aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-07-31 20:51:48 -0700
committerXiangrui Meng <meng@databricks.com>2014-07-31 20:51:48 -0700
commitb124de584a45b7ebde9fbe10128db429c56aeaee (patch)
tree6f8b70447a4d825cfe1ad71870741f8f8d77ba3d
parentd8430148ee1f6ba02569db0538eeae473a32c78e (diff)
downloadspark-b124de584a45b7ebde9fbe10128db429c56aeaee.tar.gz
spark-b124de584a45b7ebde9fbe10128db429c56aeaee.tar.bz2
spark-b124de584a45b7ebde9fbe10128db429c56aeaee.zip
[SPARK-2756] [mllib] Decision tree bug fixes
(1) Inconsistent aggregate (agg) indexing for unordered features. (2) Fixed gain calculations for edge cases. (3) One-off error in choosing thresholds for continuous features for small datasets. (4) (not a bug) Changed meaning of tree depth by 1 to fit scikit-learn and rpart. (Depth 1 used to mean 1 leaf node; depth 0 now means 1 leaf node.) Other updates, to help with tests: * Updated DecisionTreeRunner to print more info. * Added utility functions to DecisionTreeModel: toString, depth, numNodes * Improved internal DecisionTree documentation Bug fix details: (1) Indexing was inconsistent for aggregate calculations for unordered features (in multiclass classification with categorical features, where the features had few enough values such that they could be considered unordered, i.e., isSpaceSufficientForAllCategoricalSplits=true). * updateBinForUnorderedFeature indexed agg as (node, feature, featureValue, binIndex), where ** featureValue was from arr (so it was a feature value) ** binIndex was in [0,…, 2^(maxFeatureValue-1)-1) * The rest of the code indexed agg as (node, feature, binIndex, label). * Corrected this bug by changing updateBinForUnorderedFeature to use the second indexing pattern. Unit tests in DecisionTreeSuite * Updated a few tests to train a model and test its training accuracy, which catches the indexing bug from updateBinForUnorderedFeature() discussed above. * Added new test (“stump with categorical variables for multiclass classification, with just enough bins”) to test bin extremes. (2) Bug fix: calculateGainForSplit (for classification): * It used to return dummy prediction values when either the right or left children had 0 weight. These were incorrect for multiclass classification. It has been corrected. Updated impurities to allow for count = 0. This was related to the above bug fix for calculateGainForSplit (for classification). Small updates to documentation and coding style. (3) Bug fix: Off-by-1 when finding thresholds for splits for continuous features. * Exhibited bug in new test in DecisionTreeSuite: “stump with 1 continuous variable for binary classification, to check off-by-1 error” * Description: When finding thresholds for possible splits for continuous features in DecisionTree.findSplitsBins, the thresholds were set according to individual training examples’ feature values. * Fix: The threshold is set to be the average of 2 consecutive (sorted) examples’ feature values. E.g.: If the old code set the threshold using example i, the new code sets the threshold using exam * Note: In 4 DecisionTreeSuite tests with all labels identical, removed check of threshold since it is somewhat arbitrary. CC: mengxr manishamde Please let me know if I missed something! Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1673 from jkbradley/decisiontree-bugfix and squashes the following commits: 2b20c61 [Joseph K. Bradley] Small doc and style updates dab0b67 [Joseph K. Bradley] Added documentation for DecisionTree internals 8bb8aa0 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 978cfcf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 6eed482 [Joseph K. Bradley] In DecisionTree: Changed from using procedural syntax for functions returning Unit to explicitly writing Unit return type. 376dca2 [Joseph K. Bradley] Updated meaning of maxDepth by 1 to fit scikit-learn and rpart. * In code, replaced usages of maxDepth <-- maxDepth + 1 * In params, replace settings of maxDepth <-- maxDepth - 1 59750f8 [Joseph K. Bradley] * Updated Strategy to check numClassesForClassification only if algo=Classification. * Updates based on comments: ** DecisionTreeRunner *** Made dataFormat arg default to libsvm ** Small cleanups ** tree.Node: Made recursive helper methods private, and renamed them. 52e17c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix da50db7 [Joseph K. Bradley] Added one more test to DecisionTreeSuite: stump with 2 continuous variables for binary classification. Caused problems in past, but fixed now. 8ea8750 [Joseph K. Bradley] Bug fix: Off-by-1 when finding thresholds for splits for continuous features. 2283df8 [Joseph K. Bradley] 2 bug fixes. 73fbea2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 5f920a1 [Joseph K. Bradley] Demonstration of bug before submitting fix: Updated DecisionTreeSuite so that 3 tests fail. Will describe bug in next commit.
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala92
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala408
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala31
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala56
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala115
10 files changed, 538 insertions, 193 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 6db9bf3cf5..cf3d2cca81 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -21,7 +21,6 @@ import scopt.OptionParser
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
@@ -36,6 +35,9 @@ import org.apache.spark.rdd.RDD
* ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ *
+ * Note: This script treats all features as real-valued (not categorical).
+ * To include categorical features, modify categoricalFeaturesInfo.
*/
object DecisionTreeRunner {
@@ -48,11 +50,12 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
+ dataFormat: String = "libsvm",
algo: Algo = Classification,
- numClassesForClassification: Int = 2,
- maxDepth: Int = 5,
+ maxDepth: Int = 4,
impurity: ImpurityType = Gini,
- maxBins: Int = 100)
+ maxBins: Int = 100,
+ fracTest: Double = 0.2)
def main(args: Array[String]) {
val defaultParams = Params()
@@ -69,25 +72,31 @@ object DecisionTreeRunner {
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
- opt[Int]("numClassesForClassification")
- .text(s"number of classes for classification, "
- + s"default: ${defaultParams.numClassesForClassification}")
- .action((x, c) => c.copy(numClassesForClassification = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[String]("<dataFormat>")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
arg[String]("<input>")
.text("input paths to labeled examples in dense format (label,f0 f1 f2 ...)")
.required()
.action((x, c) => c.copy(input = x))
checkConfig { params =>
- if (params.algo == Classification &&
- (params.impurity == Gini || params.impurity == Entropy)) {
- success
- } else if (params.algo == Regression && params.impurity == Variance) {
- success
+ if (params.fracTest < 0 || params.fracTest > 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
} else {
- failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
+ if (params.algo == Classification &&
+ (params.impurity == Gini || params.impurity == Entropy)) {
+ success
+ } else if (params.algo == Regression && params.impurity == Variance) {
+ success
+ } else {
+ failure(s"Algo ${params.algo} is not compatible with impurity ${params.impurity}.")
+ }
}
}
}
@@ -100,16 +109,57 @@ object DecisionTreeRunner {
}
def run(params: Params) {
+
val conf = new SparkConf().setAppName("DecisionTreeRunner")
val sc = new SparkContext(conf)
// Load training data and cache it.
- val examples = MLUtils.loadLabeledPoints(sc, params.input).cache()
+ val origExamples = params.dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache()
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache()
+ }
+ // For classification, re-index classes if needed.
+ val (examples, numClasses) = params.algo match {
+ case Classification => {
+ // classCounts: class --> # examples in class
+ val classCounts = origExamples.map(_.label).countByValue()
+ val sortedClasses = classCounts.keys.toList.sorted
+ val numClasses = classCounts.size
+ // classIndexMap: class --> index in 0,...,numClasses-1
+ val classIndexMap = {
+ if (classCounts.keySet != Set(0.0, 1.0)) {
+ sortedClasses.zipWithIndex.toMap
+ } else {
+ Map[Double, Int]()
+ }
+ }
+ val examples = {
+ if (classIndexMap.isEmpty) {
+ origExamples
+ } else {
+ origExamples.map(lp => LabeledPoint(classIndexMap(lp.label), lp.features))
+ }
+ }
+ val numExamples = examples.count()
+ println(s"numClasses = $numClasses.")
+ println(s"Per-class example fractions, counts:")
+ println(s"Class\tFrac\tCount")
+ sortedClasses.foreach { c =>
+ val frac = classCounts(c) / numExamples.toDouble
+ println(s"$c\t$frac\t${classCounts(c)}")
+ }
+ (examples, numClasses)
+ }
+ case Regression =>
+ (origExamples, 0)
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
- val splits = examples.randomSplit(Array(0.8, 0.2))
+ // Split into training, test.
+ val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
val training = splits(0).cache()
val test = splits(1).cache()
-
val numTraining = training.count()
val numTest = test.count()
@@ -129,17 +179,19 @@ object DecisionTreeRunner {
impurity = impurityCalculator,
maxDepth = params.maxDepth,
maxBins = params.maxBins,
- numClassesForClassification = params.numClassesForClassification)
+ numClassesForClassification = numClasses)
val model = DecisionTree.train(training, strategy)
+ println(model)
+
if (params.algo == Classification) {
val accuracy = accuracyScore(model, test)
- println(s"Test accuracy = $accuracy.")
+ println(s"Test accuracy = $accuracy")
}
if (params.algo == Regression) {
val mse = meanSquaredError(model, test)
- println(s"Test mean squared error = $mse.")
+ println(s"Test mean squared error = $mse")
}
sc.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index ad32e3f456..7d123dd6ae 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -31,8 +31,8 @@ import org.apache.spark.util.random.XORShiftRandom
/**
* :: Experimental ::
- * A class that implements a decision tree algorithm for classification and regression. It
- * supports both continuous and categorical features.
+ * A class which implements a decision tree learning algorithm for classification and regression.
+ * It supports both continuous and categorical features.
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of algorithm (classification, regression, etc.), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
@@ -42,8 +42,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
/**
* Method to train a decision tree model over an RDD
- * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
- * @return a DecisionTreeModel that can be used for prediction
+ * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
@@ -60,7 +60,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// depth of the decision tree
val maxDepth = strategy.maxDepth
// the max number of nodes possible given the depth of the tree
- val maxNumNodes = math.pow(2, maxDepth).toInt - 1
+ val maxNumNodes = math.pow(2, maxDepth + 1).toInt - 1
// Initialize an array to hold filters applied to points for each node.
val filters = new Array[List[Filter]](maxNumNodes)
// The filter at the top node is an empty list.
@@ -100,7 +100,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
var level = 0
var break = false
- while (level < maxDepth && !break) {
+ while (level <= maxDepth && !break) {
logDebug("#####################################")
logDebug("level = " + level)
@@ -152,7 +152,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
val split = nodeSplitStats._1
val stats = nodeSplitStats._2
val nodeIndex = math.pow(2, level).toInt - 1 + index
- val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1)
+ val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
logDebug("Node = " + node)
nodes(nodeIndex) = node
@@ -173,7 +173,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
while (i <= 1) {
// Calculate the index of the node from the node level and the index at the current level.
val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i
- if (level < maxDepth - 1) {
+ if (level < maxDepth) {
val impurity = if (i == 0) {
nodeSplitStats._2.leftImpurity
} else {
@@ -197,17 +197,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
object DecisionTree extends Serializable with Logging {
/**
- * Method to train a decision tree model where the instances are represented as an RDD of
- * (label, features) pairs. The method supports binary classification and regression. For the
- * binary classification, the label for each instance should either be 0 or 1 to denote the two
- * classes. The parameters for the algorithm are specified using the strategy parameter.
+ * Method to train a decision tree model.
+ * The method supports binary and multiclass classification and regression.
*
- * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
- * for DecisionTree
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
* @param strategy The configuration parameters for the tree algorithm which specify the type
* of algorithm (classification, regression, etc.), feature type (continuous,
* categorical), depth of the tree, quantile calculation strategy, etc.
- * @return a DecisionTreeModel that can be used for prediction
+ * @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
new DecisionTree(strategy).train(input)
@@ -219,12 +218,14 @@ object DecisionTree extends Serializable with Logging {
* binary classification, the label for each instance should either be 0 or 1 to denote the two
* classes.
*
- * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
- * training data
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
* @param algo algorithm, classification or regression
* @param impurity impurity criterion used for information gain calculation
- * @param maxDepth maxDepth maximum depth of the tree
- * @return a DecisionTreeModel that can be used for prediction
+ * @param maxDepth Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * @return DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
@@ -241,13 +242,15 @@ object DecisionTree extends Serializable with Logging {
* binary classification, the label for each instance should either be 0 or 1 to denote the two
* classes.
*
- * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
- * training data
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
* @param algo algorithm, classification or regression
* @param impurity impurity criterion used for information gain calculation
- * @param maxDepth maxDepth maximum depth of the tree
+ * @param maxDepth Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* @param numClassesForClassification number of classes for classification. Default value of 2.
- * @return a DecisionTreeModel that can be used for prediction
+ * @return DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
@@ -266,11 +269,13 @@ object DecisionTree extends Serializable with Logging {
* 1 to denote the two classes. The method also supports categorical features inputs where the
* number of categories can specified using the categoricalFeaturesInfo option.
*
- * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
- * training data for DecisionTree
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
- * @param maxDepth maximum depth of the tree
+ * @param maxDepth Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* @param numClassesForClassification number of classes for classification. Default value of 2.
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
@@ -279,7 +284,7 @@ object DecisionTree extends Serializable with Logging {
* an entry (n -> k) implies the feature n is categorical with k
* categories 0, 1, 2, ... , k-1. It's important to note that
* features are zero-indexed.
- * @return a DecisionTreeModel that can be used for prediction
+ * @return DecisionTreeModel that can be used for prediction
*/
def train(
input: RDD[LabeledPoint],
@@ -301,11 +306,10 @@ object DecisionTree extends Serializable with Logging {
* Returns an array of optimal splits for all nodes at a given level. Splits the task into
* multiple groups if the level-wise training task could lead to memory overflow.
*
- * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
- * for DecisionTree
+ * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param parentImpurities Impurities for all parent nodes for the current level
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
- * parameters for construction the DecisionTree
+ * parameters for constructing the DecisionTree
* @param level Level of the tree
* @param filters Filters for all nodes at a given level
* @param splits possible splits for all features
@@ -348,11 +352,10 @@ object DecisionTree extends Serializable with Logging {
/**
* Returns an array of optimal splits for a group of nodes at a given level
*
- * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
- * for DecisionTree
+ * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param parentImpurities Impurities for all parent nodes for the current level
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
- * parameters for construction the DecisionTree
+ * parameters for constructing the DecisionTree
* @param level Level of the tree
* @param filters Filters for all nodes at a given level
* @param splits possible splits for all features
@@ -373,7 +376,7 @@ object DecisionTree extends Serializable with Logging {
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
/*
- * The high-level description for the best split optimizations are noted here.
+ * The high-level descriptions of the best split optimizations are noted here.
*
* *Level-wise training*
* We perform bin calculations for all nodes at the given level to avoid making multiple
@@ -396,18 +399,27 @@ object DecisionTree extends Serializable with Logging {
* drastically reduce the communication overhead.
*/
- // common calculations for multiple nested methods
+ // Common calculations for multiple nested methods:
+
+ // numNodes: Number of nodes in this (level of tree, group),
+ // where nodes at deeper (larger) levels may be divided into groups.
val numNodes = math.pow(2, level).toInt / numGroups
logDebug("numNodes = " + numNodes)
+
// Find the number of features by looking at the first sample.
val numFeatures = input.first().features.size
logDebug("numFeatures = " + numFeatures)
+
+ // numBins: Number of bins = 1 + number of possible splits
val numBins = bins(0).length
logDebug("numBins = " + numBins)
+
val numClasses = strategy.numClassesForClassification
logDebug("numClasses = " + numClasses)
+
val isMulticlassClassification = strategy.isMulticlassClassification
logDebug("isMulticlassClassification = " + isMulticlassClassification)
+
val isMulticlassClassificationWithCategoricalFeatures
= strategy.isMulticlassWithCategoricalFeatures
logDebug("isMultiClassWithCategoricalFeatures = " +
@@ -465,10 +477,13 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Find bin for one feature.
+ * Find bin for one (labeledPoint, feature).
*/
- def findBin(featureIndex: Int, labeledPoint: LabeledPoint,
- isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = {
+ def findBin(
+ featureIndex: Int,
+ labeledPoint: LabeledPoint,
+ isFeatureContinuous: Boolean,
+ isSpaceSufficientForAllCategoricalSplits: Boolean): Int = {
val binForFeatures = bins(featureIndex)
val feature = labeledPoint.features(featureIndex)
@@ -535,7 +550,9 @@ object DecisionTree extends Serializable with Logging {
} else {
// Perform sequential search to find bin for categorical features.
val binIndex = {
- if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+ val isUnorderedFeature =
+ isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
+ if (isUnorderedFeature) {
sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
} else {
sequentialBinSearchForOrderedCategoricalFeatureInClassification()
@@ -555,6 +572,14 @@ object DecisionTree extends Serializable with Logging {
* where b_ij is an integer between 0 and numBins - 1 for regressions and binary
* classification and the categorical feature value in multiclass classification.
* Invalid sample is denoted by noting bin for feature 1 as -1.
+ *
+ * For unordered features, the "bin index" returned is actually the feature value (category).
+ *
+ * @return Array of size 1 + numFeatures * numNodes, where
+ * arr(0) = label for labeledPoint, and
+ * arr(1 + numFeatures * nodeIndex + featureIndex) =
+ * bin index for this labeledPoint
+ * (or InvalidBinIndex if labeledPoint is not handled by this node)
*/
def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
// Calculate bin index and label per feature per node.
@@ -598,9 +623,21 @@ object DecisionTree extends Serializable with Logging {
// Find feature bins for all nodes at a level.
val binMappedRDD = input.map(x => findBinsForLevel(x))
- def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int,
- label: Double, featureIndex: Int) = {
-
+ /**
+ * Increment aggregate in location for (node, feature, bin, label).
+ *
+ * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label.
+ * Array of size 1 + (numFeatures * numNodes).
+ * @param agg Array storing aggregate calculation, of size:
+ * numClasses * numBins * numFeatures * numNodes.
+ * Indexed by (node, feature, bin, label) where label is the least significant bit.
+ */
+ def updateBinForOrderedFeature(
+ arr: Array[Double],
+ agg: Array[Double],
+ nodeIndex: Int,
+ label: Double,
+ featureIndex: Int): Unit = {
// Find the bin index for this feature.
val arrShift = 1 + numFeatures * nodeIndex
val arrIndex = arrShift + featureIndex
@@ -612,44 +649,58 @@ object DecisionTree extends Serializable with Logging {
agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
}
- def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double],
- label: Double, agg: Array[Double], rightChildShift: Int) = {
+ /**
+ * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label),
+ * where [bins] ranges over all bins.
+ * Updates left or right side of aggregate depending on split.
+ *
+ * @param arr arr(0) = label.
+ * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category)
+ * @param agg Indexed by (left/right, node, feature, bin, label)
+ * where label is the least significant bit.
+ * The left/right specifier is a 0/1 index indicating left/right child info.
+ * @param rightChildShift Offset for right side of agg.
+ */
+ def updateBinForUnorderedFeature(
+ nodeIndex: Int,
+ featureIndex: Int,
+ arr: Array[Double],
+ label: Double,
+ agg: Array[Double],
+ rightChildShift: Int): Unit = {
// Find the bin index for this feature.
- val arrShift = 1 + numFeatures * nodeIndex
- val arrIndex = arrShift + featureIndex
+ val arrIndex = 1 + numFeatures * nodeIndex + featureIndex
+ val featureValue = arr(arrIndex).toInt
// Update the left or right count for one bin.
- val aggShift = numClasses * numBins * numFeatures * nodeIndex
- val aggIndex
- = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
+ val aggShift =
+ numClasses * numBins * numFeatures * nodeIndex +
+ numClasses * numBins * featureIndex +
+ label.toInt
// Find all matching bins and increment their values
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
var binIndex = 0
while (binIndex < numCategoricalBins) {
- val labelInt = label.toInt
- if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) {
- agg(aggIndex + binIndex)
- = agg(aggIndex + binIndex) + 1
+ val aggIndex = aggShift + binIndex * numClasses
+ if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) {
+ agg(aggIndex) += 1
} else {
- agg(rightChildShift + aggIndex + binIndex)
- = agg(rightChildShift + aggIndex + binIndex) + 1
+ agg(rightChildShift + aggIndex) += 1
}
binIndex += 1
}
}
/**
- * Performs a sequential aggregation over a partition for classification. For l nodes,
- * k features, either the left count or the right count of one of the p bins is
- * incremented based upon whether the feature is classified as 0 or 1.
+ * Helper for binSeqOp.
*
- * @param agg Array[Double] storing aggregate calculation of size
- * numClasses * numSplits * numFeatures*numNodes for classification
- * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
- * @return Array[Double] storing aggregate calculation of size
- * 2 * numSplits * numFeatures * numNodes for classification
+ * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label.
+ * Array of size 1 + (numFeatures * numNodes).
+ * @param agg Array storing aggregate calculation, of size:
+ * numClasses * numBins * numFeatures * numNodes.
+ * Indexed by (node, feature, bin, label) where label is the least significant bit.
*/
- def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
+ def binaryOrNotCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
// Iterate over all nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
@@ -671,17 +722,21 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Performs a sequential aggregation over a partition for classification. For l nodes,
- * k features, either the left count or the right count of one of the p bins is
- * incremented based upon whether the feature is classified as 0 or 1.
+ * Helper for binSeqOp.
*
- * @param agg Array[Double] storing aggregate calculation of size
- * numClasses * numSplits * numFeatures*numNodes for classification
- * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
- * @return Array[Double] storing aggregate calculation of size
- * 2 * numClasses * numSplits * numFeatures * numNodes for classification
+ * @param arr Bin mapping from findBinsForLevel. arr(0) stores the class label.
+ * Array of size 1 + (numFeatures * numNodes).
+ * For ordered features,
+ * arr(1 + featureIndex + nodeIndex * numFeatures) = bin index.
+ * For unordered features,
+ * arr(1 + featureIndex + nodeIndex * numFeatures) = feature value (category).
+ * @param agg Array storing aggregate calculation.
+ * For ordered features, this is of size:
+ * numClasses * numBins * numFeatures * numNodes.
+ * For unordered features, this is of size:
+ * 2 * numClasses * numBins * numFeatures * numNodes.
*/
- def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
+ def multiclassWithCategoricalBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
// Iterate over all nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
@@ -717,16 +772,17 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Performs a sequential aggregation over a partition for regression. For l nodes, k features,
+ * Performs a sequential aggregation over a partition for regression.
+ * For l nodes, k features,
* the count, sum, sum of squares of one of the p bins is incremented.
*
- * @param agg Array[Double] storing aggregate calculation of size
- * 3 * numSplits * numFeatures * numNodes for classification
- * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
- * @return Array[Double] storing aggregate calculation of size
- * 3 * numSplits * numFeatures * numNodes for regression
+ * @param agg Array storing aggregate calculation, updated by this function.
+ * Size: 3 * numBins * numFeatures * numNodes
+ * @param arr Bin mapping from findBinsForLevel.
+ * Array of size 1 + (numFeatures * numNodes).
+ * @return agg
*/
- def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
+ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]): Unit = {
// Iterate over all nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
@@ -757,14 +813,30 @@ object DecisionTree extends Serializable with Logging {
/**
* Performs a sequential aggregation over a partition.
+ * For l nodes, k features,
+ * For classification:
+ * Either the left count or the right count of one of the bins is
+ * incremented based upon whether the feature is classified as 0 or 1.
+ * For regression:
+ * The count, sum, sum of squares of one of the bins is incremented.
+ *
+ * @param agg Array storing aggregate calculation, updated by this function.
+ * Size for classification:
+ * numClasses * numBins * numFeatures * numNodes for ordered features, or
+ * 2 * numClasses * numBins * numFeatures * numNodes for unordered features.
+ * Size for regression:
+ * 3 * numBins * numFeatures * numNodes.
+ * @param arr Bin mapping from findBinsForLevel.
+ * Array of size 1 + (numFeatures * numNodes).
+ * @return agg
*/
def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
strategy.algo match {
case Classification =>
if(isMulticlassClassificationWithCategoricalFeatures) {
- unorderedClassificationBinSeqOp(arr, agg)
+ multiclassWithCategoricalBinSeqOp(arr, agg)
} else {
- orderedClassificationBinSeqOp(arr, agg)
+ binaryOrNotCategoricalBinSeqOp(arr, agg)
}
case Regression => regressionBinSeqOp(arr, agg)
}
@@ -815,20 +887,10 @@ object DecisionTree extends Serializable with Logging {
topImpurity: Double): InformationGainStats = {
strategy.algo match {
case Classification =>
- var classIndex = 0
- val leftCounts: Array[Double] = new Array[Double](numClasses)
- val rightCounts: Array[Double] = new Array[Double](numClasses)
- var leftTotalCount = 0.0
- var rightTotalCount = 0.0
- while (classIndex < numClasses) {
- val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex)
- val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex)
- leftCounts(classIndex) = leftClassCount
- leftTotalCount += leftClassCount
- rightCounts(classIndex) = rightClassCount
- rightTotalCount += rightClassCount
- classIndex += 1
- }
+ val leftCounts: Array[Double] = leftNodeAgg(featureIndex)(splitIndex)
+ val rightCounts: Array[Double] = rightNodeAgg(featureIndex)(splitIndex)
+ val leftTotalCount = leftCounts.sum
+ val rightTotalCount = rightCounts.sum
val impurity = {
if (level > 0) {
@@ -845,33 +907,17 @@ object DecisionTree extends Serializable with Logging {
}
}
- if (leftTotalCount == 0) {
- return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1)
- }
- if (rightTotalCount == 0) {
- return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1)
- }
-
- val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount)
- val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount)
-
- val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount)
- val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount)
-
- val gain = {
- if (level > 0) {
- impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
- } else {
- impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
- }
- }
-
val totalCount = leftTotalCount + rightTotalCount
+ if (totalCount == 0) {
+ // Return arbitrary prediction.
+ return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+ }
// Sum of count for each label
- val leftRightCounts: Array[Double]
- = leftCounts.zip(rightCounts)
- .map{case (leftCount, rightCount) => leftCount + rightCount}
+ val leftRightCounts: Array[Double] =
+ leftCounts.zip(rightCounts).map { case (leftCount, rightCount) =>
+ leftCount + rightCount
+ }
def indexOfLargestArrayElement(array: Array[Double]): Int = {
val result = array.foldLeft(-1, Double.MinValue, 0) {
@@ -885,6 +931,22 @@ object DecisionTree extends Serializable with Logging {
val predict = indexOfLargestArrayElement(leftRightCounts)
val prob = leftRightCounts(predict) / totalCount
+ val leftImpurity = if (leftTotalCount == 0) {
+ topImpurity
+ } else {
+ strategy.impurity.calculate(leftCounts, leftTotalCount)
+ }
+ val rightImpurity = if (rightTotalCount == 0) {
+ topImpurity
+ } else {
+ strategy.impurity.calculate(rightCounts, rightTotalCount)
+ }
+
+ val leftWeight = leftTotalCount / totalCount
+ val rightWeight = rightTotalCount / totalCount
+
+ val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
+
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
case Regression =>
val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
@@ -937,10 +999,18 @@ object DecisionTree extends Serializable with Logging {
/**
* Extracts left and right split aggregates.
- * @param binData Array[Double] of size 2*numFeatures*numSplits
- * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\],
- * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature,
- * (numBins - 1), numClasses)
+ * @param binData Aggregate array slice from getBinDataForNode.
+ * For classification:
+ * For unordered features, this is leftChildData ++ rightChildData,
+ * each of which is indexed by (feature, split/bin, class),
+ * with class being the least significant bit.
+ * For ordered features, this is of size numClasses * numBins * numFeatures.
+ * For regression:
+ * This is of size 2 * numFeatures * numBins.
+ * @return (leftNodeAgg, rightNodeAgg) pair of arrays.
+ * For classification, each array is of size (numFeatures, (numBins - 1), numClasses).
+ * For regression, each array is of size (numFeatures, (numBins - 1), 3).
+ *
*/
def extractLeftRightNodeAggregates(
binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
@@ -983,6 +1053,11 @@ object DecisionTree extends Serializable with Logging {
}
}
+ /**
+ * Reshape binData for this feature.
+ * Indexes binData as (feature, split, class) with class as the least significant bit.
+ * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value
+ */
def findAggForUnorderedFeatureClassification(
leftNodeAgg: Array[Array[Array[Double]]],
rightNodeAgg: Array[Array[Array[Double]]],
@@ -1107,7 +1182,7 @@ object DecisionTree extends Serializable with Logging {
/**
* Find the best split for a node.
- * @param binData Array[Double] of size 2 * numSplits * numFeatures
+ * @param binData Bin data slice for this node, given by getBinDataForNode.
* @param nodeImpurity impurity of the top node
* @return tuple of split and information gain
*/
@@ -1133,7 +1208,7 @@ object DecisionTree extends Serializable with Logging {
while (featureIndex < numFeatures) {
// Iterate over all splits.
var splitIndex = 0
- val maxSplitIndex : Double = {
+ val maxSplitIndex: Double = {
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
numBins - 1
@@ -1162,8 +1237,8 @@ object DecisionTree extends Serializable with Logging {
(bestFeatureIndex, bestSplitIndex, bestGainStats)
}
+ logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex))
logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex))
- logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex))
(splits(bestFeatureIndex)(bestSplitIndex), gainStats)
}
@@ -1214,8 +1289,17 @@ object DecisionTree extends Serializable with Logging {
bestSplits
}
- private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int,
- isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = {
+ /**
+ * Get the number of values to be stored per node in the bin aggregates.
+ *
+ * @param numBins Number of bins = 1 + number of possible splits.
+ */
+ private def getElementsPerNode(
+ numFeatures: Int,
+ numBins: Int,
+ numClasses: Int,
+ isMulticlassClassificationWithCategoricalFeatures: Boolean,
+ algo: Algo): Int = {
algo match {
case Classification =>
if (isMulticlassClassificationWithCategoricalFeatures) {
@@ -1228,18 +1312,40 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Returns split and bins for decision tree calculation.
- * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
- * for DecisionTree
+ * Returns splits and bins for decision tree calculation.
+ * Continuous and categorical features are handled differently.
+ *
+ * Continuous features:
+ * For each feature, there are numBins - 1 possible splits representing the possible binary
+ * decisions at each node in the tree.
+ *
+ * Categorical features:
+ * For each feature, there is 1 bin per split.
+ * Splits and bins are handled in 2 ways:
+ * (a) For multiclass classification with a low-arity feature
+ * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
+ * the feature is split based on subsets of categories.
+ * There are 2^(maxFeatureValue - 1) - 1 splits.
+ * (b) For regression and binary classification,
+ * and for multiclass classification with a high-arity feature,
+ * there is one split per category.
+
+ * Categorical case (a) features are called unordered features.
+ * Other cases are called ordered features.
+ *
+ * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
- * parameters for construction the DecisionTree
- * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree
- * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache
- * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1)
+ * parameters for construction the DecisionTree
+ * @return A tuple of (splits,bins).
+ * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
+ * of size (numFeatures, numBins - 1).
+ * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
+ * of size (numFeatures, numBins).
*/
protected[tree] def findSplitsBins(
input: RDD[LabeledPoint],
strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = {
+
val count = input.count()
// Find the number of features by looking at the first sample
@@ -1271,7 +1377,8 @@ object DecisionTree extends Serializable with Logging {
logDebug("fraction of data used for calculating quantiles = " + fraction)
// sampled input for RDD calculation
- val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect()
+ val sampledInput =
+ input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
val numSamples = sampledInput.length
val stride: Double = numSamples.toDouble / numBins
@@ -1294,8 +1401,10 @@ object DecisionTree extends Serializable with Logging {
val stride: Double = numSamples.toDouble / numBins
logDebug("stride = " + stride)
for (index <- 0 until numBins - 1) {
- val sampleIndex = (index + 1) * stride.toInt
- val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List())
+ val sampleIndex = index * stride.toInt
+ // Set threshold halfway in between 2 samples.
+ val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0
+ val split = new Split(featureIndex, threshold, Continuous, List())
splits(featureIndex)(index) = split
}
} else { // Categorical feature
@@ -1304,8 +1413,10 @@ object DecisionTree extends Serializable with Logging {
= numBins > math.pow(2, featureCategories.toInt - 1) - 1
// Use different bin/split calculation strategy for categorical features in multiclass
- // classification that satisfy the space constraint
- if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+ // classification that satisfy the space constraint.
+ val isUnorderedFeature =
+ isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
+ if (isUnorderedFeature) {
// 2^(maxFeatureValue- 1) - 1 combinations
var index = 0
while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
@@ -1330,8 +1441,13 @@ object DecisionTree extends Serializable with Logging {
}
index += 1
}
- } else {
-
+ } else { // ordered feature
+ /* For a given categorical feature, use a subsample of the data
+ * to choose how to arrange possible splits.
+ * This examines each category and computes a centroid.
+ * These centroids are later used to sort the possible splits.
+ * centroidForCategories is a mapping: category (for the given feature) --> centroid
+ */
val centroidForCategories = {
if (isMulticlassClassification) {
// For categorical variables in multiclass classification,
@@ -1341,7 +1457,7 @@ object DecisionTree extends Serializable with Logging {
.groupBy(_._1)
.mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
.map(x => (x._1, x._2.values.toArray))
- .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum)))
+ .map(x => (x._1, strategy.impurity.calculate(x._2, x._2.sum)))
} else { // regression or binary classification
// For categorical variables in regression and binary classification,
// each bin is a category. The bins are sorted and they
@@ -1352,7 +1468,7 @@ object DecisionTree extends Serializable with Logging {
}
}
- logDebug("centriod for categories = " + centroidForCategories.mkString(","))
+ logDebug("centroid for categories = " + centroidForCategories.mkString(","))
// Check for missing categorical variables and putting them last in the sorted list.
val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
@@ -1367,7 +1483,7 @@ object DecisionTree extends Serializable with Logging {
// bins sorted by centroids
val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
- logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)
+ logDebug("centroid for categorical variable = " + categoriesSortedByCentroid)
var categoriesForSplit = List[Double]()
categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 7c027ac2fd..5c65b537b6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -27,7 +27,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* Stores all the configuration options for tree construction
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
- * @param maxDepth maximum depth of the tree
+ * @param maxDepth Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
* @param numClassesForClassification number of classes for classification. Default value is 2
* leads to binary classification
* @param maxBins maximum number of bins used for splitting features
@@ -52,7 +53,9 @@ class Strategy (
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
val maxMemoryInMB: Int = 128) extends Serializable {
- require(numClassesForClassification >= 2)
+ if (algo == Classification) {
+ require(numClassesForClassification >= 2)
+ }
val isMulticlassClassification = numClassesForClassification > 2
val isMulticlassWithCategoricalFeatures
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index a0e2d91762..9297c20596 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -34,10 +34,13 @@ object Entropy extends Impurity {
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
- * @return information value
+ * @return information value, or 0 if totalCount = 0
*/
@DeveloperApi
override def calculate(counts: Array[Double], totalCount: Double): Double = {
+ if (totalCount == 0) {
+ return 0
+ }
val numClasses = counts.length
var impurity = 0.0
var classIndex = 0
@@ -58,6 +61,7 @@ object Entropy extends Impurity {
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
+ * @return information value, or 0 if count = 0
*/
@DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 48144b5e6d..2874bcf496 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -33,10 +33,13 @@ object Gini extends Impurity {
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
- * @return information value
+ * @return information value, or 0 if totalCount = 0
*/
@DeveloperApi
override def calculate(counts: Array[Double], totalCount: Double): Double = {
+ if (totalCount == 0) {
+ return 0
+ }
val numClasses = counts.length
var impurity = 1.0
var classIndex = 0
@@ -54,6 +57,7 @@ object Gini extends Impurity {
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
+ * @return information value, or 0 if count = 0
*/
@DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 7b2a9320cc..92b0c7b4a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -31,7 +31,7 @@ trait Impurity extends Serializable {
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
- * @return information value
+ * @return information value, or 0 if totalCount = 0
*/
@DeveloperApi
def calculate(counts: Array[Double], totalCount: Double): Double
@@ -42,7 +42,7 @@ trait Impurity extends Serializable {
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
- * @return information value
+ * @return information value, or 0 if count = 0
*/
@DeveloperApi
def calculate(count: Double, sum: Double, sumSquares: Double): Double
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 97149a99ea..698a1a2a8e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -31,7 +31,7 @@ object Variance extends Impurity {
* information calculation for multiclass classification
* @param counts Array[Double] with counts for each label
* @param totalCount sum of counts for all labels
- * @return information value
+ * @return information value, or 0 if totalCount = 0
*/
@DeveloperApi
override def calculate(counts: Array[Double], totalCount: Double): Double =
@@ -43,9 +43,13 @@ object Variance extends Impurity {
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
+ * @return information value, or 0 if count = 0
*/
@DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
+ if (count == 0) {
+ return 0
+ }
val squaredLoss = sumSquares - (sum * sum) / count
squaredLoss / count
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index bf692ca8c4..3d3406b5d5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -24,7 +24,8 @@ import org.apache.spark.mllib.linalg.Vector
/**
* :: Experimental ::
- * Model to store the decision tree parameters
+ * Decision tree model for classification or regression.
+ * This model stores the decision tree structure and parameters.
* @param topNode root node
* @param algo algorithm type -- classification or regression
*/
@@ -50,4 +51,32 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
def predict(features: RDD[Vector]): RDD[Double] = {
features.map(x => predict(x))
}
+
+ /**
+ * Get number of nodes in tree, including leaf nodes.
+ */
+ def numNodes: Int = {
+ 1 + topNode.numDescendants
+ }
+
+ /**
+ * Get depth of tree.
+ * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes.
+ */
+ def depth: Int = {
+ topNode.subtreeDepth
+ }
+
+ /**
+ * Print full model.
+ */
+ override def toString: String = algo match {
+ case Classification =>
+ s"DecisionTreeModel classifier\n" + topNode.subtreeToString(2)
+ case Regression =>
+ s"DecisionTreeModel regressor\n" + topNode.subtreeToString(2)
+ case _ => throw new IllegalArgumentException(
+ s"DecisionTreeModel given unknown algo parameter: $algo.")
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 682f213f41..944f11c2c2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -91,4 +91,60 @@ class Node (
}
}
}
+
+ /**
+ * Get the number of nodes in tree below this node, including leaf nodes.
+ * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
+ */
+ private[tree] def numDescendants: Int = {
+ if (isLeaf) {
+ 0
+ } else {
+ 2 + leftNode.get.numDescendants + rightNode.get.numDescendants
+ }
+ }
+
+ /**
+ * Get depth of tree from this node.
+ * E.g.: Depth 0 means this is a leaf node.
+ */
+ private[tree] def subtreeDepth: Int = {
+ if (isLeaf) {
+ 0
+ } else {
+ 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth)
+ }
+ }
+
+ /**
+ * Recursive print function.
+ * @param indentFactor The number of spaces to add to each level of indentation.
+ */
+ private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+
+ def splitToString(split: Split, left: Boolean): String = {
+ split.featureType match {
+ case Continuous => if (left) {
+ s"(feature ${split.feature} <= ${split.threshold})"
+ } else {
+ s"(feature ${split.feature} > ${split.threshold})"
+ }
+ case Categorical => if (left) {
+ s"(feature ${split.feature} in ${split.categories.mkString("{",",","}")})"
+ } else {
+ s"(feature ${split.feature} not in ${split.categories.mkString("{",",","}")})"
+ }
+ }
+ }
+ val prefix: String = " " * indentFactor
+ if (isLeaf) {
+ prefix + s"Predict: $predict\n"
+ } else {
+ prefix + s"If ${splitToString(split.get, left=true)}\n" +
+ leftNode.get.subtreeToString(indentFactor + 1) +
+ prefix + s"Else ${splitToString(split.get, left=false)}\n" +
+ rightNode.get.subtreeToString(indentFactor + 1)
+ }
+ }
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 5961a618c5..10462db700 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -20,8 +20,7 @@ package org.apache.spark.mllib.tree
import org.scalatest.FunSuite
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.Filter
-import org.apache.spark.mllib.tree.model.Split
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
@@ -31,6 +30,18 @@ import org.apache.spark.mllib.regression.LabeledPoint
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
+ def validateClassifier(
+ model: DecisionTreeModel,
+ input: Seq[LabeledPoint],
+ requiredAccuracy: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+ prediction != expected.label
+ }
+ val accuracy = (input.length - numOffPredictions).toDouble / input.length
+ assert(accuracy >= requiredAccuracy)
+ }
+
test("split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
@@ -50,7 +61,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategy = new Strategy(
Classification,
Gini,
- maxDepth = 3,
+ maxDepth = 2,
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
@@ -130,7 +141,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategy = new Strategy(
Classification,
Gini,
- maxDepth = 3,
+ maxDepth = 2,
numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
@@ -236,7 +247,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("extract categories from a number for multiclass classification") {
val l = DecisionTree.extractMultiClassCategories(13, 10)
assert(l.length === 3)
- assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq)
+ assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
}
test("split and bin calculations for unordered categorical variables with multiclass " +
@@ -247,7 +258,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategy = new Strategy(
Classification,
Gini,
- maxDepth = 3,
+ maxDepth = 2,
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
@@ -341,7 +352,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategy = new Strategy(
Classification,
Gini,
- maxDepth = 3,
+ maxDepth = 2,
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
@@ -397,7 +408,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
Classification,
Gini,
numClassesForClassification = 2,
- maxDepth = 3,
+ maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
@@ -413,7 +424,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = bestSplits(0)._2
assert(stats.gain > 0)
assert(stats.predict === 1)
- assert(stats.prob == 0.6)
+ assert(stats.prob === 0.6)
assert(stats.impurity > 0.2)
}
@@ -424,7 +435,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategy = new Strategy(
Regression,
Variance,
- maxDepth = 3,
+ maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
@@ -439,7 +450,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = bestSplits(0)._2
assert(stats.gain > 0)
- assert(stats.predict == 0.6)
+ assert(stats.predict === 0.6)
assert(stats.impurity > 0.2)
}
@@ -460,7 +471,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
- assert(bestSplits(0)._1.threshold === 10)
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
@@ -483,7 +493,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
- assert(bestSplits(0)._1.threshold === 10)
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
@@ -507,7 +516,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
- assert(bestSplits(0)._1.threshold === 10)
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
@@ -531,7 +539,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
- assert(bestSplits(0)._1.threshold === 10)
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
@@ -587,7 +594,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("stump with categorical variables for multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val input = sc.parallelize(arr)
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
@@ -602,12 +609,78 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplit.featureType === Categorical)
}
+ test("stump with 1 continuous variable for binary classification, to check off-by-1 error") {
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0))
+ arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
+ arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0))
+ arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0))
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
+ numClassesForClassification = 2)
+
+ val model = DecisionTree.train(input, strategy)
+ validateClassifier(model, arr, 1.0)
+ assert(model.numNodes === 3)
+ assert(model.depth === 1)
+ }
+
+ test("stump with 2 continuous variables for binary classification") {
+ val arr = new Array[LabeledPoint](4)
+ arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
+ arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
+ arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
+ arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))
+
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
+ numClassesForClassification = 2)
+
+ val model = DecisionTree.train(input, strategy)
+ validateClassifier(model, arr, 1.0)
+ assert(model.numNodes === 3)
+ assert(model.depth === 1)
+ assert(model.topNode.split.get.feature === 1)
+ }
+
+ test("stump with categorical variables for multiclass classification, with just enough bins") {
+ val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features
+ val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
+ numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+ assert(strategy.isMulticlassClassification)
+
+ val model = DecisionTree.train(input, strategy)
+ validateClassifier(model, arr, 1.0)
+ assert(model.numNodes === 3)
+ assert(model.depth === 1)
+
+ val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+ val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+ Array[List[Filter]](), splits, bins, 10)
+
+ assert(bestSplits.length === 1)
+ val bestSplit = bestSplits(0)._1
+ assert(bestSplit.feature === 0)
+ assert(bestSplit.categories.length === 1)
+ assert(bestSplit.categories.contains(1))
+ assert(bestSplit.featureType === Categorical)
+ val gain = bestSplits(0)._2
+ assert(gain.leftImpurity === 0)
+ assert(gain.rightImpurity === 0)
+ }
+
test("stump with continuous variables for multiclass classification") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val input = sc.parallelize(arr)
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3)
assert(strategy.isMulticlassClassification)
+
+ val model = DecisionTree.train(input, strategy)
+ validateClassifier(model, arr, 0.9)
+
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
@@ -625,9 +698,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("stump with continuous + categorical variables for multiclass classification") {
val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
val input = sc.parallelize(arr)
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
+
+ val model = DecisionTree.train(input, strategy)
+ validateClassifier(model, arr, 0.9)
+
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
@@ -644,7 +721,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
test("stump with categorical variables for ordered multiclass classification") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
val input = sc.parallelize(arr)
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)