aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2016-02-09 17:10:55 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-02-09 17:10:55 -0800
commit9267bc68fab65c6a798e065a1dbe0f5171df3077 (patch)
treeafbf5313cbc324c134b2c9dc20ed56860bf7e427 /mllib
parent0e5ebac3c1f1ff58f938be59c7c9e604977d269c (diff)
downloadspark-9267bc68fab65c6a798e065a1dbe0f5171df3077.tar.gz
spark-9267bc68fab65c6a798e065a1dbe0f5171df3077.tar.bz2
spark-9267bc68fab65c6a798e065a1dbe0f5171df3077.zip
[SPARK-10524][ML] Use the soft prediction to order categories' bins
JIRA: https://issues.apache.org/jira/browse/SPARK-10524 Currently we use the hard prediction (`ImpurityCalculator.predict`) to order categories' bins. But we should use the soft prediction. Author: Liang-Chi Hsieh <viirya@gmail.com> Author: Liang-Chi Hsieh <viirya@appier.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #8734 from viirya/dt-soft-centroids.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala42
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala219
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala36
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala30
4 files changed, 194 insertions, 133 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index d3376a7dff..ea733d577a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -650,7 +650,7 @@ private[ml] object RandomForest extends Logging {
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
- private def binsToBestSplit(
+ private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
@@ -720,32 +720,30 @@ private[ml] object RandomForest extends Logging {
*
* centroidForCategories is a list: (category, centroid)
*/
- val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
- // For categorical variables in multiclass classification,
- // the bins are ordered by the impurity of their corresponding labels.
- Range(0, numCategories).map { case featureValue =>
- val categoryStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
- val centroid = if (categoryStats.count != 0) {
+ val centroidForCategories = Range(0, numCategories).map { case featureValue =>
+ val categoryStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val centroid = if (categoryStats.count != 0) {
+ if (binAggregates.metadata.isMulticlass) {
+ // multiclass classification
+ // For categorical variables in multiclass classification,
+ // the bins are ordered by the impurity of their corresponding labels.
categoryStats.calculate()
+ } else if (binAggregates.metadata.isClassification) {
+ // binary classification
+ // For categorical variables in binary classification,
+ // the bins are ordered by the count of class 1.
+ categoryStats.stats(1)
} else {
- Double.MaxValue
- }
- (featureValue, centroid)
- }
- } else { // regression or binary classification
- // For categorical variables in regression and binary classification,
- // the bins are ordered by the centroid of their corresponding labels.
- Range(0, numCategories).map { case featureValue =>
- val categoryStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
- val centroid = if (categoryStats.count != 0) {
+ // regression
+ // For categorical variables in regression and binary classification,
+ // the bins are ordered by the prediction.
categoryStats.predict
- } else {
- Double.MaxValue
}
- (featureValue, centroid)
+ } else {
+ Double.MaxValue
}
+ (featureValue, centroid)
}
logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 07ba0d8ccb..51235a2371 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -791,7 +791,7 @@ object DecisionTree extends Serializable with Logging {
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
- private def binsToBestSplit(
+ private[tree] def binsToBestSplit(
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
@@ -808,128 +808,127 @@ object DecisionTree extends Serializable with Logging {
// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
- val featureIndex = if (featuresForNode.nonEmpty) {
- featuresForNode.get.apply(featureIndexIdx)
- } else {
- featureIndexIdx
- }
- val numSplits = binAggregates.metadata.numSplits(featureIndex)
- if (binAggregates.metadata.isContinuous(featureIndex)) {
- // Cumulative sum (scanLeft) of bin statistics.
- // Afterwards, binAggregates for a bin is the sum of aggregates for
- // that bin + all preceding bins.
- val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
- var splitIndex = 0
- while (splitIndex < numSplits) {
- binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
- splitIndex += 1
+ val featureIndex = if (featuresForNode.nonEmpty) {
+ featuresForNode.get.apply(featureIndexIdx)
+ } else {
+ featureIndexIdx
}
- // Find best split.
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
- Range(0, numSplits).map { case splitIdx =>
- val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
- val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
- rightChildStats.subtract(leftChildStats)
- predictWithImpurity = Some(predictWithImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
- (splitIdx, gainStats)
- }.maxBy(_._2.gain)
- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
- } else if (binAggregates.metadata.isUnordered(featureIndex)) {
- // Unordered categorical feature
- val (leftChildOffset, rightChildOffset) =
- binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
- Range(0, numSplits).map { splitIndex =>
- val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
- val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
- predictWithImpurity = Some(predictWithImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
- (splitIndex, gainStats)
- }.maxBy(_._2.gain)
- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
- } else {
- // Ordered categorical feature
- val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
- val numBins = binAggregates.metadata.numBins(featureIndex)
-
- /* Each bin is one category (feature value).
- * The bins are ordered based on centroidForCategories, and this ordering determines which
- * splits are considered. (With K categories, we consider K - 1 possible splits.)
- *
- * centroidForCategories is a list: (category, centroid)
- */
- val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
- // For categorical variables in multiclass classification,
- // the bins are ordered by the impurity of their corresponding labels.
- Range(0, numBins).map { case featureValue =>
- val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
- val centroid = if (categoryStats.count != 0) {
- categoryStats.calculate()
- } else {
- Double.MaxValue
- }
- (featureValue, centroid)
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+ splitIndex += 1
}
- } else { // regression or binary classification
- // For categorical variables in regression and binary classification,
- // the bins are ordered by the centroid of their corresponding labels.
- Range(0, numBins).map { case featureValue =>
- val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { case splitIdx =>
+ val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
+ rightChildStats.subtract(leftChildStats)
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
+ (splitIdx, gainStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
+ // Unordered categorical feature
+ val (leftChildOffset, rightChildOffset) =
+ binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
+ (splitIndex, gainStats)
+ }.maxBy(_._2.gain)
+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
+ } else {
+ // Ordered categorical feature
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val numBins = binAggregates.metadata.numBins(featureIndex)
+
+ /* Each bin is one category (feature value).
+ * The bins are ordered based on centroidForCategories, and this ordering determines which
+ * splits are considered. (With K categories, we consider K - 1 possible splits.)
+ *
+ * centroidForCategories is a list: (category, centroid)
+ */
+ val centroidForCategories = Range(0, numBins).map { case featureValue =>
+ val categoryStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
val centroid = if (categoryStats.count != 0) {
- categoryStats.predict
+ if (binAggregates.metadata.isMulticlass) {
+ // For categorical variables in multiclass classification,
+ // the bins are ordered by the impurity of their corresponding labels.
+ categoryStats.calculate()
+ } else if (binAggregates.metadata.isClassification) {
+ // For categorical variables in binary classification,
+ // the bins are ordered by the count of class 1.
+ categoryStats.stats(1)
+ } else {
+ // For categorical variables in regression,
+ // the bins are ordered by the prediction.
+ categoryStats.predict
+ }
} else {
Double.MaxValue
}
(featureValue, centroid)
}
- }
- logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
+ logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))
- // bins sorted by centroids
- val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
+ // bins sorted by centroids
+ val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
- logDebug("Sorted centroids for categorical variable = " +
- categoriesSortedByCentroid.mkString(","))
+ logDebug("Sorted centroids for categorical variable = " +
+ categoriesSortedByCentroid.mkString(","))
- // Cumulative sum (scanLeft) of bin statistics.
- // Afterwards, binAggregates for a bin is the sum of aggregates for
- // that bin + all preceding bins.
- var splitIndex = 0
- while (splitIndex < numSplits) {
- val currentCategory = categoriesSortedByCentroid(splitIndex)._1
- val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
- binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
- splitIndex += 1
+ // Cumulative sum (scanLeft) of bin statistics.
+ // Afterwards, binAggregates for a bin is the sum of aggregates for
+ // that bin + all preceding bins.
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val currentCategory = categoriesSortedByCentroid(splitIndex)._1
+ val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
+ splitIndex += 1
+ }
+ // lastCategory = index of bin with total aggregates for this (node, feature)
+ val lastCategory = categoriesSortedByCentroid.last._1
+ // Find best split.
+ val (bestFeatureSplitIndex, bestFeatureGainStats) =
+ Range(0, numSplits).map { splitIndex =>
+ val featureValue = categoriesSortedByCentroid(splitIndex)._1
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
+ val rightChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
+ rightChildStats.subtract(leftChildStats)
+ predictWithImpurity = Some(predictWithImpurity.getOrElse(
+ calculatePredictImpurity(leftChildStats, rightChildStats)))
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
+ (splitIndex, gainStats)
+ }.maxBy(_._2.gain)
+ val categoriesForSplit =
+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
+ val bestFeatureSplit =
+ new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
+ (bestFeatureSplit, bestFeatureGainStats)
}
- // lastCategory = index of bin with total aggregates for this (node, feature)
- val lastCategory = categoriesSortedByCentroid.last._1
- // Find best split.
- val (bestFeatureSplitIndex, bestFeatureGainStats) =
- Range(0, numSplits).map { splitIndex =>
- val featureValue = categoriesSortedByCentroid(splitIndex)._1
- val leftChildStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
- val rightChildStats =
- binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
- rightChildStats.subtract(leftChildStats)
- predictWithImpurity = Some(predictWithImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
- (splitIndex, gainStats)
- }.maxBy(_._2.gain)
- val categoriesForSplit =
- categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
- val bestFeatureSplit =
- new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
- (bestFeatureSplit, bestFeatureGainStats)
- }
}.maxBy(_._2.gain)
(bestSplit, bestSplitStats, predictWithImpurity.get._1)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index fda2711fed..baf6b90839 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.LeafNode
+import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -275,6 +275,40 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
val model = dt.fit(df)
}
+ test("Use soft prediction for binary classification with ordered categorical features") {
+ // The following dataset is set up such that the best split is {1} vs. {0, 2}.
+ // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)))
+ val data = sc.parallelize(arr)
+ val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
+
+ // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("gini")
+ .setMaxDepth(1)
+ .setMaxBins(3)
+ val model = dt.fit(df)
+ model.rootNode match {
+ case n: InternalNode =>
+ n.split match {
+ case s: CategoricalSplit =>
+ assert(s.leftCategories === Array(1.0))
+ }
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index a9c935bd42..dca8ea815a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, Tree
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
@@ -337,6 +338,35 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(topNode.rightNode.get.impurity === 0.0)
}
+ test("Use soft prediction for binary classification with ordered categorical features") {
+ // The following dataset is set up such that the best split is {1} vs. {0, 2}.
+ // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)))
+ val input = sc.parallelize(arr)
+
+ // Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
+
+ val model = new DecisionTree(strategy).run(input)
+ model.topNode.split.get match {
+ case Split(_, _, _, categories: List[Double]) =>
+ assert(categories === List(1.0))
+ }
+ }
+
test("Second level node building with vs. without groups") {
val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
assert(arr.length === 1000)