aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala59
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala4
9 files changed, 54 insertions, 49 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 91dc98569a..dd9a5f261f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -244,8 +244,7 @@ private[ml] object RandomForest extends Logging {
if (unorderedFeatures.contains(featureIndex)) {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
- val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
- agg.getLeftRightFeatureOffsets(featureIndexIdx)
+ val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
val featureSplits = splits(featureIndex)
@@ -253,8 +252,6 @@ private[ml] object RandomForest extends Logging {
while (splitIndex < numSplits) {
if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
- } else {
- agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
}
splitIndex += 1
}
@@ -394,6 +391,7 @@ private[ml] object RandomForest extends Logging {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
+ agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
}
}
@@ -658,7 +656,7 @@ private[ml] object RandomForest extends Logging {
// Calculate InformationGain and ImpurityStats if current node is top node
val level = LearningNode.indexToLevel(node.id)
- var gainAndImpurityStats: ImpurityStats = if (level ==0) {
+ var gainAndImpurityStats: ImpurityStats = if (level == 0) {
null
} else {
node.stats
@@ -697,13 +695,12 @@ private[ml] object RandomForest extends Logging {
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
- val (leftChildOffset, rightChildOffset) =
- binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+ val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
- val rightChildStats =
- binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+ val rightChildStats = binAggregates.getParentImpurityCalculator()
+ .subtract(leftChildStats)
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
leftChildStats, rightChildStats, binAggregates.metadata)
(splitIndex, gainAndImpurityStats)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 18f66e65f1..c0934d241f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -52,6 +52,7 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
/**
* Method to train a decision tree model over an RDD
+ *
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
* @return DecisionTreeModel that can be used for prediction.
*/
@@ -368,8 +369,7 @@ object DecisionTree extends Serializable with Logging {
if (unorderedFeatures.contains(featureIndex)) {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
- val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
- agg.getLeftRightFeatureOffsets(featureIndexIdx)
+ val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
@@ -377,9 +377,6 @@ object DecisionTree extends Serializable with Logging {
if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
- } else {
- agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
- instanceWeight)
}
splitIndex += 1
}
@@ -521,6 +518,7 @@ object DecisionTree extends Serializable with Logging {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
+ agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
}
}
@@ -847,13 +845,12 @@ object DecisionTree extends Serializable with Logging {
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
- val (leftChildOffset, rightChildOffset) =
- binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+ val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
- val rightChildStats =
- binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+ val rightChildStats = binAggregates.getParentImpurityCalculator()
+ .subtract(leftChildStats)
predictWithImpurity = Some(predictWithImpurity.getOrElse(
calculatePredictImpurity(leftChildStats, rightChildStats)))
val gainStats = calculateGainForSplit(leftChildStats,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 7985ed4b4c..c745e9f8db 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -73,26 +73,34 @@ private[spark] class DTStatsAggregator(
* Flat array of elements.
* Index for start of stats for a (feature, bin) is:
* index = featureOffsets(featureIndex) + binIndex * statsSize
- * Note: For unordered features,
- * the left child stats have binIndex in [0, numBins(featureIndex) / 2))
- * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
*/
private val allStats: Array[Double] = new Array[Double](allStatsSize)
+ /**
+ * Array of parent node sufficient stats.
+ *
+ * Note: this is necessary because stats for the parent node are not available
+ * on the first iteration of tree learning.
+ */
+ private val parentStats: Array[Double] = new Array[Double](statsSize)
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
- * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset
+ * @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
- * For unordered features, this is a pre-computed
- * (node, feature, left/right child) offset from
- * [[getLeftRightFeatureOffsets]].
*/
def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
}
/**
+ * Get an [[ImpurityCalculator]] for the parent node.
+ */
+ def getParentImpurityCalculator(): ImpurityCalculator = {
+ impurityAggregator.getCalculator(parentStats, 0)
+ }
+
+ /**
* Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
@@ -101,13 +109,17 @@ private[spark] class DTStatsAggregator(
}
/**
+ * Update the parent node stats using the given label.
+ */
+ def updateParent(label: Double, instanceWeight: Double): Unit = {
+ impurityAggregator.update(parentStats, 0, label, instanceWeight)
+ }
+
+ /**
* Faster version of [[update]].
* Update the stats for a given (feature, bin), using the given label.
- * @param featureOffset For ordered features, this is a pre-computed feature offset
+ * @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
- * For unordered features, this is a pre-computed
- * (feature, left/right child) offset from
- * [[getLeftRightFeatureOffsets]].
*/
def featureUpdate(
featureOffset: Int,
@@ -125,21 +137,9 @@ private[spark] class DTStatsAggregator(
def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
/**
- * Pre-compute feature offset for use with [[featureUpdate]].
- * For unordered features only.
- */
- def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
- val baseOffset = featureOffsets(featureIndex)
- (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
- }
-
- /**
* For a given feature, merge the stats for two bins.
- * @param featureOffset For ordered features, this is a pre-computed feature offset
+ * @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
- * For unordered features, this is a pre-computed
- * (feature, left/right child) offset from
- * [[getLeftRightFeatureOffsets]].
* @param binIndex The other bin is merged into this bin.
* @param otherBinIndex This bin is not modified.
*/
@@ -162,6 +162,17 @@ private[spark] class DTStatsAggregator(
allStats(i) += other.allStats(i)
i += 1
}
+
+ require(statsSize == other.statsSize,
+ s"DTStatsAggregator.merge requires that both aggregators have the same length parent " +
+ s"stats vectors. This aggregator's parent stats are length $statsSize, " +
+ s"but the other is ${other.statsSize}.")
+ var j = 0
+ while (j < statsSize) {
+ parentStats(j) += other.parentStats(j)
+ j += 1
+ }
+
this
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index df13d291ca..4f27dc44ef 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -67,11 +67,11 @@ private[spark] class DecisionTreeMetadata(
/**
* Number of splits for the given feature.
- * For unordered features, there are 2 bins per split.
+ * For unordered features, there is 1 bin per split.
* For ordered features, there is 1 more bin than split.
*/
def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
- numBins(featureIndex) >> 1
+ numBins(featureIndex)
} else {
numBins(featureIndex) - 1
}
@@ -212,6 +212,6 @@ private[spark] object DecisionTreeMetadata extends Logging {
* there are math.pow(2, arity - 1) - 1 such splits.
* Each split has 2 corresponding bins.
*/
- def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
+ def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 73df6b054a..13aff11007 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -113,7 +113,6 @@ private[tree] class EntropyAggregator(numClasses: Int)
def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
}
-
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index f21845b21a..39c7f9c3be 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -109,7 +109,6 @@ private[tree] class GiniAggregator(numClasses: Int)
def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
}
-
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index b2c6e2bba4..65f0163ec6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -89,7 +89,6 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator
-
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 09017d482a..92d74a1b83 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -93,7 +93,6 @@ private[tree] class VarianceAggregator()
def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
}
-
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 5518bdf527..89b64fce96 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -189,6 +189,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(bins.length === 2)
assert(splits(0).length === 3)
assert(bins(0).length === 0)
+ assert(metadata.numSplits(0) === 3)
+ assert(metadata.numBins(0) === 3)
+ assert(metadata.numSplits(1) === 3)
+ assert(metadata.numBins(1) === 3)
// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)