aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-17 16:44:41 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-17 16:44:41 -0700
commit1614485fd92fc94bc3989da49be612e542b93fb8 (patch)
treeca71f1fee26fcb993316bd0a6e545ae2003ab107
parentb39e80d39dae8e6779f9d78c1631a27585239032 (diff)
downloadspark-1614485fd92fc94bc3989da49be612e542b93fb8.tar.gz
spark-1614485fd92fc94bc3989da49be612e542b93fb8.tar.bz2
spark-1614485fd92fc94bc3989da49be612e542b93fb8.zip
[SPARK-10788][MLLIB][ML] Remove duplicate bins for decision trees
Decision trees in spark.ml (RandomForest.scala) communicate twice as much data as needed for unordered categorical features. Here's an example. Say there are 3 categories A, B, C. We consider 3 splits: * A vs. B, C * A, B vs. C * A, C vs. B Currently, we collect statistics for each of the 6 subsets of categories (3 * 2 = 6). However, we could instead collect statistics for the 3 subsets on the left-hand side of the 3 possible splits: A and A,B and A,C. If we also have stats for the entire node, then we can compute the stats for the 3 subsets on the right-hand side of the splits. In pseudomath: stats(B,C) = stats(A,B,C) - stats(A). This patch adds a parent stats array to the `DTStatsAggregator` so that the right child stats do not need to be stored. The right child stats are computed by subtracting left child stats from the parent stats for unordered categorical features. Author: sethah <seth.hendrickson16@gmail.com> Closes #9474 from sethah/SPARK-10788.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala59
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala4
9 files changed, 54 insertions, 49 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 91dc98569a..dd9a5f261f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -244,8 +244,7 @@ private[ml] object RandomForest extends Logging {
if (unorderedFeatures.contains(featureIndex)) {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
- val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
- agg.getLeftRightFeatureOffsets(featureIndexIdx)
+ val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
val featureSplits = splits(featureIndex)
@@ -253,8 +252,6 @@ private[ml] object RandomForest extends Logging {
while (splitIndex < numSplits) {
if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
- } else {
- agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
}
splitIndex += 1
}
@@ -394,6 +391,7 @@ private[ml] object RandomForest extends Logging {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
+ agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
}
}
@@ -658,7 +656,7 @@ private[ml] object RandomForest extends Logging {
// Calculate InformationGain and ImpurityStats if current node is top node
val level = LearningNode.indexToLevel(node.id)
- var gainAndImpurityStats: ImpurityStats = if (level ==0) {
+ var gainAndImpurityStats: ImpurityStats = if (level == 0) {
null
} else {
node.stats
@@ -697,13 +695,12 @@ private[ml] object RandomForest extends Logging {
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
- val (leftChildOffset, rightChildOffset) =
- binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+ val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
- val rightChildStats =
- binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+ val rightChildStats = binAggregates.getParentImpurityCalculator()
+ .subtract(leftChildStats)
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
leftChildStats, rightChildStats, binAggregates.metadata)
(splitIndex, gainAndImpurityStats)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 18f66e65f1..c0934d241f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -52,6 +52,7 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
/**
* Method to train a decision tree model over an RDD
+ *
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return DecisionTreeModel that can be used for prediction.
*/
@@ -368,8 +369,7 @@ object DecisionTree extends Serializable with Logging {
if (unorderedFeatures.contains(featureIndex)) {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
- val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
- agg.getLeftRightFeatureOffsets(featureIndexIdx)
+ val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
@@ -377,9 +377,6 @@ object DecisionTree extends Serializable with Logging {
if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
- } else {
- agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
- instanceWeight)
}
splitIndex += 1
}
@@ -521,6 +518,7 @@ object DecisionTree extends Serializable with Logging {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
+ agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
}
}
@@ -847,13 +845,12 @@ object DecisionTree extends Serializable with Logging {
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
- val (leftChildOffset, rightChildOffset) =
- binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+ val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
- val rightChildStats =
- binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+ val rightChildStats = binAggregates.getParentImpurityCalculator()
+ .subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 7985ed4b4c..c745e9f8db 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -73,26 +73,34 @@ private[spark] class DTStatsAggregator(
* Flat array of elements.
* Index for start of stats for a (feature, bin) is:
* index = featureOffsets(featureIndex) + binIndex * statsSize
- * Note: For unordered features,
- * the left child stats have binIndex in [0, numBins(featureIndex) / 2))
- * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
*/
private val allStats: Array[Double] = new Array[Double](allStatsSize)
+ /**
+ * Array of parent node sufficient stats.
+ *
+ * Note: this is necessary because stats for the parent node are not available
+ * on the first iteration of tree learning.
+ */
+ private val parentStats: Array[Double] = new Array[Double](statsSize)
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
- * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset
+ * @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
- * For unordered features, this is a pre-computed
- * (node, feature, left/right child) offset from
- * [[getLeftRightFeatureOffsets]].
*/
def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
}
/**
+ * Get an [[ImpurityCalculator]] for the parent node.
+ */
+ def getParentImpurityCalculator(): ImpurityCalculator = {
+ impurityAggregator.getCalculator(parentStats, 0)
+ }
+
+ /**
* Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
@@ -101,13 +109,17 @@ private[spark] class DTStatsAggregator(
}
/**
+ * Update the parent node stats using the given label.
+ */
+ def updateParent(label: Double, instanceWeight: Double): Unit = {
+ impurityAggregator.update(parentStats, 0, label, instanceWeight)
+ }
+
+ /**
* Faster version of [[update]].
* Update the stats for a given (feature, bin), using the given label.
- * @param featureOffset For ordered features, this is a pre-computed feature offset
+ * @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
- * For unordered features, this is a pre-computed
- * (feature, left/right child) offset from
- * [[getLeftRightFeatureOffsets]].
*/
def featureUpdate(
featureOffset: Int,
@@ -125,21 +137,9 @@ private[spark] class DTStatsAggregator(
def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
/**
- * Pre-compute feature offset for use with [[featureUpdate]].
- * For unordered features only.
- */
- def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
- val baseOffset = featureOffsets(featureIndex)
- (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
- }
-
- /**
* For a given feature, merge the stats for two bins.
- * @param featureOffset For ordered features, this is a pre-computed feature offset
+ * @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
- * For unordered features, this is a pre-computed
- * (feature, left/right child) offset from
- * [[getLeftRightFeatureOffsets]].
* @param binIndex The other bin is merged into this bin.
* @param otherBinIndex This bin is not modified.
*/
@@ -162,6 +162,17 @@ private[spark] class DTStatsAggregator(
allStats(i) += other.allStats(i)
i += 1
}
+
+ require(statsSize == other.statsSize,
+ s"DTStatsAggregator.merge requires that both aggregators have the same length parent " +
+ s"stats vectors. This aggregator's parent stats are length $statsSize, " +
+ s"but the other is ${other.statsSize}.")
+ var j = 0
+ while (j < statsSize) {
+ parentStats(j) += other.parentStats(j)
+ j += 1
+ }
+
this
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index df13d291ca..4f27dc44ef 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -67,11 +67,11 @@ private[spark] class DecisionTreeMetadata(
/**
* Number of splits for the given feature.
- * For unordered features, there are 2 bins per split.
+ * For unordered features, there is 1 bin per split.
* For ordered features, there is 1 more bin than split.
*/
def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
- numBins(featureIndex) >> 1
+ numBins(featureIndex)
} else {
numBins(featureIndex) - 1
}
@@ -212,6 +212,6 @@ private[spark] object DecisionTreeMetadata extends Logging {
* there are math.pow(2, arity - 1) - 1 such splits.
* Each split has 2 corresponding bins.
*/
- def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
+ def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 73df6b054a..13aff11007 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -113,7 +113,6 @@ private[tree] class EntropyAggregator(numClasses: Int)
def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
}
-
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index f21845b21a..39c7f9c3be 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -109,7 +109,6 @@ private[tree] class GiniAggregator(numClasses: Int)
def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
}
-
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index b2c6e2bba4..65f0163ec6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -89,7 +89,6 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator
-
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 09017d482a..92d74a1b83 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -93,7 +93,6 @@ private[tree] class VarianceAggregator()
def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
}
-
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 5518bdf527..89b64fce96 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -189,6 +189,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(bins.length === 2)
assert(splits(0).length === 3)
assert(bins(0).length === 0)
+ assert(metadata.numSplits(0) === 3)
+ assert(metadata.numBins(0) === 3)
+ assert(metadata.numSplits(1) === 3)
+ assert(metadata.numBins(1) === 3)
// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)