aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala140
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala292
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala1
5 files changed, 182 insertions, 267 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b7dc373ebd..b311d10023 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -23,7 +23,6 @@ import scala.collection.mutable
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.Logging
-import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.configuration.Strategy
@@ -36,6 +35,7 @@ import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.SparkContext._
/**
@@ -328,9 +328,8 @@ object DecisionTree extends Serializable with Logging {
* for each subset is updated.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (node, feature, bin).
+ * each (feature, bin).
* @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes).
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param unorderedFeatures Set of indices of unordered features.
* @param instanceWeight Weight (importance) of instance in dataset.
@@ -338,7 +337,6 @@ object DecisionTree extends Serializable with Logging {
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- nodeIndex: Int,
bins: Array[Array[Bin]],
unorderedFeatures: Set[Int],
instanceWeight: Double,
@@ -350,7 +348,6 @@ object DecisionTree extends Serializable with Logging {
// Use all features
agg.metadata.numFeatures
}
- val nodeOffset = agg.getNodeOffset(nodeIndex)
// Iterate over features.
var featureIndexIdx = 0
while (featureIndexIdx < numFeaturesPerNode) {
@@ -363,16 +360,16 @@ object DecisionTree extends Serializable with Logging {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
- agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
+ agg.getLeftRightFeatureOffsets(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
- agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
+ agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
} else {
- agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
+ agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
}
splitIndex += 1
@@ -380,8 +377,7 @@ object DecisionTree extends Serializable with Logging {
} else {
// Ordered feature
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label,
- instanceWeight)
+ agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
}
featureIndexIdx += 1
}
@@ -393,26 +389,24 @@ object DecisionTree extends Serializable with Logging {
* For each feature, the sufficient statistics of one bin are updated.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (node, feature, bin).
+ * each (feature, bin).
* @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes).
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def orderedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- nodeIndex: Int,
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val label = treePoint.label
- val nodeOffset = agg.getNodeOffset(nodeIndex)
+
// Iterate over features.
if (featuresForNode.nonEmpty) {
// Use subsampled features
var featureIndexIdx = 0
while (featureIndexIdx < featuresForNode.get.size) {
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
- agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight)
+ agg.update(featureIndexIdx, binIndex, label, instanceWeight)
featureIndexIdx += 1
}
} else {
@@ -421,7 +415,7 @@ object DecisionTree extends Serializable with Logging {
var featureIndex = 0
while (featureIndex < numFeatures) {
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight)
+ agg.update(featureIndex, binIndex, label, instanceWeight)
featureIndex += 1
}
}
@@ -496,8 +490,8 @@ object DecisionTree extends Serializable with Logging {
* @return agg
*/
def binSeqOp(
- agg: DTStatsAggregator,
- baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = {
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)
@@ -508,9 +502,9 @@ object DecisionTree extends Serializable with Logging {
val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode)
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
- mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures,
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
instanceWeight, featuresForNode)
}
}
@@ -518,30 +512,76 @@ object DecisionTree extends Serializable with Logging {
agg
}
- // Calculate bin aggregates.
- timer.start("aggregation")
- val binAggregates: DTStatsAggregator = {
- val initAgg = if (metadata.subsamplingFeatures) {
- new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo)
- } else {
- new DTStatsAggregatorFixedFeatures(metadata, numNodes)
+ /**
+ * Get node index in group --> features indices map,
+ * which is a short cut to find feature indices for a node given node index in group
+ * @param treeToNodeToIndexInfo
+ * @return
+ */
+ def getNodeToFeatures(treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]])
+ : Option[Map[Int, Array[Int]]] = if (!metadata.subsamplingFeatures) {
+ None
+ } else {
+ val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
+ treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
+ nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
+ assert(nodeIndexInfo.featureSubset.isDefined)
+ mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
+ }
}
- input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
+ Some(mutableNodeToFeatures.toMap)
}
- timer.stop("aggregation")
// Calculate best splits for all nodes in the group
timer.start("chooseSplits")
+ // In each partition, iterate all instances and compute aggregate stats for each node,
+ // yield an (nodeIndex, nodeAggregateStats) pair for each node.
+ // After a `reduceByKey` operation,
+ // stats of a node will be shuffled to a particular partition and be combined together,
+ // then best splits for nodes are found there.
+ // Finally, only best Splits for nodes are collected to driver to construct decision tree.
+ val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
+ val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
+ val nodeToBestSplits =
+ input.mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOp(nodeStatsAggregators, _))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+ }.reduceByKey((a, b) => a.merge(b))
+ .map { case (nodeIndex, aggStats) =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+
+ // find best split for each node
+ val (split: Split, stats: InformationGainStats, predict: Predict) =
+ binsToBestSplit(aggStats, splits, featuresForNode)
+ (nodeIndex, (split, stats, predict))
+ }.collectAsMap()
+
+ timer.stop("chooseSplits")
+
// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
val nodeIndex = node.id
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val featuresForNode = nodeInfo.featureSubset
val (split: Split, stats: InformationGainStats, predict: Predict) =
- binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode)
+ nodeToBestSplits(aggNodeIndex)
logDebug("best split = " + split)
// Extract info for this node. Create children if not leaf.
@@ -565,7 +605,7 @@ object DecisionTree extends Serializable with Logging {
}
}
}
- timer.stop("chooseSplits")
+
}
/**
@@ -633,36 +673,33 @@ object DecisionTree extends Serializable with Logging {
/**
* Find the best split for a node.
* @param binAggregates Bin statistics.
- * @param nodeIndex Index into aggregates for node to split in this group.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
binAggregates: DTStatsAggregator,
- nodeIndex: Int,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
- val metadata: DecisionTreeMetadata = binAggregates.metadata
-
// calculate predict only once
var predict: Option[Predict] = None
// For each (feature, split), calculate the gain, and select the best (feature, split).
- val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx =>
+ val (bestSplit, bestSplitStats) =
+ Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
- val numSplits = metadata.numSplits(featureIndex)
- if (metadata.isContinuous(featureIndex)) {
+ val numSplits = binAggregates.metadata.numSplits(featureIndex)
+ if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
- val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
- binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
+ binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
splitIndex += 1
}
// Find best split.
@@ -672,27 +709,29 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
- } else if (metadata.isUnordered(featureIndex)) {
+ } else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
- binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
+ binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
- val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
- val numBins = metadata.numBins(featureIndex)
+ val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
+ val numBins = binAggregates.metadata.numBins(featureIndex)
/* Each bin is one category (feature value).
* The bins are ordered based on centroidForCategories, and this ordering determines which
@@ -700,7 +739,7 @@ object DecisionTree extends Serializable with Logging {
*
* centroidForCategories is a list: (category, centroid)
*/
- val centroidForCategories = if (metadata.isMulticlass) {
+ val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
// For categorical variables in multiclass classification,
// the bins are ordered by the impurity of their corresponding labels.
Range(0, numBins).map { case featureValue =>
@@ -741,7 +780,7 @@ object DecisionTree extends Serializable with Logging {
while (splitIndex < numSplits) {
val currentCategory = categoriesSortedByCentroid(splitIndex)._1
val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
- binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory)
+ binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
splitIndex += 1
}
// lastCategory = index of bin with total aggregates for this (node, feature)
@@ -756,7 +795,8 @@ object DecisionTree extends Serializable with Logging {
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats,
+ rightChildStats, binAggregates.metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 7fa7725e79..fa7a26f17c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -171,8 +171,8 @@ private class RandomForest (
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
- DecisionTree.findBestSplits(baggedInput,
- metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
+ DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
+ treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
timer.stop("findBestSplits")
}
@@ -382,6 +382,7 @@ object RandomForest extends Serializable with Logging {
* @param maxMemoryUsage Bound on size of aggregate statistics.
* @return (nodesForGroup, treeToNodeToIndexInfo).
* nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+ *
* treeToNodeToIndexInfo holds indices selected features for each node:
* treeIndex --> (global) node index --> (node index in group, feature indices).
* The (global) node index is the index in the tree; the node index in group is the
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index d49df7a016..55f422dff0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -17,17 +17,19 @@
package org.apache.spark.mllib.tree.impl
-import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.impurity._
+
+
/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * DecisionTree statistics aggregator for a node.
+ * This holds a flat array of statistics for a set of (features, bins)
* and helps with indexing.
* This class is abstract to support learning with and without feature subsampling.
*/
-private[tree] abstract class DTStatsAggregator(
- val metadata: DecisionTreeMetadata) extends Serializable {
+private[tree] class DTStatsAggregator(
+ val metadata: DecisionTreeMetadata,
+ featureSubset: Option[Array[Int]]) extends Serializable {
/**
* [[ImpurityAggregator]] instance specifying the impurity type.
@@ -42,7 +44,25 @@ private[tree] abstract class DTStatsAggregator(
/**
* Number of elements (Double values) used for the sufficient statistics of each bin.
*/
- val statsSize: Int = impurityAggregator.statsSize
+ private val statsSize: Int = impurityAggregator.statsSize
+
+ /**
+ * Number of bins for each feature. This is indexed by the feature index.
+ */
+ private val numBins: Array[Int] = {
+ if (featureSubset.isDefined) {
+ featureSubset.get.map(metadata.numBins(_))
+ } else {
+ metadata.numBins
+ }
+ }
+
+ /**
+ * Offset for each feature for calculating indices into the [[allStats]] array.
+ */
+ private val featureOffsets: Array[Int] = {
+ numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
+ }
/**
* Indicator for each feature of whether that feature is an unordered feature.
@@ -51,107 +71,95 @@ private[tree] abstract class DTStatsAggregator(
def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
/**
- * Total number of elements stored in this aggregator.
+ * Total number of elements stored in this aggregator
*/
- def allStatsSize: Int
+ private val allStatsSize: Int = featureOffsets.last
/**
- * Get flat array of elements stored in this aggregator.
+ * Flat array of elements.
+ * Index for start of stats for a (feature, bin) is:
+ * index = featureOffsets(featureIndex) + binIndex * statsSize
+ * Note: For unordered features,
+ * the left child stats have binIndex in [0, numBins(featureIndex) / 2))
+ * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
*/
- protected def allStats: Array[Double]
+ private val allStats: Array[Double] = new Array[Double](allStatsSize)
+
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
- * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
- * from [[getNodeFeatureOffset]].
+ * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset
+ * from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
* (node, feature, left/right child) offset from
- * [[getLeftRightNodeFeatureOffsets]].
+ * [[getLeftRightFeatureOffsets]].
*/
- def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = {
- impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize)
+ def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
+ impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
}
/**
- * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+ * Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
- def update(
- nodeIndex: Int,
- featureIndex: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize
+ def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
+ val i = featureOffsets(featureIndex) + binIndex * statsSize
impurityAggregator.update(allStats, i, label, instanceWeight)
}
/**
- * Pre-compute node offset for use with [[nodeUpdate]].
- */
- def getNodeOffset(nodeIndex: Int): Int
-
- /**
* Faster version of [[update]].
- * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
- * @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
+ * Update the stats for a given (feature, bin), using the given label.
+ * @param featureOffset For ordered features, this is a pre-computed feature offset
+ * from [[getFeatureOffset]].
+ * For unordered features, this is a pre-computed
+ * (feature, left/right child) offset from
+ * [[getLeftRightFeatureOffsets]].
*/
- def nodeUpdate(
- nodeOffset: Int,
- nodeIndex: Int,
- featureIndex: Int,
+ def featureUpdate(
+ featureOffset: Int,
binIndex: Int,
label: Double,
- instanceWeight: Double): Unit
+ instanceWeight: Double): Unit = {
+ impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
+ label, instanceWeight)
+ }
/**
- * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+ * Pre-compute feature offset for use with [[featureUpdate]].
* For ordered features only.
*/
- def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int
+ def getFeatureOffset(featureIndex: Int): Int = {
+ require(!isUnordered(featureIndex),
+ s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" +
+ s" for unordered feature $featureIndex.")
+ featureOffsets(featureIndex)
+ }
/**
- * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+ * Pre-compute feature offset for use with [[featureUpdate]].
* For unordered features only.
*/
- def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
+ def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
require(isUnordered(featureIndex),
- s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
+ s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," +
s" but was called for ordered feature $featureIndex.")
- val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex)
- (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize)
- }
-
- /**
- * Faster version of [[update]].
- * Update the stats for a given (node, feature, bin), using the given label.
- * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
- * from [[getNodeFeatureOffset]].
- * For unordered features, this is a pre-computed
- * (node, feature, left/right child) offset from
- * [[getLeftRightNodeFeatureOffsets]].
- */
- def nodeFeatureUpdate(
- nodeFeatureOffset: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label,
- instanceWeight)
+ val baseOffset = featureOffsets(featureIndex)
+ (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}
/**
- * For a given (node, feature), merge the stats for two bins.
- * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
- * from [[getNodeFeatureOffset]].
+ * For a given feature, merge the stats for two bins.
+ * @param featureOffset For ordered features, this is a pre-computed feature offset
+ * from [[getFeatureOffset]].
* For unordered features, this is a pre-computed
- * (node, feature, left/right child) offset from
- * [[getLeftRightNodeFeatureOffsets]].
+ * (feature, left/right child) offset from
+ * [[getLeftRightFeatureOffsets]].
* @param binIndex The other bin is merged into this bin.
* @param otherBinIndex This bin is not modified.
*/
- def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
- impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize,
- nodeFeatureOffset + otherBinIndex * statsSize)
+ def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = {
+ impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize,
+ featureOffset + otherBinIndex * statsSize)
}
/**
@@ -161,7 +169,7 @@ private[tree] abstract class DTStatsAggregator(
def merge(other: DTStatsAggregator): DTStatsAggregator = {
require(allStatsSize == other.allStatsSize,
s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors."
- + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
+ + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.")
var i = 0
// TODO: Test BLAS.axpy
while (i < allStatsSize) {
@@ -171,149 +179,3 @@ private[tree] abstract class DTStatsAggregator(
this
}
}
-
-/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
- * and helps with indexing.
- *
- * This instance of [[DTStatsAggregator]] is used when not subsampling features.
- *
- * @param numNodes Number of nodes to collect statistics for.
- */
-private[tree] class DTStatsAggregatorFixedFeatures(
- metadata: DecisionTreeMetadata,
- numNodes: Int) extends DTStatsAggregator(metadata) {
-
- /**
- * Offset for each feature for calculating indices into the [[allStats]] array.
- * Mapping: featureIndex --> offset
- */
- private val featureOffsets: Array[Int] = {
- metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
- }
-
- /**
- * Number of elements for each node, corresponding to stride between nodes in [[allStats]].
- */
- private val nodeStride: Int = featureOffsets.last
-
- override val allStatsSize: Int = numNodes * nodeStride
-
- /**
- * Flat array of elements.
- * Index for start of stats for a (node, feature, bin) is:
- * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
- * Note: For unordered features, the left child stats precede the right child stats
- * in the binIndex order.
- */
- override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
-
- override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
-
- override def nodeUpdate(
- nodeOffset: Int,
- nodeIndex: Int,
- featureIndex: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label, instanceWeight)
- }
-
- override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
- nodeIndex * nodeStride + featureOffsets(featureIndex)
- }
-}
-
-/**
- * DecisionTree statistics aggregator.
- * This holds a flat array of statistics for a set of (nodes, features, bins)
- * and helps with indexing.
- *
- * This instance of [[DTStatsAggregator]] is used when subsampling features.
- *
- * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
- * where nodeIndexInfo stores the index in the group and the
- * feature subsets (if using feature subsets).
- */
-private[tree] class DTStatsAggregatorSubsampledFeatures(
- metadata: DecisionTreeMetadata,
- treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) {
-
- /**
- * For each node, offset for each feature for calculating indices into the [[allStats]] array.
- * Mapping: nodeIndex --> featureIndex --> offset
- */
- private val featureOffsets: Array[Array[Int]] = {
- val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum
- val offsets = new Array[Array[Int]](numNodes)
- treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) =>
- nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) =>
- offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_))
- .scanLeft(0)((total, nBins) => total + statsSize * nBins)
- }
- }
- offsets
- }
-
- /**
- * For each node, offset for each feature for calculating indices into the [[allStats]] array.
- */
- protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _)
-
- override val allStatsSize: Int = nodeOffsets.last
-
- /**
- * Flat array of elements.
- * Index for start of stats for a (node, feature, bin) is:
- * index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize
- * Note: For unordered features, the left child stats precede the right child stats
- * in the binIndex order.
- */
- override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
-
- override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex)
-
- /**
- * Faster version of [[update]].
- * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
- * @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
- * @param featureIndex Index of feature in featuresForNodes(nodeIndex).
- * Note: This is NOT the original feature index.
- */
- override def nodeUpdate(
- nodeOffset: Int,
- nodeIndex: Int,
- featureIndex: Int,
- binIndex: Int,
- label: Double,
- instanceWeight: Double): Unit = {
- val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label, instanceWeight)
- }
-
- /**
- * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
- * For ordered features only.
- * @param featureIndex Index of feature in featuresForNodes(nodeIndex).
- * Note: This is NOT the original feature index.
- */
- override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
- nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex)
- }
-}
-
-private[tree] object DTStatsAggregator extends Serializable {
-
- /**
- * Combines two aggregates (modifying the first) and returns the combination.
- */
- def binCombOp(
- agg1: DTStatsAggregator,
- agg2: DTStatsAggregator): DTStatsAggregator = {
- agg1.merge(agg2)
- }
-
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index f3e2619bd8..a89e71e115 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -38,6 +38,17 @@ class InformationGainStats(
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
.format(gain, impurity, leftImpurity, rightImpurity)
}
+
+ override def equals(o: Any) =
+ o match {
+ case other: InformationGainStats => {
+ gain == other.gain &&
+ impurity == other.impurity &&
+ leftImpurity == other.leftImpurity &&
+ rightImpurity == other.rightImpurity
+ }
+ case _ => false
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 30669fcd1c..20d372dc1d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -145,6 +145,7 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
assert(nodesForGroup.size === numTrees, failString)
assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree
+
if (numFeaturesPerNode == numFeatures) {
// featureSubset values should all be None
assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),