diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-04-22 11:20:47 -0700 |
---|---|---|
committer | Patrick Wendell <pwendell@gmail.com> | 2014-04-22 11:20:47 -0700 |
commit | 26d35f3fd942761b0adecd1a720e1fa834db4de9 (patch) | |
tree | 16e57e2ff01e7cd2d7a1a3c1f3bf98c9cf98a082 /docs/mllib-decision-tree.md | |
parent | bf9d49b6d1f668b49795c2d380ab7d64ec0029da (diff) | |
download | spark-26d35f3fd942761b0adecd1a720e1fa834db4de9.tar.gz spark-26d35f3fd942761b0adecd1a720e1fa834db4de9.tar.bz2 spark-26d35f3fd942761b0adecd1a720e1fa834db4de9.zip |
[SPARK-1506][MLLIB] Documentation improvements for MLlib 1.0
Preview: http://54.82.240.23:4000/mllib-guide.html
Table of contents:
* Basics
* Data types
* Summary statistics
* Classification and regression
* linear support vector machine (SVM)
* logistic regression
* linear linear squares, Lasso, and ridge regression
* decision tree
* naive Bayes
* Collaborative Filtering
* alternating least squares (ALS)
* Clustering
* k-means
* Dimensionality reduction
* singular value decomposition (SVD)
* principal component analysis (PCA)
* Optimization
* stochastic gradient descent
* limited-memory BFGS (L-BFGS)
Author: Xiangrui Meng <meng@databricks.com>
Closes #422 from mengxr/mllib-doc and squashes the following commits:
944e3a9 [Xiangrui Meng] merge master
f9fda28 [Xiangrui Meng] minor
9474065 [Xiangrui Meng] add alpha to ALS examples
928e630 [Xiangrui Meng] initialization_mode -> initializationMode
5bbff49 [Xiangrui Meng] add imports to labeled point examples
c17440d [Xiangrui Meng] fix python nb example
28f40dc [Xiangrui Meng] remove localhost:4000
369a4d3 [Xiangrui Meng] Merge branch 'master' into mllib-doc
7dc95cc [Xiangrui Meng] update linear methods
053ad8a [Xiangrui Meng] add links to go back to the main page
abbbf7e [Xiangrui Meng] update ALS argument names
648283e [Xiangrui Meng] level down statistics
14e2287 [Xiangrui Meng] add sample libsvm data and use it in guide
8cd2441 [Xiangrui Meng] minor updates
186ab07 [Xiangrui Meng] update section names
6568d65 [Xiangrui Meng] update toc, level up lr and svm
162ee12 [Xiangrui Meng] rename section names
5c1e1b1 [Xiangrui Meng] minor
8aeaba1 [Xiangrui Meng] wrap long lines
6ce6a6f [Xiangrui Meng] add summary statistics to toc
5760045 [Xiangrui Meng] claim beta
cc604bf [Xiangrui Meng] remove classification and regression
92747b3 [Xiangrui Meng] make section titles consistent
e605dd6 [Xiangrui Meng] add LIBSVM loader
f639674 [Xiangrui Meng] add python section to migration guide
c82ffb4 [Xiangrui Meng] clean optimization
31660eb [Xiangrui Meng] update linear algebra and stat
0a40837 [Xiangrui Meng] first pass over linear methods
1fc8271 [Xiangrui Meng] update toc
906ed0a [Xiangrui Meng] add a python example to naive bayes
5f0a700 [Xiangrui Meng] update collaborative filtering
656d416 [Xiangrui Meng] update mllib-clustering
86e143a [Xiangrui Meng] remove data types section from main page
8d1a128 [Xiangrui Meng] move part of linear algebra to data types and add Java/Python examples
d1b5cbf [Xiangrui Meng] merge master
72e4804 [Xiangrui Meng] one pass over tree guide
64f8995 [Xiangrui Meng] move decision tree guide to a separate file
9fca001 [Xiangrui Meng] add first version of linear algebra guide
53c9552 [Xiangrui Meng] update dependencies
f316ec2 [Xiangrui Meng] add migration guide
f399f6c [Xiangrui Meng] move linear-algebra to dimensionality-reduction
182460f [Xiangrui Meng] add guide for naive Bayes
137fd1d [Xiangrui Meng] re-organize toc
a61e434 [Xiangrui Meng] update mllib's toc
Diffstat (limited to 'docs/mllib-decision-tree.md')
-rw-r--r-- | docs/mllib-decision-tree.md | 185 |
1 files changed, 185 insertions, 0 deletions
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md new file mode 100644 index 0000000000..0693766990 --- /dev/null +++ b/docs/mllib-decision-tree.md @@ -0,0 +1,185 @@ +--- +layout: global +title: <a href="mllib-guide.html">MLlib</a> - Decision Tree +--- + +* Table of contents +{:toc} + +Decision trees and their ensembles are popular methods for the machine learning tasks of +classification and regression. Decision trees are widely used since they are easy to interpret, +handle categorical variables, extend to the multiclass classification setting, do not require +feature scaling and are able to capture nonlinearities and feature interactions. Tree ensemble +algorithms such as decision forest and boosting are among the top performers for classification and +regression tasks. + +## Basic algorithm + +The decision tree is a greedy algorithm that performs a recursive binary partitioning of the feature +space by choosing a single element from the *best split set* where each element of the set maximizes +the information gain at a tree node. In other words, the split chosen at each tree node is chosen +from the set `$\underset{s}{\operatorname{argmax}} IG(D,s)$` where `$IG(D,s)$` is the information +gain when a split `$s$` is applied to a dataset `$D$`. + +### Node impurity and information gain + +The *node impurity* is a measure of the homogeneity of the labels at the node. The current +implementation provides two impurity measures for classification (Gini impurity and entropy) and one +impurity measure for regression (variance). + +<table class="table"> + <thead> + <tr><th>Impurity</th><th>Task</th><th>Formula</th><th>Description</th></tr> + </thead> + <tbody> + <tr> + <td>Gini impurity</td> + <td>Classification</td> + <td>$\sum_{i=1}^{M} f_i(1-f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $M$ is the number of unique labels.</td> + </tr> + <tr> + <td>Entropy</td> + <td>Classification</td> + <td>$\sum_{i=1}^{M} -f_ilog(f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $M$ is the number of unique labels.</td> + </tr> + <tr> + <td>Variance</td> + <td>Regression</td> + <td>$\frac{1}{n} \sum_{i=1}^{N} (x_i - \mu)^2$</td><td>$y_i$ is label for an instance, + $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^n x_i$.</td> + </tr> + </tbody> +</table> + +The *information gain* is the difference in the parent node impurity and the weighted sum of the two +child node impurities. Assuming that a split $s$ partitions the dataset `$D$` of size `$N$` into two +datasets `$D_{left}$` and `$D_{right}$` of sizes `$N_{left}$` and `$N_{right}$`, respectively: + +`$IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})$` + +### Split candidates + +**Continuous features** + +For small datasets in single machine implementations, the split candidates for each continuous +feature are typically the unique values for the feature. Some implementations sort the feature +values and then use the ordered unique values as split candidates for faster tree calculations. + +Finding ordered unique feature values is computationally intensive for large distributed +datasets. One can get an approximate set of split candidates by performing a quantile calculation +over a sampled fraction of the data. The ordered splits create "bins" and the maximum number of such +bins can be specified using the `maxBins` parameters. + +Note that the number of bins cannot be greater than the number of instances `$N$` (a rare scenario +since the default `maxBins` value is 100). The tree algorithm automatically reduces the number of +bins if the condition is not satisfied. + +**Categorical features** + +For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for +binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the +categorical feature values by the proportion of labels falling in one of the two classes (see +Section 9.2.4 in +[Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for +details). For example, for a binary classification problem with one categorical feature with three +categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical +features are orded as A followed by C followed B or A, B, C. The two split candidates are A \| C, B +and A , B \| C where \| denotes the split. + +### Stopping rule + +The recursive tree construction is stopped at a node when one of the two conditions is met: + +1. The node depth is equal to the `maxDepth` training parammeter +2. No split candidate leads to an information gain at the node. + +### Practical limitations + +1. The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)* + in memory for aggregating histograms over partitions. The current implementation might not scale + to very deep trees since the memory requirement grows exponentially with tree depth. +2. The implemented algorithm reads both sparse and dense data. However, it is not optimized for + sparse input. +3. Python is not supported in this release. + +We are planning to solve these problems in the near future. Please drop us a line if you encounter +any issues. + +## Examples + +### Classification + +The example below demonstrates how to load a CSV file, parse it as an RDD of `LabeledPoint` and then +perform classification using a decision tree using Gini impurity as an impurity measure and a +maximum tree depth of 5. The training error is calculated to measure the algorithm accuracy. + +<div class="codetabs"> +<div data-lang="scala"> +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impurity.Gini + +// Load and parse the data file +val data = sc.textFile("mllib/data/sample_tree_data.csv") +val parsedData = data.map { line => + val parts = line.split(',').map(_.toDouble) + LabeledPoint(parts(0), Vectors.dense(parts.tail)) +} + +// Run training algorithm to build the model +val maxDepth = 5 +val model = DecisionTree.train(parsedData, Classification, Gini, maxDepth) + +// Evaluate model on training examples and compute training error +val labelAndPreds = parsedData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) +} +val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / parsedData.count +println("Training Error = " + trainErr) +{% endhighlight %} +</div> +</div> + +### Regression + +The example below demonstrates how to load a CSV file, parse it as an RDD of `LabeledPoint` and then +perform regression using a decision tree using variance as an impurity measure and a maximum tree +depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate +[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). + +<div class="codetabs"> +<div data-lang="scala"> +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impurity.Variance + +// Load and parse the data file +val data = sc.textFile("mllib/data/sample_tree_data.csv") +val parsedData = data.map { line => + val parts = line.split(',').map(_.toDouble) + LabeledPoint(parts(0), Vectors.dense(parts.tail)) +} + +// Run training algorithm to build the model +val maxDepth = 5 +val model = DecisionTree.train(parsedData, Regression, Variance, maxDepth) + +// Evaluate model on training examples and compute training error +val valuesAndPreds = parsedData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) +} +val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.reduce(_ + _)/valuesAndPreds.count +println("training Mean Squared Error = " + MSE) +{% endhighlight %} +</div> +</div> |