aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorqiping.lqp <qiping.lqp@alibaba-inc.com>2014-10-03 03:26:17 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-03 03:26:17 -0700
commit2e4eae3a52e3d04895b00447d1ac56ae3c1b98ae (patch)
tree58acf9a7570d29de3d353383c3de1e428a056e9c /mllib
parent1c90347a4bba12df7b76d282a7dbac8e555e049f (diff)
downloadspark-2e4eae3a52e3d04895b00447d1ac56ae3c1b98ae.tar.gz
spark-2e4eae3a52e3d04895b00447d1ac56ae3c1b98ae.tar.bz2
spark-2e4eae3a52e3d04895b00447d1ac56ae3c1b98ae.zip
[SPARK-3366][MLLIB]Compute best splits distributively in decision tree
Currently, all best splits are computed on the driver, which makes the driver a bottleneck for both communication and computation. This PR fix this problem by computed best splits on executors. Instead of send all aggregate stats to the driver node, we can send aggregate stats for a node to a particular executor, using `reduceByKey` operation, then we can compute best split for this node there. Implementation details: Each node now has a nodeStatsAggregator, which save aggregate stats for all features and bins. First use mapPartition to compute node aggregate stats for all nodes in each partition. Then transform node aggregate stats to (nodeIndex, nodeStatsAggregator) pairs and use to `reduceByKey` operation to combine nodeStatsAggregator for the same node. After all stats have been combined, best splits can be computed for each node based on the node aggregate stats. Best split result is collected to driver to construct the decision tree. CC: mengxr manishamde jkbradley, please help me review this, thanks. Author: qiping.lqp <qiping.lqp@alibaba-inc.com> Author: chouqin <liqiping1991@gmail.com> Closes #2595 from chouqin/dt-dist-agg and squashes the following commits: db0d24a [chouqin] fix a minor bug and adjust code a0d9de3 [chouqin] adjust code based on comments 9f201a6 [chouqin] fix bug: statsSize -> allStatsSize a8a7ed0 [chouqin] Merge branch 'master' of https://github.com/apache/spark into dt-dist-agg f13b346 [chouqin] adjust randomforest comments c32636e [chouqin] adjust code based on comments ac6a505 [chouqin] adjust code based on comments 7bbb787 [chouqin] add comments bdd2a63 [qiping.lqp] fix test suite a75df27 [qiping.lqp] fix test suite b5b0bc2 [qiping.lqp] fix style e76414f [qiping.lqp] fix testsuite 748bd45 [qiping.lqp] fix type-mismatch bug 24eacd8 [qiping.lqp] fix type-mismatch bug 5f63d6c [qiping.lqp] add multiclassification using One-Vs-All strategy 4f56496 [qiping.lqp] fix bug f00fc22 [qiping.lqp] fix bug 532993a [qiping.lqp] Compute best splits distributively in decision tree
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala140
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala292
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala1
5 files changed, 182 insertions, 267 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b7dc373ebd..b311d10023 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -23,7 +23,6 @@ import scala.collection.mutable
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.Logging
-import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.configuration.Strategy
@@ -36,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.SparkContext._
/**
@@ -328,9 +328,8 @@ object DecisionTree extends Serializable with Logging {
* for each subset is updated.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (node, feature, bin).
+ * each (feature, bin).
* @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes).
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param unorderedFeatures Set of indices of unordered features.
* @param instanceWeight Weight (importance) of instance in dataset.
@@ -338,7 +337,6 @@ object DecisionTree extends Serializable with Logging {
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- nodeIndex: Int,
bins: Array[Array[Bin]],
unorderedFeatures: Set[Int],
instanceWeight: Double,
@@ -350,7 +348,6 @@ object DecisionTree extends Serializable with Logging {
// Use all features
agg.metadata.numFeatures
}
- val nodeOffset = agg.getNodeOffset(nodeIndex)
// Iterate over features.
var featureIndexIdx = 0
while (featureIndexIdx < numFeaturesPerNode) {
@@ -363,16 +360,16 @@ object DecisionTree extends Serializable with Logging {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
- agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
+ agg.getLeftRightFeatureOffsets(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
- agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
+ agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
} else {
- agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
+ agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
}
splitIndex += 1
@@ -380,8 +377,7 @@ object DecisionTree extends Serializable with Logging {
} else {
// Ordered feature
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label,
- instanceWeight)
+ agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
}
featureIndexIdx += 1
}
@@ -393,26 +389,24 @@ object DecisionTree extends Serializable with Logging {
* For each feature, the sufficient statistics of one bin are updated.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (node, feature, bin).
+ * each (feature, bin).
* @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes).
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def orderedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- nodeIndex: Int,
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val label = treePoint.label
- val nodeOffset = agg.getNodeOffset(nodeIndex)
+
// Iterate over features.
if (featuresForNode.nonEmpty) {
// Use subsampled features
var featureIndexIdx = 0
while (featureIndexIdx < featuresForNode.get.size) {
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
- agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight)
+ agg.update(featureIndexIdx, binIndex, label, instanceWeight)
featureIndexIdx += 1
}
} else {
@@ -421,7 +415,7 @@ object DecisionTree extends Serializable with Logging {
var featureIndex = 0
while (featureIndex < numFeatures) {
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight)
+ agg.update(featureIndex, binIndex, label, instanceWeight)
featureIndex += 1
}
}
@@ -496,8 +490,8 @@ object DecisionTree extends Serializable with Logging {
* @return agg
*/
def binSeqOp(
- agg: DTStatsAggregator,
- baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = {
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)
@@ -508,9 +502,9 @@ object DecisionTree extends Serializable with Logging {
val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode)
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
- mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures,
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
instanceWeight, featuresForNode)
}
}
@@ -518,30 +512,76 @@ object DecisionTree extends Serializable with Logging {
agg
}
- // Calculate bin aggregates.
- timer.start("aggregation")
- val binAggregates: DTStatsAggregator = {
- val initAgg = if (metadata.subsamplingFeatures) {
- new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo)
- } else {
- new DTStatsAggregatorFixedFeatures(metadata, numNodes)
+ /**
+ * Get node index in group --> features indices map,
+ * which is a short cut to find feature indices for a node given node index in group
+ * @param treeToNodeToIndexInfo
+ * @return
+ */
+ def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]])
+ : Option[Map[Int, Array[Int]]] = if (!metadata.subsamplingFeatures) {
+ None
+ } else {
+ val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
+ treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
+ nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
+ assert(nodeIndexInfo.featureSubset.isDefined)
+ mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
+ }
}
- input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
+ Some(mutableNodeToFeatures.toMap)
}
- timer.stop("aggregation")
// Calculate best splits for all nodes in the group
timer.start("chooseSplits")
+ // In each partition, iterate all instances and compute aggregate stats for each node,
+ // yield an (nodeIndex, nodeAggregateStats) pair for each node.
+ // After a `reduceByKey` operation,
+ // stats of a node will be shuffled to a particular partition and be combined together,
+ // then best splits for nodes are found there.
+ // Finally, only best Splits for nodes are collected to driver to construct decision tree.
+ val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
+ val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
+ val nodeToBestSplits =
+ input.mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOp(nodeStatsAggregators, _))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+ }.reduceByKey((a, b) => a.merge(b))
+ .map { case (nodeIndex, aggStats) =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+
+ // find best split for each node
+ val (split: Split, stats: InformationGainStats, predict: Predict) =
+ binsToBestSplit(aggStats, splits, featuresForNode)
+ (nodeIndex, (split, stats, predict))
+ }.collectAsMap()
+
+ timer.stop("chooseSplits")
+
// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
val nodeIndex = node.id
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val featuresForNode = nodeInfo.featureSubset
val (split: Split, stats: InformationGainStats, predict: Predict) =
- binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode)
+ nodeToBestSplits(aggNodeIndex)
logDebug("best split = " + split)
// Extract info for this node. Create children if not leaf.
@@ -565,7 +605,7 @@ object DecisionTree extends Serializable with Logging {
}
}
}
- timer.stop("chooseSplits")
+
}
/**
@@ -633,36 +673,33 @@ object DecisionTree extends Serializable with Logging {
/**
* Find the best split for a node.
* @param binAggregates Bin statistics.
- * @param nodeIndex Index into aggregates for node to split in this group.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
binAggregates: DTStatsAggregator,
- nodeIndex: Int,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
- val metadata: DecisionTreeMetadata = binAggregates.metadata
-
// calculate predict only once
var predict: Option[Predict] = None
// For each (feature, split), calculate the gain, and select the best (feature, split).
- val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx =>
+ val (bestSplit, bestSplitStats) =
+ Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
- val numSplits = metadata.numSplits(featureIndex)
- if (metadata.isContinuous(featureIndex)) {
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
- val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
- binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
}
// Find best split.
@@ -672,27 +709,29 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
- } else if (metadata.isUnordered(featureIndex)) {
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
- binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
+ binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
- val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
- val numBins = metadata.numBins(featureIndex)
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val numBins = binAggregates.metadata.numBins(featureIndex)
/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
@@ -700,7 +739,7 @@ object DecisionTree extends Serializable with Logging {
*
* centroidForCategories is a list: (category, centroid)
*/
- val centroidForCategories = if (metadata.isMulticlass) {
+ val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
@@ -741,7 +780,7 @@ object DecisionTree extends Serializable with Logging {
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
- binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory)
+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
}
// lastCategory = index of bin with total aggregates for this (node, feature)
@@ -756,7 +795,8 @@ object DecisionTree extends Serializable with Logging {
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 7fa7725e79..fa7a26f17c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -171,8 +171,8 @@ private class RandomForest (
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
- DecisionTree.findBestSplits(baggedInput,
- metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
+ DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
+ treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
timer.stop("findBestSplits")
}
@@ -382,6 +382,7 @@ object RandomForest extends Serializable with Logging {
* @param maxMemoryUsage Bound on size of aggregate statistics.
* @return (nodesForGroup, treeToNodeToIndexInfo).
* nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+ *
* treeToNodeToIndexInfo holds indices selected features for each node:
* treeIndex --> (global) node index --> (node index in group, feature indices).
* The (global) node index is the index in the tree; the node index in group is the
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index d49df7a016..55f422dff0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -17,17 +17,19 @@
package org.apache.spark.mllib.tree.impl
-import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.impurity._
+
+
/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * DecisionTree statistics aggregator for a node.
+ * This holds a flat array of statistics for a set of (features, bins)
* and helps with indexing.
* This class is abstract to support learning with and without feature subsampling.
*/
-private[tree] abstract class DTStatsAggregator(
- val metadata: DecisionTreeMetadata) extends Serializable {
+private[tree] class DTStatsAggregator(
+ val metadata: DecisionTreeMetadata,
+ featureSubset: Option[Array[Int]]) extends Serializable {
/**
* [[ImpurityAggregator]] instance specifying the impurity type.
@@ -42,7 +44,25 @@ private[tree] abstract class DTStatsAggregator(
/**
* Number of elements (Double values) used for the sufficient statistics of each bin.
*/
- val statsSize: Int = impurityAggregator.statsSize
+ private val statsSize: Int = impurityAggregator.statsSize
+
+ /**
+ * Number of bins for each feature. This is indexed by the feature index.
+ */
+ private val numBins: Array[Int] = {
+ if (featureSubset.isDefined) {
+ featureSubset.get.map(metadata.numBins(_))
+ } else {
+ metadata.numBins
+ }
+ }
+
+ /**
+ * Offset for each feature for calculating indices into the [[allStats]] array.
+ */
+ private val featureOffsets: Array[Int] = {
+ numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
+ }
/**
* Indicator for each feature of whether that feature is an unordered feature.
@@ -51,107 +71,95 @@ private[tree] abstract class DTStatsAggregator(
def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
/**
- * Total number of elements stored in this aggregator.
+ * Total number of elements stored in this aggregator
*/
- def allStatsSize: Int
+ private val allStatsSize: Int = featureOffsets.last
/**
- * Get flat array of elements stored in this aggregator.
+ * Flat array of elements.
+ * Index for start of stats for a (feature, bin) is:
+ * index = featureOffsets(featureIndex) + binIndex * statsSize
+ * Note: For unordered features,
+ * the left child stats have binIndex in [0, numBins(featureIndex) / 2))
+ * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
*/
- protected def allStats: Array[Double]
+ private val allStats: Array[Double] = new Array[Double](allStatsSize)
+
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
- * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
- * from [[getNodeFeatureOffset]].
+ * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset
+ * from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
- * [[getLeftRightNodeFeatureOffsets]].
+ * [[getLeftRightFeatureOffsets]].
*/
- def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = {
- impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize)
+ def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
+ impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
}
/**
- * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+ * Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
- def update(
- nodeIndex: Int,
- featureIndex: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize
+ def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
+ val i = featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label, instanceWeight)
}
/**
- * Pre-compute node offset for use with [[nodeUpdate]].
- */
- def getNodeOffset(nodeIndex: Int): Int
-
- /**
* Faster version of [[update]].
- * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
- * @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
+ * Update the stats for a given (feature, bin), using the given label.
+ * @param featureOffset For ordered features, this is a pre-computed feature offset
+ * from [[getFeatureOffset]].
+ * For unordered features, this is a pre-computed
+ * (feature, left/right child) offset from
+ * [[getLeftRightFeatureOffsets]].
*/
- def nodeUpdate(
- nodeOffset: Int,
- nodeIndex: Int,
- featureIndex: Int,
+ def featureUpdate(
+ featureOffset: Int,
binIndex: Int,
label: Double,
- instanceWeight: Double): Unit
+ instanceWeight: Double): Unit = {
+ impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
+ label, instanceWeight)
+ }
/**
- * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+ * Pre-compute feature offset for use with [[featureUpdate]].
* For ordered features only.
*/
- def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int
+ def getFeatureOffset(featureIndex: Int): Int = {
+ require(!isUnordered(featureIndex),
+ s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" +
+ s" for unordered feature $featureIndex.")
+ featureOffsets(featureIndex)
+ }
/**
- * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+ * Pre-compute feature offset for use with [[featureUpdate]].
* For unordered features only.
*/
- def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
+ def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
require(isUnordered(featureIndex),
- s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
+ s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," +
s" but was called for ordered feature $featureIndex.")
- val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex)
- (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize)
- }
-
- /**
- * Faster version of [[update]].
- * Update the stats for a given (node, feature, bin), using the given label.
- * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
- * from [[getNodeFeatureOffset]].
- * For unordered features, this is a pre-computed
- * (node, feature, left/right child) offset from
- * [[getLeftRightNodeFeatureOffsets]].
- */
- def nodeFeatureUpdate(
- nodeFeatureOffset: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label,
- instanceWeight)
+ val baseOffset = featureOffsets(featureIndex)
+ (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}
/**
- * For a given (node, feature), merge the stats for two bins.
- * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
- * from [[getNodeFeatureOffset]].
+ * For a given feature, merge the stats for two bins.
+ * @param featureOffset For ordered features, this is a pre-computed feature offset
+ * from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
- * (node, feature, left/right child) offset from
- * [[getLeftRightNodeFeatureOffsets]].
+ * (feature, left/right child) offset from
+ * [[getLeftRightFeatureOffsets]].
* @param binIndex The other bin is merged into this bin.
* @param otherBinIndex This bin is not modified.
*/
- def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
- impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize,
- nodeFeatureOffset + otherBinIndex * statsSize)
+ def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
+ impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize,
+ featureOffset + otherBinIndex * statsSize)
}
/**
@@ -161,7 +169,7 @@ private[tree] abstract class DTStatsAggregator(
def merge(other: DTStatsAggregator): DTStatsAggregator = {
require(allStatsSize == other.allStatsSize,
s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
- + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
+ + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
var i = 0
// TODO: Test BLAS.axpy
while (i < allStatsSize) {
@@ -171,149 +179,3 @@ private[tree] abstract class DTStatsAggregator(
this
}
}
-
-/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
- * and helps with indexing.
- *
- * This instance of [[DTStatsAggregator]] is used when not subsampling features.
- *
- * @param numNodes Number of nodes to collect statistics for.
- */
-private[tree] class DTStatsAggregatorFixedFeatures(
- metadata: DecisionTreeMetadata,
- numNodes: Int) extends DTStatsAggregator(metadata) {
-
- /**
- * Offset for each feature for calculating indices into the [[allStats]] array.
- * Mapping: featureIndex --> offset
- */
- private val featureOffsets: Array[Int] = {
- metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
- }
-
- /**
- * Number of elements for each node, corresponding to stride between nodes in [[allStats]].
- */
- private val nodeStride: Int = featureOffsets.last
-
- override val allStatsSize: Int = numNodes * nodeStride
-
- /**
- * Flat array of elements.
- * Index for start of stats for a (node, feature, bin) is:
- * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
- * Note: For unordered features, the left child stats precede the right child stats
- * in the binIndex order.
- */
- override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
-
- override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
-
- override def nodeUpdate(
- nodeOffset: Int,
- nodeIndex: Int,
- featureIndex: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label, instanceWeight)
- }
-
- override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
- nodeIndex * nodeStride + featureOffsets(featureIndex)
- }
-}
-
-/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
- * and helps with indexing.
- *
- * This instance of [[DTStatsAggregator]] is used when subsampling features.
- *
- * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
- * where nodeIndexInfo stores the index in the group and the
- * feature subsets (if using feature subsets).
- */
-private[tree] class DTStatsAggregatorSubsampledFeatures(
- metadata: DecisionTreeMetadata,
- treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) {
-
- /**
- * For each node, offset for each feature for calculating indices into the [[allStats]] array.
- * Mapping: nodeIndex --> featureIndex --> offset
- */
- private val featureOffsets: Array[Array[Int]] = {
- val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum
- val offsets = new Array[Array[Int]](numNodes)
- treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) =>
- nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) =>
- offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_))
- .scanLeft(0)((total, nBins) => total + statsSize * nBins)
- }
- }
- offsets
- }
-
- /**
- * For each node, offset for each feature for calculating indices into the [[allStats]] array.
- */
- protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _)
-
- override val allStatsSize: Int = nodeOffsets.last
-
- /**
- * Flat array of elements.
- * Index for start of stats for a (node, feature, bin) is:
- * index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize
- * Note: For unordered features, the left child stats precede the right child stats
- * in the binIndex order.
- */
- override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
-
- override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex)
-
- /**
- * Faster version of [[update]].
- * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
- * @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
- * @param featureIndex Index of feature in featuresForNodes(nodeIndex).
- * Note: This is NOT the original feature index.
- */
- override def nodeUpdate(
- nodeOffset: Int,
- nodeIndex: Int,
- featureIndex: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label, instanceWeight)
- }
-
- /**
- * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
- * For ordered features only.
- * @param featureIndex Index of feature in featuresForNodes(nodeIndex).
- * Note: This is NOT the original feature index.
- */
- override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
- nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex)
- }
-}
-
-private[tree] object DTStatsAggregator extends Serializable {
-
- /**
- * Combines two aggregates (modifying the first) and returns the combination.
- */
- def binCombOp(
- agg1: DTStatsAggregator,
- agg2: DTStatsAggregator): DTStatsAggregator = {
- agg1.merge(agg2)
- }
-
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index f3e2619bd8..a89e71e115 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -38,6 +38,17 @@ class InformationGainStats(
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
.format(gain, impurity, leftImpurity, rightImpurity)
}
+
+ override def equals(o: Any) =
+ o match {
+ case other: InformationGainStats => {
+ gain == other.gain &&
+ impurity == other.impurity &&
+ leftImpurity == other.leftImpurity &&
+ rightImpurity == other.rightImpurity
+ }
+ case _ => false
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 30669fcd1c..20d372dc1d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -145,6 +145,7 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
assert(nodesForGroup.size === numTrees, failString)
assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree
+
if (numFeaturesPerNode == numFeatures) {
// featureSubset values should all be None
assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),