diff options
Diffstat (limited to 'docs/mllib-decision-tree.md')
-rw-r--r-- | docs/mllib-decision-tree.md | 185 |
1 files changed, 185 insertions, 0 deletions
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md new file mode 100644 index 0000000000..0693766990 --- /dev/null +++ b/docs/mllib-decision-tree.md @@ -0,0 +1,185 @@ +--- +layout: global +title: <a href="mllib-guide.html">MLlib</a> - Decision Tree +--- + +* Table of contents +{:toc} + +Decision trees and their ensembles are popular methods for the machine learning tasks of +classification and regression. Decision trees are widely used since they are easy to interpret, +handle categorical variables, extend to the multiclass classification setting, do not require +feature scaling and are able to capture nonlinearities and feature interactions. Tree ensemble +algorithms such as decision forest and boosting are among the top performers for classification and +regression tasks. + +## Basic algorithm + +The decision tree is a greedy algorithm that performs a recursive binary partitioning of the feature +space by choosing a single element from the *best split set* where each element of the set maximizes +the information gain at a tree node. In other words, the split chosen at each tree node is chosen +from the set `$\underset{s}{\operatorname{argmax}} IG(D,s)$` where `$IG(D,s)$` is the information +gain when a split `$s$` is applied to a dataset `$D$`. + +### Node impurity and information gain + +The *node impurity* is a measure of the homogeneity of the labels at the node. The current +implementation provides two impurity measures for classification (Gini impurity and entropy) and one +impurity measure for regression (variance). + +<table class="table"> + <thead> + <tr><th>Impurity</th><th>Task</th><th>Formula</th><th>Description</th></tr> + </thead> + <tbody> + <tr> + <td>Gini impurity</td> + <td>Classification</td> + <td>$\sum_{i=1}^{M} f_i(1-f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $M$ is the number of unique labels.</td> + </tr> + <tr> + <td>Entropy</td> + <td>Classification</td> + <td>$\sum_{i=1}^{M} -f_ilog(f_i)$</td><td>$f_i$ is the frequency of label $i$ at a node and $M$ is the number of unique labels.</td> + </tr> + <tr> + <td>Variance</td> + <td>Regression</td> + <td>$\frac{1}{n} \sum_{i=1}^{N} (x_i - \mu)^2$</td><td>$y_i$ is label for an instance, + $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^n x_i$.</td> + </tr> + </tbody> +</table> + +The *information gain* is the difference in the parent node impurity and the weighted sum of the two +child node impurities. Assuming that a split $s$ partitions the dataset `$D$` of size `$N$` into two +datasets `$D_{left}$` and `$D_{right}$` of sizes `$N_{left}$` and `$N_{right}$`, respectively: + +`$IG(D,s) = Impurity(D) - \frac{N_{left}}{N} Impurity(D_{left}) - \frac{N_{right}}{N} Impurity(D_{right})$` + +### Split candidates + +**Continuous features** + +For small datasets in single machine implementations, the split candidates for each continuous +feature are typically the unique values for the feature. Some implementations sort the feature +values and then use the ordered unique values as split candidates for faster tree calculations. + +Finding ordered unique feature values is computationally intensive for large distributed +datasets. One can get an approximate set of split candidates by performing a quantile calculation +over a sampled fraction of the data. The ordered splits create "bins" and the maximum number of such +bins can be specified using the `maxBins` parameters. + +Note that the number of bins cannot be greater than the number of instances `$N$` (a rare scenario +since the default `maxBins` value is 100). The tree algorithm automatically reduces the number of +bins if the condition is not satisfied. + +**Categorical features** + +For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for +binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the +categorical feature values by the proportion of labels falling in one of the two classes (see +Section 9.2.4 in +[Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for +details). For example, for a binary classification problem with one categorical feature with three +categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical +features are orded as A followed by C followed B or A, B, C. The two split candidates are A \| C, B +and A , B \| C where \| denotes the split. + +### Stopping rule + +The recursive tree construction is stopped at a node when one of the two conditions is met: + +1. The node depth is equal to the `maxDepth` training parammeter +2. No split candidate leads to an information gain at the node. + +### Practical limitations + +1. The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)* + in memory for aggregating histograms over partitions. The current implementation might not scale + to very deep trees since the memory requirement grows exponentially with tree depth. +2. The implemented algorithm reads both sparse and dense data. However, it is not optimized for + sparse input. +3. Python is not supported in this release. + +We are planning to solve these problems in the near future. Please drop us a line if you encounter +any issues. + +## Examples + +### Classification + +The example below demonstrates how to load a CSV file, parse it as an RDD of `LabeledPoint` and then +perform classification using a decision tree using Gini impurity as an impurity measure and a +maximum tree depth of 5. The training error is calculated to measure the algorithm accuracy. + +<div class="codetabs"> +<div data-lang="scala"> +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impurity.Gini + +// Load and parse the data file +val data = sc.textFile("mllib/data/sample_tree_data.csv") +val parsedData = data.map { line => + val parts = line.split(',').map(_.toDouble) + LabeledPoint(parts(0), Vectors.dense(parts.tail)) +} + +// Run training algorithm to build the model +val maxDepth = 5 +val model = DecisionTree.train(parsedData, Classification, Gini, maxDepth) + +// Evaluate model on training examples and compute training error +val labelAndPreds = parsedData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) +} +val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / parsedData.count +println("Training Error = " + trainErr) +{% endhighlight %} +</div> +</div> + +### Regression + +The example below demonstrates how to load a CSV file, parse it as an RDD of `LabeledPoint` and then +perform regression using a decision tree using variance as an impurity measure and a maximum tree +depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate +[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit). + +<div class="codetabs"> +<div data-lang="scala"> +{% highlight scala %} +import org.apache.spark.SparkContext +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impurity.Variance + +// Load and parse the data file +val data = sc.textFile("mllib/data/sample_tree_data.csv") +val parsedData = data.map { line => + val parts = line.split(',').map(_.toDouble) + LabeledPoint(parts(0), Vectors.dense(parts.tail)) +} + +// Run training algorithm to build the model +val maxDepth = 5 +val model = DecisionTree.train(parsedData, Regression, Variance, maxDepth) + +// Evaluate model on training examples and compute training error +val valuesAndPreds = parsedData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) +} +val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.reduce(_ + _)/valuesAndPreds.count +println("training Mean Squared Error = " + MSE) +{% endhighlight %} +</div> +</div> |