diff options
9 files changed, 54 insertions, 49 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 91dc98569a..dd9a5f261f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -244,8 +244,7 @@ private[ml] object RandomForest extends Logging { if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) - val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightFeatureOffsets(featureIndexIdx) + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) // Update the left or right bin for each split. val numSplits = agg.metadata.numSplits(featureIndex) val featureSplits = splits(featureIndex) @@ -253,8 +252,6 @@ private[ml] object RandomForest extends Logging { while (splitIndex < numSplits) { if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) - } else { - agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } splitIndex += 1 } @@ -394,6 +391,7 @@ private[ml] object RandomForest extends Logging { mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, metadata.unorderedFeatures, instanceWeight, featuresForNode) } + agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) } } @@ -658,7 +656,7 @@ private[ml] object RandomForest extends Logging { // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) - var gainAndImpurityStats: ImpurityStats = if (level ==0) { + var gainAndImpurityStats: ImpurityStats = if (level == 0) { null } else { node.stats @@ -697,13 +695,12 @@ private[ml] object RandomForest extends Logging { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = - binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, leftChildStats, rightChildStats, binAggregates.metadata) (splitIndex, gainAndImpurityStats) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 18f66e65f1..c0934d241f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -52,6 +52,7 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy) /** * Method to train a decision tree model over an RDD + * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return DecisionTreeModel that can be used for prediction. */ @@ -368,8 +369,7 @@ object DecisionTree extends Serializable with Logging { if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) - val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightFeatureOffsets(featureIndexIdx) + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) // Update the left or right bin for each split. val numSplits = agg.metadata.numSplits(featureIndex) var splitIndex = 0 @@ -377,9 +377,6 @@ object DecisionTree extends Serializable with Logging { if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) - } else { - agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, - instanceWeight) } splitIndex += 1 } @@ -521,6 +518,7 @@ object DecisionTree extends Serializable with Logging { mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, metadata.unorderedFeatures, instanceWeight, featuresForNode) } + agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) } } @@ -847,13 +845,12 @@ object DecisionTree extends Serializable with Logging { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = - binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) predictWithImpurity = Some(predictWithImpurity.getOrElse( calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 7985ed4b4c..c745e9f8db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -73,26 +73,34 @@ private[spark] class DTStatsAggregator( * Flat array of elements. * Index for start of stats for a (feature, bin) is: * index = featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, - * the left child stats have binIndex in [0, numBins(featureIndex) / 2)) - * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex)) */ private val allStats: Array[Double] = new Array[Double](allStatsSize) + /** + * Array of parent node sufficient stats. + * + * Note: this is necessary because stats for the parent node are not available + * on the first iteration of tree learning. + */ + private val parentStats: Array[Double] = new Array[Double](statsSize) /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). - * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset + * @param featureOffset This is a pre-computed (node, feature) offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (node, feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. */ def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) } /** + * Get an [[ImpurityCalculator]] for the parent node. + */ + def getParentImpurityCalculator(): ImpurityCalculator = { + impurityAggregator.getCalculator(parentStats, 0) + } + + /** * Update the stats for a given (feature, bin) for ordered features, using the given label. */ def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { @@ -101,13 +109,17 @@ private[spark] class DTStatsAggregator( } /** + * Update the parent node stats using the given label. + */ + def updateParent(label: Double, instanceWeight: Double): Unit = { + impurityAggregator.update(parentStats, 0, label, instanceWeight) + } + + /** * Faster version of [[update]]. * Update the stats for a given (feature, bin), using the given label. - * @param featureOffset For ordered features, this is a pre-computed feature offset + * @param featureOffset This is a pre-computed feature offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. */ def featureUpdate( featureOffset: Int, @@ -125,21 +137,9 @@ private[spark] class DTStatsAggregator( def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) /** - * Pre-compute feature offset for use with [[featureUpdate]]. - * For unordered features only. - */ - def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { - val baseOffset = featureOffsets(featureIndex) - (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) - } - - /** * For a given feature, merge the stats for two bins. - * @param featureOffset For ordered features, this is a pre-computed feature offset + * @param featureOffset This is a pre-computed feature offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. * @param binIndex The other bin is merged into this bin. * @param otherBinIndex This bin is not modified. */ @@ -162,6 +162,17 @@ private[spark] class DTStatsAggregator( allStats(i) += other.allStats(i) i += 1 } + + require(statsSize == other.statsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length parent " + + s"stats vectors. This aggregator's parent stats are length $statsSize, " + + s"but the other is ${other.statsSize}.") + var j = 0 + while (j < statsSize) { + parentStats(j) += other.parentStats(j) + j += 1 + } + this } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index df13d291ca..4f27dc44ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -67,11 +67,11 @@ private[spark] class DecisionTreeMetadata( /** * Number of splits for the given feature. - * For unordered features, there are 2 bins per split. + * For unordered features, there is 1 bin per split. * For ordered features, there is 1 more bin than split. */ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { - numBins(featureIndex) >> 1 + numBins(featureIndex) } else { numBins(featureIndex) - 1 } @@ -212,6 +212,6 @@ private[spark] object DecisionTreeMetadata extends Logging { * there are math.pow(2, arity - 1) - 1 such splits. * Each split has 2 corresponding bins. */ - def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1) + def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 73df6b054a..13aff11007 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -113,7 +113,6 @@ private[tree] class EntropyAggregator(numClasses: Int) def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) } - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index f21845b21a..39c7f9c3be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -109,7 +109,6 @@ private[tree] class GiniAggregator(numClasses: Int) def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = { new GiniCalculator(allStats.view(offset, offset + statsSize).toArray) } - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index b2c6e2bba4..65f0163ec6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -89,7 +89,6 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 09017d482a..92d74a1b83 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -93,7 +93,6 @@ private[tree] class VarianceAggregator() def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = { new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) } - } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5518bdf527..89b64fce96 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -189,6 +189,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(bins.length === 2) assert(splits(0).length === 3) assert(bins(0).length === 0) + assert(metadata.numSplits(0) === 3) + assert(metadata.numBins(0) === 3) + assert(metadata.numSplits(1) === 3) + assert(metadata.numBins(1) === 3) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) |