aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/mllib-decision-tree.md8
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala21
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala732
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala36
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala303
-rw-r--r--project/MimaExcludes.scala10
12 files changed, 926 insertions, 258 deletions
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md
index 9cd768599e..9cbd880897 100644
--- a/docs/mllib-decision-tree.md
+++ b/docs/mllib-decision-tree.md
@@ -77,15 +77,17 @@ bins if the condition is not satisfied.
**Categorical features**
-For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for
-binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the
+For `$M$` categorical feature values, one could come up with `$2^(M-1)-1$` split candidates. For
+binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the
categorical feature values by the proportion of labels falling in one of the two classes (see
Section 9.2.4 in
[Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for
details). For example, for a binary classification problem with one categorical feature with three
categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical
features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B
-and A , B \| C where \| denotes the split.
+and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification
+when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value
+is used for ordering.
### Stopping rule
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index b3cc361154..43f13fe24f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -49,6 +49,7 @@ object DecisionTreeRunner {
case class Params(
input: String = null,
algo: Algo = Classification,
+ numClassesForClassification: Int = 2,
maxDepth: Int = 5,
impurity: ImpurityType = Gini,
maxBins: Int = 100)
@@ -68,6 +69,10 @@ object DecisionTreeRunner {
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
.action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("numClassesForClassification")
+ .text(s"number of classes for classification, "
+ + s"default: ${defaultParams.numClassesForClassification}")
+ .action((x, c) => c.copy(numClassesForClassification = x))
opt[Int]("maxBins")
.text(s"max number of bins, default: ${defaultParams.maxBins}")
.action((x, c) => c.copy(maxBins = x))
@@ -118,7 +123,13 @@ object DecisionTreeRunner {
case Variance => impurity.Variance
}
- val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins)
+ val strategy
+ = new Strategy(
+ algo = params.algo,
+ impurity = impurityCalculator,
+ maxDepth = params.maxDepth,
+ maxBins = params.maxBins,
+ numClassesForClassification = params.numClassesForClassification)
val model = DecisionTree.train(training, strategy)
if (params.algo == Classification) {
@@ -139,12 +150,8 @@ object DecisionTreeRunner {
*/
private def accuracyScore(
model: DecisionTreeModel,
- data: RDD[LabeledPoint],
- threshold: Double = 0.5): Double = {
- def predictedValue(features: Vector): Double = {
- if (model.predict(features) < threshold) 0.0 else 1.0
- }
- val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
+ data: RDD[LabeledPoint]): Double = {
+ val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
val count = data.count()
correctCount.toDouble / count
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 74d5d7ba10..ad32e3f456 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -77,11 +77,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Max memory usage for aggregates
val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
- val numElementsPerNode =
- strategy.algo match {
- case Classification => 2 * numBins * numFeatures
- case Regression => 3 * numBins * numFeatures
- }
+ val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins,
+ strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures,
+ strategy.algo)
logDebug("numElementsPerNode = " + numElementsPerNode)
val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
@@ -109,8 +107,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("#####################################")
// Find best split for all nodes at a level.
- val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
- level, filters, splits, bins, maxLevelForSingleGroup)
+ val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities,
+ strategy, level, filters, splits, bins, maxLevelForSingleGroup)
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
// Extract info for nodes at the current level.
@@ -212,7 +210,7 @@ object DecisionTree extends Serializable with Logging {
* @return a DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
- new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+ new DecisionTree(strategy).train(input)
}
/**
@@ -233,10 +231,33 @@ object DecisionTree extends Serializable with Logging {
algo: Algo,
impurity: Impurity,
maxDepth: Int): DecisionTreeModel = {
- val strategy = new Strategy(algo,impurity,maxDepth)
- new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+ val strategy = new Strategy(algo, impurity, maxDepth)
+ new DecisionTree(strategy).train(input)
}
+ /**
+ * Method to train a decision tree model where the instances are represented as an RDD of
+ * (label, features) pairs. The method supports binary classification and regression. For the
+ * binary classification, the label for each instance should either be 0 or 1 to denote the two
+ * classes.
+ *
+ * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
+ * training data
+ * @param algo algorithm, classification or regression
+ * @param impurity impurity criterion used for information gain calculation
+ * @param maxDepth maxDepth maximum depth of the tree
+ * @param numClassesForClassification number of classes for classification. Default value of 2.
+ * @return a DecisionTreeModel that can be used for prediction
+ */
+ def train(
+ input: RDD[LabeledPoint],
+ algo: Algo,
+ impurity: Impurity,
+ maxDepth: Int,
+ numClassesForClassification: Int): DecisionTreeModel = {
+ val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
+ new DecisionTree(strategy).train(input)
+ }
/**
* Method to train a decision tree model where the instances are represented as an RDD of
@@ -250,6 +271,7 @@ object DecisionTree extends Serializable with Logging {
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
+ * @param numClassesForClassification number of classes for classification. Default value of 2.
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and
@@ -264,12 +286,13 @@ object DecisionTree extends Serializable with Logging {
algo: Algo,
impurity: Impurity,
maxDepth: Int,
+ numClassesForClassification: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
- val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
- categoricalFeaturesInfo)
- new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+ val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
+ quantileCalculationStrategy, categoricalFeaturesInfo)
+ new DecisionTree(strategy).train(input)
}
private val InvalidBinIndex = -1
@@ -381,6 +404,14 @@ object DecisionTree extends Serializable with Logging {
logDebug("numFeatures = " + numFeatures)
val numBins = bins(0).length
logDebug("numBins = " + numBins)
+ val numClasses = strategy.numClassesForClassification
+ logDebug("numClasses = " + numClasses)
+ val isMulticlassClassification = strategy.isMulticlassClassification
+ logDebug("isMulticlassClassification = " + isMulticlassClassification)
+ val isMulticlassClassificationWithCategoricalFeatures
+ = strategy.isMulticlassWithCategoricalFeatures
+ logDebug("isMultiClassWithCategoricalFeatures = " +
+ isMulticlassClassificationWithCategoricalFeatures)
// shift when more than one group is used at deep tree level
val groupShift = numNodes * groupIndex
@@ -436,10 +467,8 @@ object DecisionTree extends Serializable with Logging {
/**
* Find bin for one feature.
*/
- def findBin(
- featureIndex: Int,
- labeledPoint: LabeledPoint,
- isFeatureContinuous: Boolean): Int = {
+ def findBin(featureIndex: Int, labeledPoint: LabeledPoint,
+ isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = {
val binForFeatures = bins(featureIndex)
val feature = labeledPoint.features(featureIndex)
@@ -468,16 +497,27 @@ object DecisionTree extends Serializable with Logging {
}
/**
+ * Sequential search helper method to find bin for categorical feature in multiclass
+ * classification. The category is returned since each category can belong to multiple
+ * splits. The actual left/right child allocation per split is performed in the
+ * sequential phase of the bin aggregate operation.
+ */
+ def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
+ labeledPoint.features(featureIndex).toInt
+ }
+
+ /**
* Sequential search helper method to find bin for categorical feature.
*/
- def sequentialBinSearchForCategoricalFeature(): Int = {
- val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
+ def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = {
+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+ val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
var binIndex = 0
while (binIndex < numCategoricalBins) {
val bin = bins(featureIndex)(binIndex)
- val category = bin.category
+ val categories = bin.highSplit.categories
val features = labeledPoint.features
- if (category == features(featureIndex)) {
+ if (categories.contains(features(featureIndex))) {
return binIndex
}
binIndex += 1
@@ -494,7 +534,13 @@ object DecisionTree extends Serializable with Logging {
binIndex
} else {
// Perform sequential search to find bin for categorical features.
- val binIndex = sequentialBinSearchForCategoricalFeature()
+ val binIndex = {
+ if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+ sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
+ } else {
+ sequentialBinSearchForOrderedCategoricalFeatureInClassification()
+ }
+ }
if (binIndex == -1){
throw new UnknownError("no bin was found for categorical variable.")
}
@@ -506,13 +552,16 @@ object DecisionTree extends Serializable with Logging {
* Finds bins for all nodes (and all features) at a given level.
* For l nodes, k features the storage is as follows:
* label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
- * where b_ij is an integer between 0 and numBins - 1.
+ * where b_ij is an integer between 0 and numBins - 1 for regressions and binary
+ * classification and the categorical feature value in multiclass classification.
* Invalid sample is denoted by noting bin for feature 1 as -1.
*/
def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
// Calculate bin index and label per feature per node.
val arr = new Array[Double](1 + (numFeatures * numNodes))
+ // First element of the array is the label of the instance.
arr(0) = labeledPoint.label
+ // Iterate over nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
val parentFilters = findParentFilters(nodeIndex)
@@ -525,8 +574,19 @@ object DecisionTree extends Serializable with Logging {
} else {
var featureIndex = 0
while (featureIndex < numFeatures) {
- val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
- arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous)
+ val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex)
+ val isFeatureContinuous = featureInfo.isEmpty
+ if (isFeatureContinuous) {
+ arr(shift + featureIndex)
+ = findBin(featureIndex, labeledPoint, isFeatureContinuous, false)
+ } else {
+ val featureCategories = featureInfo.get
+ val isSpaceSufficientForAllCategoricalSplits
+ = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+ arr(shift + featureIndex)
+ = findBin(featureIndex, labeledPoint, isFeatureContinuous,
+ isSpaceSufficientForAllCategoricalSplits)
+ }
featureIndex += 1
}
}
@@ -535,18 +595,61 @@ object DecisionTree extends Serializable with Logging {
arr
}
+ // Find feature bins for all nodes at a level.
+ val binMappedRDD = input.map(x => findBinsForLevel(x))
+
+ def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int,
+ label: Double, featureIndex: Int) = {
+
+ // Find the bin index for this feature.
+ val arrShift = 1 + numFeatures * nodeIndex
+ val arrIndex = arrShift + featureIndex
+ // Update the left or right count for one bin.
+ val aggShift = numClasses * numBins * numFeatures * nodeIndex
+ val aggIndex
+ = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
+ val labelInt = label.toInt
+ agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
+ }
+
+ def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double],
+ label: Double, agg: Array[Double], rightChildShift: Int) = {
+ // Find the bin index for this feature.
+ val arrShift = 1 + numFeatures * nodeIndex
+ val arrIndex = arrShift + featureIndex
+ // Update the left or right count for one bin.
+ val aggShift = numClasses * numBins * numFeatures * nodeIndex
+ val aggIndex
+ = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
+ // Find all matching bins and increment their values
+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+ val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
+ var binIndex = 0
+ while (binIndex < numCategoricalBins) {
+ val labelInt = label.toInt
+ if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) {
+ agg(aggIndex + binIndex)
+ = agg(aggIndex + binIndex) + 1
+ } else {
+ agg(rightChildShift + aggIndex + binIndex)
+ = agg(rightChildShift + aggIndex + binIndex) + 1
+ }
+ binIndex += 1
+ }
+ }
+
/**
* Performs a sequential aggregation over a partition for classification. For l nodes,
* k features, either the left count or the right count of one of the p bins is
* incremented based upon whether the feature is classified as 0 or 1.
*
* @param agg Array[Double] storing aggregate calculation of size
- * 2 * numSplits * numFeatures*numNodes for classification
+ * numClasses * numSplits * numFeatures*numNodes for classification
* @param arr Array[Double] of size 1 + (numFeatures * numNodes)
* @return Array[Double] storing aggregate calculation of size
* 2 * numSplits * numFeatures * numNodes for classification
*/
- def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+ def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
// Iterate over all nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
@@ -559,15 +662,52 @@ object DecisionTree extends Serializable with Logging {
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
- // Find the bin index for this feature.
- val arrShift = 1 + numFeatures * nodeIndex
- val arrIndex = arrShift + featureIndex
- // Update the left or right count for one bin.
- val aggShift = 2 * numBins * numFeatures * nodeIndex
- val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
- label match {
- case 0.0 => agg(aggIndex) = agg(aggIndex) + 1
- case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1
+ updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
+ featureIndex += 1
+ }
+ }
+ nodeIndex += 1
+ }
+ }
+
+ /**
+ * Performs a sequential aggregation over a partition for classification. For l nodes,
+ * k features, either the left count or the right count of one of the p bins is
+ * incremented based upon whether the feature is classified as 0 or 1.
+ *
+ * @param agg Array[Double] storing aggregate calculation of size
+ * numClasses * numSplits * numFeatures*numNodes for classification
+ * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
+ * @return Array[Double] storing aggregate calculation of size
+ * 2 * numClasses * numSplits * numFeatures * numNodes for classification
+ */
+ def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
+ // Iterate over all nodes.
+ var nodeIndex = 0
+ while (nodeIndex < numNodes) {
+ // Check whether the instance was valid for this nodeIndex.
+ val validSignalIndex = 1 + numFeatures * nodeIndex
+ val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
+ if (isSampleValidForNode) {
+ val rightChildShift = numClasses * numBins * numFeatures * numNodes
+ // actual class label
+ val label = arr(0)
+ // Iterate over all features.
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ if (isFeatureContinuous) {
+ updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
+ } else {
+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+ val isSpaceSufficientForAllCategoricalSplits
+ = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+ if (isSpaceSufficientForAllCategoricalSplits) {
+ updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg,
+ rightChildShift)
+ } else {
+ updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
+ }
}
featureIndex += 1
}
@@ -586,7 +726,7 @@ object DecisionTree extends Serializable with Logging {
* @return Array[Double] storing aggregate calculation of size
* 3 * numSplits * numFeatures * numNodes for regression
*/
- def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+ def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
// Iterate over all nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
@@ -620,17 +760,20 @@ object DecisionTree extends Serializable with Logging {
*/
def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
strategy.algo match {
- case Classification => classificationBinSeqOp(arr, agg)
+ case Classification =>
+ if(isMulticlassClassificationWithCategoricalFeatures) {
+ unorderedClassificationBinSeqOp(arr, agg)
+ } else {
+ orderedClassificationBinSeqOp(arr, agg)
+ }
case Regression => regressionBinSeqOp(arr, agg)
}
agg
}
// Calculate bin aggregate length for classification or regression.
- val binAggregateLength = strategy.algo match {
- case Classification => 2 * numBins * numFeatures * numNodes
- case Regression => 3 * numBins * numFeatures * numNodes
- }
+ val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses,
+ isMulticlassClassificationWithCategoricalFeatures, strategy.algo)
logDebug("binAggregateLength = " + binAggregateLength)
/**
@@ -649,9 +792,6 @@ object DecisionTree extends Serializable with Logging {
combinedAggregate
}
- // Find feature bins for all nodes at a level.
- val binMappedRDD = input.map(x => findBinsForLevel(x))
-
// Calculate bin aggregates.
val binAggregates = {
binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
@@ -668,42 +808,55 @@ object DecisionTree extends Serializable with Logging {
* @return information gain and statistics for all splits
*/
def calculateGainForSplit(
- leftNodeAgg: Array[Array[Double]],
+ leftNodeAgg: Array[Array[Array[Double]]],
featureIndex: Int,
splitIndex: Int,
- rightNodeAgg: Array[Array[Double]],
+ rightNodeAgg: Array[Array[Array[Double]]],
topImpurity: Double): InformationGainStats = {
strategy.algo match {
case Classification =>
- val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
- val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1)
- val leftCount = left0Count + left1Count
-
- val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex)
- val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1)
- val rightCount = right0Count + right1Count
+ var classIndex = 0
+ val leftCounts: Array[Double] = new Array[Double](numClasses)
+ val rightCounts: Array[Double] = new Array[Double](numClasses)
+ var leftTotalCount = 0.0
+ var rightTotalCount = 0.0
+ while (classIndex < numClasses) {
+ val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex)
+ val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex)
+ leftCounts(classIndex) = leftClassCount
+ leftTotalCount += leftClassCount
+ rightCounts(classIndex) = rightClassCount
+ rightTotalCount += rightClassCount
+ classIndex += 1
+ }
val impurity = {
if (level > 0) {
topImpurity
} else {
// Calculate impurity for root node.
- strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
+ val rootNodeCounts = new Array[Double](numClasses)
+ var classIndex = 0
+ while (classIndex < numClasses) {
+ rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex)
+ classIndex += 1
+ }
+ strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
}
}
- if (leftCount == 0) {
- return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1)
+ if (leftTotalCount == 0) {
+ return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1)
}
- if (rightCount == 0) {
- return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0)
+ if (rightTotalCount == 0) {
+ return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1)
}
- val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
- val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
+ val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount)
+ val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount)
- val leftWeight = leftCount.toDouble / (leftCount + rightCount)
- val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+ val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount)
+ val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount)
val gain = {
if (level > 0) {
@@ -713,17 +866,34 @@ object DecisionTree extends Serializable with Logging {
}
}
- val predict = (left1Count + right1Count) / (leftCount + rightCount)
+ val totalCount = leftTotalCount + rightTotalCount
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+ // Sum of count for each label
+ val leftRightCounts: Array[Double]
+ = leftCounts.zip(rightCounts)
+ .map{case (leftCount, rightCount) => leftCount + rightCount}
+
+ def indexOfLargestArrayElement(array: Array[Double]): Int = {
+ val result = array.foldLeft(-1, Double.MinValue, 0) {
+ case ((maxIndex, maxValue, currentIndex), currentValue) =>
+ if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1)
+ else (maxIndex, maxValue, currentIndex + 1)
+ }
+ if (result._1 < 0) 0 else result._1
+ }
+
+ val predict = indexOfLargestArrayElement(leftRightCounts)
+ val prob = leftRightCounts(predict) / totalCount
+
+ new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
case Regression =>
- val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
- val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1)
- val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2)
+ val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
+ val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
+ val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2)
- val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex)
- val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1)
- val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2)
+ val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0)
+ val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1)
+ val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2)
val impurity = {
if (level > 0) {
@@ -768,104 +938,149 @@ object DecisionTree extends Serializable with Logging {
/**
* Extracts left and right split aggregates.
* @param binData Array[Double] of size 2*numFeatures*numSplits
- * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double],
- * Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
+ * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\],
+ * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature,
+ * (numBins - 1), numClasses)
*/
def extractLeftRightNodeAggregates(
- binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
+ binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
+
+
+ def findAggForOrderedFeatureClassification(
+ leftNodeAgg: Array[Array[Array[Double]]],
+ rightNodeAgg: Array[Array[Array[Double]]],
+ featureIndex: Int) {
+
+ // shift for this featureIndex
+ val shift = numClasses * featureIndex * numBins
+
+ var classIndex = 0
+ while (classIndex < numClasses) {
+ // left node aggregate for the lowest split
+ leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex)
+ // right node aggregate for the highest split
+ rightNodeAgg(featureIndex)(numBins - 2)(classIndex)
+ = binData(shift + (numClasses * (numBins - 1)) + classIndex)
+ classIndex += 1
+ }
+
+ // Iterate over all splits.
+ var splitIndex = 1
+ while (splitIndex < numBins - 1) {
+ // calculating left node aggregate for a split as a sum of left node aggregate of a
+ // lower split and the left bin aggregate of a bin where the split is a high split
+ var innerClassIndex = 0
+ while (innerClassIndex < numClasses) {
+ leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex)
+ = binData(shift + numClasses * splitIndex + innerClassIndex) +
+ leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex)
+ rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) =
+ binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) +
+ rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex)
+ innerClassIndex += 1
+ }
+ splitIndex += 1
+ }
+ }
+
+ def findAggForUnorderedFeatureClassification(
+ leftNodeAgg: Array[Array[Array[Double]]],
+ rightNodeAgg: Array[Array[Array[Double]]],
+ featureIndex: Int) {
+
+ val rightChildShift = numClasses * numBins * numFeatures
+ var splitIndex = 0
+ while (splitIndex < numBins - 1) {
+ var classIndex = 0
+ while (classIndex < numClasses) {
+ // shift for this featureIndex
+ val shift = numClasses * featureIndex * numBins + splitIndex * numClasses
+ val leftBinValue = binData(shift + classIndex)
+ val rightBinValue = binData(rightChildShift + shift + classIndex)
+ leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue
+ rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue
+ classIndex += 1
+ }
+ splitIndex += 1
+ }
+ }
+
+ def findAggForRegression(
+ leftNodeAgg: Array[Array[Array[Double]]],
+ rightNodeAgg: Array[Array[Array[Double]]],
+ featureIndex: Int) {
+
+ // shift for this featureIndex
+ val shift = 3 * featureIndex * numBins
+ // left node aggregate for the lowest split
+ leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
+ leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)
+ leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2)
+
+ // right node aggregate for the highest split
+ rightNodeAgg(featureIndex)(numBins - 2)(0) =
+ binData(shift + (3 * (numBins - 1)))
+ rightNodeAgg(featureIndex)(numBins - 2)(1) =
+ binData(shift + (3 * (numBins - 1)) + 1)
+ rightNodeAgg(featureIndex)(numBins - 2)(2) =
+ binData(shift + (3 * (numBins - 1)) + 2)
+
+ // Iterate over all splits.
+ var splitIndex = 1
+ while (splitIndex < numBins - 1) {
+ var i = 0 // index for regression histograms
+ while (i < 3) { // count, sum, sum^2
+ // calculating left node aggregate for a split as a sum of left node aggregate of a
+ // lower split and the left bin aggregate of a bin where the split is a high split
+ leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) +
+ leftNodeAgg(featureIndex)(splitIndex - 1)(i)
+ // calculating right node aggregate for a split as a sum of right node aggregate of a
+ // higher split and the right bin aggregate of a bin where the split is a low split
+ rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) =
+ binData(shift + (3 * (numBins - 1 - splitIndex) + i)) +
+ rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i)
+ i += 1
+ }
+ splitIndex += 1
+ }
+ }
+
strategy.algo match {
case Classification =>
// Initialize left and right split aggregates.
- val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
- val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
- // Iterate over all features.
+ val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
+ val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
var featureIndex = 0
while (featureIndex < numFeatures) {
- // shift for this featureIndex
- val shift = 2 * featureIndex * numBins
-
- // left node aggregate for the lowest split
- leftNodeAgg(featureIndex)(0) = binData(shift + 0)
- leftNodeAgg(featureIndex)(1) = binData(shift + 1)
-
- // right node aggregate for the highest split
- rightNodeAgg(featureIndex)(2 * (numBins - 2))
- = binData(shift + (2 * (numBins - 1)))
- rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1)
- = binData(shift + (2 * (numBins - 1)) + 1)
-
- // Iterate over all splits.
- var splitIndex = 1
- while (splitIndex < numBins - 1) {
- // calculating left node aggregate for a split as a sum of left node aggregate of a
- // lower split and the left bin aggregate of a bin where the split is a high split
- leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) +
- leftNodeAgg(featureIndex)(2 * splitIndex - 2)
- leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) +
- leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
-
- // calculating right node aggregate for a split as a sum of right node aggregate of a
- // higher split and the right bin aggregate of a bin where the split is a low split
- rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
- binData(shift + (2 *(numBins - 1 - splitIndex))) +
- rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
- rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
- binData(shift + (2* (numBins - 1 - splitIndex) + 1)) +
- rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
-
- splitIndex += 1
+ if (isMulticlassClassificationWithCategoricalFeatures){
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ if (isFeatureContinuous) {
+ findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+ } else {
+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+ val isSpaceSufficientForAllCategoricalSplits
+ = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+ if (isSpaceSufficientForAllCategoricalSplits) {
+ findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+ } else {
+ findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+ }
+ }
+ } else {
+ findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
}
featureIndex += 1
}
+
(leftNodeAgg, rightNodeAgg)
case Regression =>
// Initialize left and right split aggregates.
- val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
- val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
+ val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
+ val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
- // shift for this featureIndex
- val shift = 3 * featureIndex * numBins
- // left node aggregate for the lowest split
- leftNodeAgg(featureIndex)(0) = binData(shift + 0)
- leftNodeAgg(featureIndex)(1) = binData(shift + 1)
- leftNodeAgg(featureIndex)(2) = binData(shift + 2)
-
- // right node aggregate for the highest split
- rightNodeAgg(featureIndex)(3 * (numBins - 2)) =
- binData(shift + (3 * (numBins - 1)))
- rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) =
- binData(shift + (3 * (numBins - 1)) + 1)
- rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) =
- binData(shift + (3 * (numBins - 1)) + 2)
-
- // Iterate over all splits.
- var splitIndex = 1
- while (splitIndex < numBins - 1) {
- // calculating left node aggregate for a split as a sum of left node aggregate of a
- // lower split and the left bin aggregate of a bin where the split is a high split
- leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) +
- leftNodeAgg(featureIndex)(3 * splitIndex - 3)
- leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) +
- leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1)
- leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) +
- leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
-
- // calculating right node aggregate for a split as a sum of right node aggregate of a
- // higher split and the right bin aggregate of a bin where the split is a low split
- rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
- binData(shift + (3 * (numBins - 1 - splitIndex))) +
- rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
- rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
- binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) +
- rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
- rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
- binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) +
- rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
-
- splitIndex += 1
- }
+ findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
featureIndex += 1
}
(leftNodeAgg, rightNodeAgg)
@@ -876,8 +1091,8 @@ object DecisionTree extends Serializable with Logging {
* Calculates information gain for all nodes splits.
*/
def calculateGainsForAllNodeSplits(
- leftNodeAgg: Array[Array[Double]],
- rightNodeAgg: Array[Array[Double]],
+ leftNodeAgg: Array[Array[Array[Double]]],
+ rightNodeAgg: Array[Array[Array[Double]]],
nodeImpurity: Double): Array[Array[InformationGainStats]] = {
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
@@ -918,7 +1133,22 @@ object DecisionTree extends Serializable with Logging {
while (featureIndex < numFeatures) {
// Iterate over all splits.
var splitIndex = 0
- while (splitIndex < numBins - 1) {
+ val maxSplitIndex : Double = {
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ if (isFeatureContinuous) {
+ numBins - 1
+ } else { // Categorical feature
+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+ val isSpaceSufficientForAllCategoricalSplits
+ = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+ if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+ math.pow(2.0, featureCategories - 1).toInt - 1
+ } else { // Binary classification
+ featureCategories
+ }
+ }
+ }
+ while (splitIndex < maxSplitIndex) {
val gainStats = gains(featureIndex)(splitIndex)
if (gainStats.gain > bestGainStats.gain) {
bestGainStats = gainStats
@@ -944,9 +1174,23 @@ object DecisionTree extends Serializable with Logging {
def getBinDataForNode(node: Int): Array[Double] = {
strategy.algo match {
case Classification =>
- val shift = 2 * node * numBins * numFeatures
- val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
- binsForNode
+ if (isMulticlassClassificationWithCategoricalFeatures) {
+ val shift = numClasses * node * numBins * numFeatures
+ val rightChildShift = numClasses * numBins * numFeatures * numNodes
+ val binsForNode = {
+ val leftChildData
+ = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+ val rightChildData
+ = binAggregates.slice(rightChildShift + shift,
+ rightChildShift + shift + numClasses * numBins * numFeatures)
+ leftChildData ++ rightChildData
+ }
+ binsForNode
+ } else {
+ val shift = numClasses * node * numBins * numFeatures
+ val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+ binsForNode
+ }
case Regression =>
val shift = 3 * node * numBins * numFeatures
val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
@@ -963,14 +1207,26 @@ object DecisionTree extends Serializable with Logging {
val binsForNode: Array[Double] = getBinDataForNode(node)
logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
- logDebug("node impurity = " + parentNodeImpurity)
+ logDebug("parent node impurity = " + parentNodeImpurity)
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
node += 1
}
-
bestSplits
}
+ private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int,
+ isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = {
+ algo match {
+ case Classification =>
+ if (isMulticlassClassificationWithCategoricalFeatures) {
+ 2 * numClasses * numBins * numFeatures
+ } else {
+ numClasses * numBins * numFeatures
+ }
+ case Regression => 3 * numBins * numFeatures
+ }
+ }
+
/**
* Returns split and bins for decision tree calculation.
* @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
@@ -992,17 +1248,23 @@ object DecisionTree extends Serializable with Logging {
val maxBins = strategy.maxBins
val numBins = if (maxBins <= count) maxBins else count.toInt
logDebug("numBins = " + numBins)
+ val isMulticlassClassification = strategy.isMulticlassClassification
+ logDebug("isMulticlassClassification = " + isMulticlassClassification)
+
/*
- * TODO: Add a require statement ensuring #bins is always greater than the categories.
+ * Ensure #bins is always greater than the categories. For multiclass classification,
+ * #bins should be greater than 2^(maxCategories - 1) - 1.
* It's a limitation of the current implementation but a reasonable trade-off since features
* with large number of categories get favored over continuous features.
*/
if (strategy.categoricalFeaturesInfo.size > 0) {
val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
- require(numBins >= maxCategoriesForFeatures)
+ require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
+ "in categorical features")
}
+
// Calculate the number of sample for approximate quantile calculation.
val requiredSamples = numBins*numBins
val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
@@ -1036,48 +1298,93 @@ object DecisionTree extends Serializable with Logging {
val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List())
splits(featureIndex)(index) = split
}
- } else {
- val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
- require(maxFeatureValue < numBins, "number of categories should be less than number " +
- "of bins")
-
- // For categorical variables, each bin is a category. The bins are sorted and they
- // are ordered by calculating the centroid of their corresponding labels.
- val centroidForCategories =
- sampledInput.map(lp => (lp.features(featureIndex),lp.label))
- .groupBy(_._1)
- .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
-
- // Check for missing categorical variables and putting them last in the sorted list.
- val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
- for (i <- 0 until maxFeatureValue) {
- if (centroidForCategories.contains(i)) {
- fullCentroidForCategories(i) = centroidForCategories(i)
- } else {
- fullCentroidForCategories(i) = Double.MaxValue
- }
- }
-
- // bins sorted by centroids
- val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
-
- logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)
-
- var categoriesForSplit = List[Double]()
- categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
- case ((key, value), index) =>
- categoriesForSplit = key :: categoriesForSplit
- splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical,
- categoriesForSplit)
+ } else { // Categorical feature
+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+ val isSpaceSufficientForAllCategoricalSplits
+ = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+
+ // Use different bin/split calculation strategy for categorical features in multiclass
+ // classification that satisfy the space constraint
+ if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+ // 2^(maxFeatureValue- 1) - 1 combinations
+ var index = 0
+ while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
+ val categories: List[Double]
+ = extractMultiClassCategories(index + 1, featureCategories)
+ splits(featureIndex)(index)
+ = new Split(featureIndex, Double.MinValue, Categorical, categories)
bins(featureIndex)(index) = {
if (index == 0) {
- new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
- splits(featureIndex)(0), Categorical, key)
+ new Bin(
+ new DummyCategoricalSplit(featureIndex, Categorical),
+ splits(featureIndex)(0),
+ Categorical,
+ Double.MinValue)
} else {
- new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
- Categorical, key)
+ new Bin(
+ splits(featureIndex)(index - 1),
+ splits(featureIndex)(index),
+ Categorical,
+ Double.MinValue)
}
}
+ index += 1
+ }
+ } else {
+
+ val centroidForCategories = {
+ if (isMulticlassClassification) {
+ // For categorical variables in multiclass classification,
+ // each bin is a category. The bins are sorted and they
+ // are ordered by calculating the impurity of their corresponding labels.
+ sampledInput.map(lp => (lp.features(featureIndex), lp.label))
+ .groupBy(_._1)
+ .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
+ .map(x => (x._1, x._2.values.toArray))
+ .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum)))
+ } else { // regression or binary classification
+ // For categorical variables in regression and binary classification,
+ // each bin is a category. The bins are sorted and they
+ // are ordered by calculating the centroid of their corresponding labels.
+ sampledInput.map(lp => (lp.features(featureIndex), lp.label))
+ .groupBy(_._1)
+ .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
+ }
+ }
+
+ logDebug("centriod for categories = " + centroidForCategories.mkString(","))
+
+ // Check for missing categorical variables and putting them last in the sorted list.
+ val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
+ for (i <- 0 until featureCategories) {
+ if (centroidForCategories.contains(i)) {
+ fullCentroidForCategories(i) = centroidForCategories(i)
+ } else {
+ fullCentroidForCategories(i) = Double.MaxValue
+ }
+ }
+
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
+
+ logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)
+
+ var categoriesForSplit = List[Double]()
+ categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
+ case ((key, value), index) =>
+ categoriesForSplit = key :: categoriesForSplit
+ splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue,
+ Categorical, categoriesForSplit)
+ bins(featureIndex)(index) = {
+ if (index == 0) {
+ new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
+ splits(featureIndex)(0), Categorical, key)
+ } else {
+ new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
+ Categorical, key)
+ }
+ }
+ }
}
}
featureIndex += 1
@@ -1107,4 +1414,29 @@ object DecisionTree extends Serializable with Logging {
throw new UnsupportedOperationException("approximate histogram not supported yet.")
}
}
+
+ /**
+ * Nested method to extract list of eligible categories given an index. It extracts the
+ * position of ones in a binary representation of the input. If binary
+ * representation of an number is 01101 (13), the output list should (3.0, 2.0,
+ * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
+ */
+ private[tree] def extractMultiClassCategories(
+ input: Int,
+ maxFeatureValue: Int): List[Double] = {
+ var categories = List[Double]()
+ var j = 0
+ var bitShiftedInput = input
+ while (j < maxFeatureValue) {
+ if (bitShiftedInput % 2 != 0) {
+ // updating the list of categories.
+ categories = j.toDouble :: categories
+ }
+ // Right shift by one
+ bitShiftedInput = bitShiftedInput >> 1
+ j += 1
+ }
+ categories
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 1b505fd76e..7c027ac2fd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
+ * @param numClassesForClassification number of classes for classification. Default value is 2
+ * leads to binary classification
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
@@ -44,7 +46,15 @@ class Strategy (
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
+ val numClassesForClassification: Int = 2,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
- val maxMemoryInMB: Int = 128) extends Serializable
+ val maxMemoryInMB: Int = 128) extends Serializable {
+
+ require(numClassesForClassification >= 2)
+ val isMulticlassClassification = numClassesForClassification > 2
+ val isMulticlassWithCategoricalFeatures
+ = isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 60f43e9278..a0e2d91762 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -31,23 +31,35 @@ object Entropy extends Impurity {
/**
* :: DeveloperApi ::
- * entropy calculation
- * @param c0 count of instances with label 0
- * @param c1 count of instances with label 1
- * @return entropy value
+ * information calculation for multiclass classification
+ * @param counts Array[Double] with counts for each label
+ * @param totalCount sum of counts for all labels
+ * @return information value
*/
@DeveloperApi
- override def calculate(c0: Double, c1: Double): Double = {
- if (c0 == 0 || c1 == 0) {
- 0
- } else {
- val total = c0 + c1
- val f0 = c0 / total
- val f1 = c1 / total
- -(f0 * log2(f0)) - (f1 * log2(f1))
+ override def calculate(counts: Array[Double], totalCount: Double): Double = {
+ val numClasses = counts.length
+ var impurity = 0.0
+ var classIndex = 0
+ while (classIndex < numClasses) {
+ val classCount = counts(classIndex)
+ if (classCount != 0) {
+ val freq = classCount / totalCount
+ impurity -= freq * log2(freq)
+ }
+ classIndex += 1
}
+ impurity
}
+ /**
+ * :: DeveloperApi ::
+ * variance calculation
+ * @param count number of instances
+ * @param sum sum of labels
+ * @param sumSquares summation of squares of the labels
+ */
+ @DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Entropy.calculate")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index c51d76d9b4..48144b5e6d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -30,23 +30,32 @@ object Gini extends Impurity {
/**
* :: DeveloperApi ::
- * Gini coefficient calculation
- * @param c0 count of instances with label 0
- * @param c1 count of instances with label 1
- * @return Gini coefficient value
+ * information calculation for multiclass classification
+ * @param counts Array[Double] with counts for each label
+ * @param totalCount sum of counts for all labels
+ * @return information value
*/
@DeveloperApi
- override def calculate(c0: Double, c1: Double): Double = {
- if (c0 == 0 || c1 == 0) {
- 0
- } else {
- val total = c0 + c1
- val f0 = c0 / total
- val f1 = c1 / total
- 1 - f0 * f0 - f1 * f1
+ override def calculate(counts: Array[Double], totalCount: Double): Double = {
+ val numClasses = counts.length
+ var impurity = 1.0
+ var classIndex = 0
+ while (classIndex < numClasses) {
+ val freq = counts(classIndex) / totalCount
+ impurity -= freq * freq
+ classIndex += 1
}
+ impurity
}
+ /**
+ * :: DeveloperApi ::
+ * variance calculation
+ * @param count number of instances
+ * @param sum sum of labels
+ * @param sumSquares summation of squares of the labels
+ */
+ @DeveloperApi
override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Gini.calculate")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 8eab247cf0..7b2a9320cc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -28,13 +28,13 @@ trait Impurity extends Serializable {
/**
* :: DeveloperApi ::
- * information calculation for binary classification
- * @param c0 count of instances with label 0
- * @param c1 count of instances with label 1
+ * information calculation for multiclass classification
+ * @param counts Array[Double] with counts for each label
+ * @param totalCount sum of counts for all labels
* @return information value
*/
@DeveloperApi
- def calculate(c0 : Double, c1 : Double): Double
+ def calculate(counts: Array[Double], totalCount: Double): Double
/**
* :: DeveloperApi ::
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 47d07122af..97149a99ea 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -25,7 +25,16 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
*/
@Experimental
object Variance extends Impurity {
- override def calculate(c0: Double, c1: Double): Double =
+
+ /**
+ * :: DeveloperApi ::
+ * information calculation for multiclass classification
+ * @param counts Array[Double] with counts for each label
+ * @param totalCount sum of counts for all labels
+ * @return information value
+ */
+ @DeveloperApi
+ override def calculate(counts: Array[Double], totalCount: Double): Double =
throw new UnsupportedOperationException("Variance.calculate")
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
index 2d71e1e366..c89c1e371a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
- * @param category categorical label value accepted in the bin
+ * @param category categorical label value accepted in the bin for binary classification
*/
private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index cc8a24cce9..fb12298e0f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -27,6 +27,7 @@ import org.apache.spark.annotation.DeveloperApi
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
* @param predict predicted value
+ * @param prob probability of the label (classification only)
*/
@DeveloperApi
class InformationGainStats(
@@ -34,10 +35,11 @@ class InformationGainStats(
val impurity: Double,
val leftImpurity: Double,
val rightImpurity: Double,
- val predict: Double) extends Serializable {
+ val predict: Double,
+ val prob: Double = 0.0) extends Serializable {
override def toString = {
- "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
- .format(gain, impurity, leftImpurity, rightImpurity, predict)
+ "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
+ .format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index bcb11876b8..5961a618c5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree
import org.scalatest.FunSuite
-import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.Filter
import org.apache.spark.mllib.tree.model.Split
@@ -28,6 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
@@ -35,7 +35,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Gini, 3, 100)
+ val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(bins.length === 2)
@@ -51,6 +51,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
Classification,
Gini,
maxDepth = 3,
+ numClassesForClassification = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
@@ -130,8 +131,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
Classification,
Gini,
maxDepth = 3,
+ numClassesForClassification = 2,
maxBins = 100,
- categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
// Check splits.
@@ -231,6 +233,162 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins(1)(3) === null)
}
+ test("extract categories from a number for multiclass classification") {
+ val l = DecisionTree.extractMultiClassCategories(13, 10)
+ assert(l.length === 3)
+ assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq)
+ }
+
+ test("split and bin calculations for unordered categorical variables with multiclass " +
+ "classification") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ numClassesForClassification = 100,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+ // Expecting 2^2 - 1 = 3 bins/splits
+ assert(splits(0)(0).feature === 0)
+ assert(splits(0)(0).threshold === Double.MinValue)
+ assert(splits(0)(0).featureType === Categorical)
+ assert(splits(0)(0).categories.length === 1)
+ assert(splits(0)(0).categories.contains(0.0))
+ assert(splits(1)(0).feature === 1)
+ assert(splits(1)(0).threshold === Double.MinValue)
+ assert(splits(1)(0).featureType === Categorical)
+ assert(splits(1)(0).categories.length === 1)
+ assert(splits(1)(0).categories.contains(0.0))
+
+ assert(splits(0)(1).feature === 0)
+ assert(splits(0)(1).threshold === Double.MinValue)
+ assert(splits(0)(1).featureType === Categorical)
+ assert(splits(0)(1).categories.length === 1)
+ assert(splits(0)(1).categories.contains(1.0))
+ assert(splits(1)(1).feature === 1)
+ assert(splits(1)(1).threshold === Double.MinValue)
+ assert(splits(1)(1).featureType === Categorical)
+ assert(splits(1)(1).categories.length === 1)
+ assert(splits(1)(1).categories.contains(1.0))
+
+ assert(splits(0)(2).feature === 0)
+ assert(splits(0)(2).threshold === Double.MinValue)
+ assert(splits(0)(2).featureType === Categorical)
+ assert(splits(0)(2).categories.length === 2)
+ assert(splits(0)(2).categories.contains(0.0))
+ assert(splits(0)(2).categories.contains(1.0))
+ assert(splits(1)(2).feature === 1)
+ assert(splits(1)(2).threshold === Double.MinValue)
+ assert(splits(1)(2).featureType === Categorical)
+ assert(splits(1)(2).categories.length === 2)
+ assert(splits(1)(2).categories.contains(0.0))
+ assert(splits(1)(2).categories.contains(1.0))
+
+ assert(splits(0)(3) === null)
+ assert(splits(1)(3) === null)
+
+
+ // Check bins.
+
+ assert(bins(0)(0).category === Double.MinValue)
+ assert(bins(0)(0).lowSplit.categories.length === 0)
+ assert(bins(0)(0).highSplit.categories.length === 1)
+ assert(bins(0)(0).highSplit.categories.contains(0.0))
+ assert(bins(1)(0).category === Double.MinValue)
+ assert(bins(1)(0).lowSplit.categories.length === 0)
+ assert(bins(1)(0).highSplit.categories.length === 1)
+ assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+ assert(bins(0)(1).category === Double.MinValue)
+ assert(bins(0)(1).lowSplit.categories.length === 1)
+ assert(bins(0)(1).lowSplit.categories.contains(0.0))
+ assert(bins(0)(1).highSplit.categories.length === 1)
+ assert(bins(0)(1).highSplit.categories.contains(1.0))
+ assert(bins(1)(1).category === Double.MinValue)
+ assert(bins(1)(1).lowSplit.categories.length === 1)
+ assert(bins(1)(1).lowSplit.categories.contains(0.0))
+ assert(bins(1)(1).highSplit.categories.length === 1)
+ assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+ assert(bins(0)(2).category === Double.MinValue)
+ assert(bins(0)(2).lowSplit.categories.length === 1)
+ assert(bins(0)(2).lowSplit.categories.contains(1.0))
+ assert(bins(0)(2).highSplit.categories.length === 2)
+ assert(bins(0)(2).highSplit.categories.contains(1.0))
+ assert(bins(0)(2).highSplit.categories.contains(0.0))
+ assert(bins(1)(2).category === Double.MinValue)
+ assert(bins(1)(2).lowSplit.categories.length === 1)
+ assert(bins(1)(2).lowSplit.categories.contains(1.0))
+ assert(bins(1)(2).highSplit.categories.length === 2)
+ assert(bins(1)(2).highSplit.categories.contains(1.0))
+ assert(bins(1)(2).highSplit.categories.contains(0.0))
+
+ assert(bins(0)(3) === null)
+ assert(bins(1)(3) === null)
+
+ }
+
+ test("split and bin calculations for ordered categorical variables with multiclass " +
+ "classification") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+ assert(arr.length === 3000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Classification,
+ Gini,
+ maxDepth = 3,
+ numClassesForClassification = 100,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+ // 2^10 - 1 > 100, so categorical variables will be ordered
+
+ assert(splits(0)(0).feature === 0)
+ assert(splits(0)(0).threshold === Double.MinValue)
+ assert(splits(0)(0).featureType === Categorical)
+ assert(splits(0)(0).categories.length === 1)
+ assert(splits(0)(0).categories.contains(1.0))
+
+ assert(splits(0)(1).feature === 0)
+ assert(splits(0)(1).threshold === Double.MinValue)
+ assert(splits(0)(1).featureType === Categorical)
+ assert(splits(0)(1).categories.length === 2)
+ assert(splits(0)(1).categories.contains(2.0))
+
+ assert(splits(0)(2).feature === 0)
+ assert(splits(0)(2).threshold === Double.MinValue)
+ assert(splits(0)(2).featureType === Categorical)
+ assert(splits(0)(2).categories.length === 3)
+ assert(splits(0)(2).categories.contains(2.0))
+ assert(splits(0)(2).categories.contains(1.0))
+
+ assert(splits(0)(10) === null)
+ assert(splits(1)(10) === null)
+
+
+ // Check bins.
+
+ assert(bins(0)(0).category === 1.0)
+ assert(bins(0)(0).lowSplit.categories.length === 0)
+ assert(bins(0)(0).highSplit.categories.length === 1)
+ assert(bins(0)(0).highSplit.categories.contains(1.0))
+ assert(bins(0)(1).category === 2.0)
+ assert(bins(0)(1).lowSplit.categories.length === 1)
+ assert(bins(0)(1).highSplit.categories.length === 2)
+ assert(bins(0)(1).highSplit.categories.contains(1.0))
+ assert(bins(0)(1).highSplit.categories.contains(2.0))
+
+ assert(bins(0)(10) === null)
+
+ }
+
+
test("classification stump with all categorical variables") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
@@ -238,6 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategy = new Strategy(
Classification,
Gini,
+ numClassesForClassification = 2,
maxDepth = 3,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
@@ -253,8 +412,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = bestSplits(0)._2
assert(stats.gain > 0)
- assert(stats.predict > 0.5)
- assert(stats.predict < 0.7)
+ assert(stats.predict === 1)
+ assert(stats.prob == 0.6)
assert(stats.impurity > 0.2)
}
@@ -280,8 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val stats = bestSplits(0)._2
assert(stats.gain > 0)
- assert(stats.predict > 0.5)
- assert(stats.predict < 0.7)
+ assert(stats.predict == 0.6)
assert(stats.impurity > 0.2)
}
@@ -289,7 +447,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Gini, 3, 100)
+ val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
@@ -312,7 +470,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Gini, 3, 100)
+ val strategy = new Strategy(Classification, Gini, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
@@ -336,7 +494,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
@@ -360,7 +518,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
@@ -380,11 +538,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bestSplits(0)._2.predict === 1)
}
- test("test second level node building with/without groups") {
+ test("second level node building with/without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 100)
+ val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
@@ -426,6 +584,82 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
}
+ test("stump with categorical variables for multiclass classification") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+ assert(strategy.isMulticlassClassification)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+ val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+ Array[List[Filter]](), splits, bins, 10)
+
+ assert(bestSplits.length === 1)
+ val bestSplit = bestSplits(0)._1
+ assert(bestSplit.feature === 0)
+ assert(bestSplit.categories.length === 1)
+ assert(bestSplit.categories.contains(1))
+ assert(bestSplit.featureType === Categorical)
+ }
+
+ test("stump with continuous variables for multiclass classification") {
+ val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClassesForClassification = 3)
+ assert(strategy.isMulticlassClassification)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+ val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+ Array[List[Filter]](), splits, bins, 10)
+
+ assert(bestSplits.length === 1)
+ val bestSplit = bestSplits(0)._1
+
+ assert(bestSplit.feature === 1)
+ assert(bestSplit.featureType === Continuous)
+ assert(bestSplit.threshold > 1980)
+ assert(bestSplit.threshold < 2020)
+
+ }
+
+ test("stump with continuous + categorical variables for multiclass classification") {
+ val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
+ assert(strategy.isMulticlassClassification)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+ val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+ Array[List[Filter]](), splits, bins, 10)
+
+ assert(bestSplits.length === 1)
+ val bestSplit = bestSplits(0)._1
+
+ assert(bestSplit.feature === 1)
+ assert(bestSplit.featureType === Continuous)
+ assert(bestSplit.threshold > 1980)
+ assert(bestSplit.threshold < 2020)
+ }
+
+ test("stump with categorical variables for ordered multiclass classification") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+ val input = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+ assert(strategy.isMulticlassClassification)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+ val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+ Array[List[Filter]](), splits, bins, 10)
+
+ assert(bestSplits.length === 1)
+ val bestSplit = bestSplits(0)._1
+ assert(bestSplit.feature === 0)
+ assert(bestSplit.categories.length === 1)
+ assert(bestSplit.categories.contains(1.0))
+ assert(bestSplit.featureType === Categorical)
+ }
+
+
}
object DecisionTreeSuite {
@@ -473,4 +707,47 @@ object DecisionTreeSuite {
}
arr
}
+
+ def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](3000)
+ for (i <- 0 until 3000) {
+ if (i < 1000) {
+ arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+ } else if (i < 2000) {
+ arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
+ } else {
+ arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+ }
+ }
+ arr
+ }
+
+ def generateContinuousDataPointsForMulticlass(): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](3000)
+ for (i <- 0 until 3000) {
+ if (i < 2000) {
+ arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, i))
+ } else {
+ arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, i))
+ }
+ }
+ arr
+ }
+
+ def generateCategoricalDataPointsForMulticlassForOrderedFeatures():
+ Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](3000)
+ for (i <- 0 until 3000) {
+ if (i < 1000) {
+ arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+ } else if (i < 2000) {
+ arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
+ } else {
+ arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0))
+ }
+ }
+ arr
+ }
+
+
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 3487f7c5c1..e0f433b26f 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -82,7 +82,15 @@ object MimaExcludes {
MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++
MimaBuild.excludeSparkClass("storage.Values") ++
MimaBuild.excludeSparkClass("storage.Entry") ++
- MimaBuild.excludeSparkClass("storage.MemoryStore$Entry")
+ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++
+ Seq(
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.mllib.tree.impurity.Gini.calculate"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.mllib.tree.impurity.Entropy.calculate"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.mllib.tree.impurity.Variance.calculate")
+ )
case v if v.startsWith("1.0") =>
Seq(
MimaBuild.excludeSparkPackage("api.java"),