aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-09-28 21:44:50 -0700
committerXiangrui Meng <meng@databricks.com>2014-09-28 21:44:50 -0700
commit0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe (patch)
treef3d82bc455227282e96e471b45a87fb07923edce /mllib/src
parentf350cd307045c2c02e713225d8f1247f18ba123e (diff)
downloadspark-0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe.tar.gz
spark-0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe.tar.bz2
spark-0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe.zip
[SPARK-1545] [mllib] Add Random Forests
This PR adds RandomForest to MLlib. The implementation is basic, and future performance optimizations will be important. (Note: RFs = Random Forests.) # Overview ## RandomForest * trains multiple trees at once to reduce the number of passes over the data * allows feature subsets at each node * uses a queue of nodes instead of fixed groups for each level This implementation is based an implementation by manishamde and the [Alpine Labs Sequoia Forest](https://github.com/AlpineNow/SparkML2) by codedeft (in particular, the TreePoint, BaggedPoint, and node queue implementations). Thank you for your inputs! ## Testing Correctness: This has been tested for correctness with the test suites and with DecisionTreeRunner on example datasets. Performance: This has been performance tested using [this branch of spark-perf](https://github.com/jkbradley/spark-perf/tree/rfs). Results below. ### Regression tests for DecisionTree Summary: For training 1 tree, there are small regressions, especially from feature subsampling. In the table below, each row is a single (random) dataset. The 2 different sets of result columns are for 2 different RF implementations: * (numTrees): This is from an earlier commit, after implementing RandomForest to train multiple trees at once. It does not include any code for feature subsampling. * (feature subsets): This is from this current PR's code, after implementing feature subsampling. These tests were to identify regressions in DecisionTree, so they are training 1 tree with all of the features (i.e., no feature subsampling). These were run on an EC2 cluster with 15 workers, training 1 tree with maxDepth = 5 (= 6 levels). Speedup values < 1 indicate slowdowns from the old DecisionTree implementation. numInstances | numFeatures | runtime (sec) | speedup | runtime (sec) | speedup ---- | ---- | ---- | ---- | ---- | ---- | | (numTrees) | (numTrees) | (feature subsets) | (feature subsets) 20000 | 100 | 4.051 | 1.044433473 | 4.478 | 0.9448414471 20000 | 500 | 8.472 | 1.104461756 | 9.315 | 1.004508857 20000 | 1500 | 19.354 | 1.05854087 | 20.863 | 0.9819776638 20000 | 3500 | 43.674 | 1.072033704 | 45.887 | 1.020332556 200000 | 100 | 4.196 | 1.171830315 | 4.848 | 1.014232673 200000 | 500 | 8.926 | 1.082791844 | 9.771 | 0.989151571 200000 | 1500 | 20.58 | 1.068415938 | 22.134 | 0.9934038131 200000 | 3500 | 48.043 | 1.075203464 | 52.249 | 0.9886505005 2000000 | 100 | 4.944 | 1.01355178 | 5.796 | 0.8645617667 2000000 | 500 | 11.11 | 1.016831683 | 12.482 | 0.9050632911 2000000 | 1500 | 31.144 | 1.017852556 | 35.274 | 0.8986789136 2000000 | 3500 | 79.981 | 1.085382778 | 101.105 | 0.8586123337 20000000 | 100 | 8.304 | 0.9270231214 | 9.073 | 0.8484514494 20000000 | 500 | 28.174 | 1.083268262 | 34.236 | 0.8914592826 20000000 | 1500 | 143.97 | 0.9579634646 | 159.275 | 0.8659111599 ### Tests for forests I have run other tests with numTrees=10 and with sqrt(numFeatures), and those indicate that multi-model training and feature subsets can speed up training for forests, especially when training deeper trees. # Details on specific classes ## Changes to DecisionTree * Main train() method is now in RandomForest. * findBestSplits() is no longer needed. (It split levels into groups, but we now use a queue of nodes.) * Many small changes to support RFs. (Note: These methods should be moved to RandomForest.scala in a later PR, but are in DecisionTree.scala to make code comparison easier.) ## RandomForest * Main train() method is from old DecisionTree. * selectNodesToSplit: Note that it selects nodes and feature subsets jointly to track memory usage. ## RandomForestModel * Stores an Array[DecisionTreeModel] * Prediction: * For classification, most common label. For regression, mean. * We could support other methods later. ## examples/.../DecisionTreeRunner * This now takes numTrees and featureSubsetStrategy, to support RFs. ## DTStatsAggregator * 2 types of functionality (w/ and w/o subsampling features): These require different indexing methods. (We could treat both as subsampling, but this is less efficient DTStatsAggregator is now abstract, and 2 child classes implement these 2 types of functionality. ## impurities * These now take instance weights. ## Node * Some vals changed to vars. * This is unfortunately a public API change (DeveloperApi). This could be avoided by creating a LearningNode struct, but would be awkward. ## RandomForestSuite Please let me know if there are missing tests! ## BaggedPoint This wraps TreePoint and holds bootstrap weights/counts. # Design decisions * BaggedPoint: BaggedPoint is separate from TreePoint since it may be useful for other bagging algorithms later on. * RandomForest public API: What options should be easily supported by the train* methods? Should ALL options be in the Java-friendly constructors? Should there be a constructor taking Strategy? * Feature subsampling options: What options should be supported? scikit-learn supports the same options, except for "onethird." One option would be to allow users to specific fractions ("0.1"): the current options could be supported, and any unrecognized values would be parsed as Doubles in [0,1]. * Splits and bins are computed before bootstrapping, so all trees use the same discretization. * One queue, instead of one queue per tree. CC: mengxr manishamde codedeft chouqin Please let me know if you have suggestions---thanks! Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Author: qiping.lqp <qiping.lqp@alibaba-inc.com> Author: chouqin <liqiping1991@gmail.com> Closes #2435 from jkbradley/rfs-new and squashes the following commits: c694174 [Joseph K. Bradley] Fixed typo cc59d78 [Joseph K. Bradley] fixed imports e25909f [Joseph K. Bradley] Simplified node group maps. Specifically, created NodeIndexInfo to store node index in agg and feature subsets, and no longer create extra maps in findBestSplits fbe9a1e [Joseph K. Bradley] Changed default featureSubsetStrategy to be sqrt for classification, onethird for regression. Updated docs with references. ef7c293 [Joseph K. Bradley] Updates based on code review. Most substantial changes: * Simplified DTStatsAggregator * Made RandomForestModel.trees public * Added test for regression to RandomForestSuite 593b13c [Joseph K. Bradley] Fixed bug in metadata for computing log2(num features). Now it checks >= 1. a1a08df [Joseph K. Bradley] Removed old comments 866e766 [Joseph K. Bradley] Changed RandomForestSuite randomized tests to use multiple fixed random seeds. ff8bb96 [Joseph K. Bradley] removed usage of null from RandomForest and replaced with Option bf1a4c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 6b79c07 [Joseph K. Bradley] Added RandomForestSuite, and fixed small bugs, style issues. d7753d4 [Joseph K. Bradley] Added numTrees and featureSubsetStrategy to DecisionTreeRunner (to support RandomForest). Fixed bugs so that RandomForest now runs. 746d43c [Joseph K. Bradley] Implemented feature subsampling. Tested DecisionTree but not RandomForest. 6309d1d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new. Added RandomForestModel.toString b7ae594 [Joseph K. Bradley] Updated docs. Small fix for bug which does not cause errors: No longer allocate unused child nodes for leaf nodes. 121c74e [Joseph K. Bradley] Basic random forests are implemented. Random features per node not yet implemented. Test suite not implemented. 325d18a [Joseph K. Bradley] Merge branch 'chouqin-dt-preprune' into rfs-new 4ef9bf1 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy. a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 6da8571 [Joseph K. Bradley] RFs partly implemented, not done yet eddd1eb [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 5c4ac33 [Joseph K. Bradley] Added check in Strategy to make sure minInstancesPerNode >= 1 0dd4d87 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py 306120f [Joseph K. Bradley] Fixed typo in DecisionTreeModel.scala doc eaa1dcf [Joseph K. Bradley] Added topNode doc in DecisionTree and scalastyle fix d4d7864 [Joseph K. Bradley] Marked Node.build as deprecated d4dbb99 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 1a8f0ad [Joseph K. Bradley] Eliminated pre-allocated nodes array in main train() method. * Nodes are constructed and added to the tree structure as needed during training. 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 2ab763b [Joseph K. Bradley] Simplifications to DecisionTree code: efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala457
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala451
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala80
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala219
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala47
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala105
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala210
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala245
13 files changed, 1353 insertions, 492 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index c7f2576c82..b7dc373ebd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -18,12 +18,14 @@
package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
+import scala.collection.mutable
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.Logging
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
@@ -33,7 +35,6 @@ import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.XORShiftRandom
@@ -56,99 +57,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* @return DecisionTreeModel that can be used for prediction
*/
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
-
- val timer = new TimeTracker()
-
- timer.start("total")
-
- timer.start("init")
-
- val retaggedInput = input.retag(classOf[LabeledPoint])
- val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy)
- logDebug("algo = " + strategy.algo)
- logDebug("maxBins = " + metadata.maxBins)
-
- // Find the splits and the corresponding bins (interval between the splits) using a sample
- // of the input data.
- timer.start("findSplitsBins")
- val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
- timer.stop("findSplitsBins")
- logDebug("numBins: feature: number of bins")
- logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
- s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
- }.mkString("\n"))
-
- // Bin feature values (TreePoint representation).
- // Cache input RDD for speedup during multiple passes.
- val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
- .persist(StorageLevel.MEMORY_AND_DISK)
-
- // depth of the decision tree
- val maxDepth = strategy.maxDepth
- require(maxDepth <= 30,
- s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
-
- // Calculate level for single group construction
-
- // Max memory usage for aggregates
- val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L
- logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
- // TODO: Calculate memory usage more precisely.
- val numElementsPerNode = DecisionTree.getElementsPerNode(metadata)
-
- logDebug("numElementsPerNode = " + numElementsPerNode)
- val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
- val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1)
- logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup)
- // nodes at a level is 2^level. level is zero indexed.
- val maxLevelForSingleGroup = math.max(
- (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0)
- logDebug("max level for single group = " + maxLevelForSingleGroup)
-
- timer.stop("init")
-
- /*
- * The main idea here is to perform level-wise training of the decision tree nodes thus
- * reducing the passes over the data from l to log2(l) where l is the total number of nodes.
- * Each data sample is handled by a particular node at that level (or it reaches a leaf
- * beforehand and is not used in later levels.
- */
-
- var topNode: Node = null // set on first iteration
- var level = 0
- var break = false
- while (level <= maxDepth && !break) {
- logDebug("#####################################")
- logDebug("level = " + level)
- logDebug("#####################################")
-
- // Find best split for all nodes at a level.
- timer.start("findBestSplits")
- val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput,
- metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer)
- timer.stop("findBestSplits")
-
- if (level == 0) {
- topNode = tmpTopNode
- }
- if (doneTraining) {
- break = true
- logDebug("done training")
- }
-
- level += 1
- }
-
- logDebug("#####################################")
- logDebug("Extracting tree model")
- logDebug("#####################################")
-
- timer.stop("total")
-
- logInfo("Internal timing for DecisionTree:")
- logInfo(s"$timer")
-
- new DecisionTreeModel(topNode, strategy.algo)
+ // Note: random seed will not be used since numTrees = 1.
+ val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
+ val rfModel = rf.train(input)
+ rfModel.trees(0)
}
}
@@ -353,57 +265,9 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Returns an array of optimal splits for all nodes at a given level. Splits the task into
- * multiple groups if the level-wise training task could lead to memory overflow.
- *
- * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
- * @param metadata Learning and dataset metadata
- * @param level Level of the tree
- * @param topNode Root node of the tree (or invalid node when training first level).
- * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
- * @param bins possible bins for all features, indexed (numFeatures)(numBins)
- * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
- * @return (root, doneTraining) where:
- * root = Root node (which is newly created on the first iteration),
- * doneTraining = true if no more internal nodes were created.
- */
- private[tree] def findBestSplits(
- input: RDD[TreePoint],
- metadata: DecisionTreeMetadata,
- level: Int,
- topNode: Node,
- splits: Array[Array[Split]],
- bins: Array[Array[Bin]],
- maxLevelForSingleGroup: Int,
- timer: TimeTracker = new TimeTracker): (Node, Boolean) = {
-
- // split into groups to avoid memory overflow during aggregation
- if (level > maxLevelForSingleGroup) {
- // When information for all nodes at a given level cannot be stored in memory,
- // the nodes are divided into multiple groups at each level with the number of groups
- // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
- // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
- val numGroups = 1 << level - maxLevelForSingleGroup
- logDebug("numGroups = " + numGroups)
- // Iterate over each group of nodes at a level.
- var groupIndex = 0
- var doneTraining = true
- while (groupIndex < numGroups) {
- val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
- topNode, splits, bins, timer, numGroups, groupIndex)
- doneTraining = doneTraining && doneTrainingGroup
- groupIndex += 1
- }
- (topNode, doneTraining) // Not first iteration, so topNode was already set.
- } else {
- findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer)
- }
- }
-
- /**
* Get the node index corresponding to this data point.
- * This function mimics prediction, passing an example from the root node down to a node
- * at the current level being trained; that node's index is returned.
+ * This function mimics prediction, passing an example from the root node down to a leaf
+ * or unsplit node; that node's index is returned.
*
* @param node Node in tree from which to classify the given data point.
* @param binnedFeatures Binned feature vector for data point.
@@ -413,14 +277,15 @@ object DecisionTree extends Serializable with Logging {
* Otherwise, last node reachable in tree matching this example.
* Note: This is the global node index, i.e., the index used in the tree.
* This index is different from the index used during training a particular
- * set of nodes in a (level, group).
+ * group of nodes on one call to [[findBestSplits()]].
*/
private def predictNodeIndex(
node: Node,
binnedFeatures: Array[Int],
bins: Array[Array[Bin]],
unorderedFeatures: Set[Int]): Int = {
- if (node.isLeaf) {
+ if (node.isLeaf || node.split.isEmpty) {
+ // Node is either leaf, or has not yet been split.
node.id
} else {
val featureIndex = node.split.get.feature
@@ -465,43 +330,60 @@ object DecisionTree extends Serializable with Logging {
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (node, feature, bin).
* @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
+ * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes).
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param unorderedFeatures Set of indices of unordered features.
+ * @param instanceWeight Weight (importance) of instance in dataset.
*/
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
nodeIndex: Int,
bins: Array[Array[Bin]],
- unorderedFeatures: Set[Int]): Unit = {
- // Iterate over all features.
- val numFeatures = treePoint.binnedFeatures.size
+ unorderedFeatures: Set[Int],
+ instanceWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
+ val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ featuresForNode.get.size
+ } else {
+ // Use all features
+ agg.metadata.numFeatures
+ }
val nodeOffset = agg.getNodeOffset(nodeIndex)
- var featureIndex = 0
- while (featureIndex < numFeatures) {
+ // Iterate over features.
+ var featureIndexIdx = 0
+ while (featureIndexIdx < numFeaturesPerNode) {
+ val featureIndex = if (featuresForNode.nonEmpty) {
+ featuresForNode.get.apply(featureIndexIdx)
+ } else {
+ featureIndexIdx
+ }
if (unorderedFeatures.contains(featureIndex)) {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
- agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+ agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
// Update the left or right bin for each split.
- val numSplits = agg.numSplits(featureIndex)
+ val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
- agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label)
+ agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
+ instanceWeight)
} else {
- agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label)
+ agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
+ instanceWeight)
}
splitIndex += 1
}
} else {
// Ordered feature
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label)
+ agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label,
+ instanceWeight)
}
- featureIndex += 1
+ featureIndexIdx += 1
}
}
@@ -513,66 +395,77 @@ object DecisionTree extends Serializable with Logging {
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (node, feature, bin).
* @param treePoint Data point being aggregated.
- * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
- * @return agg
+ * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes).
+ * @param instanceWeight Weight (importance) of instance in dataset.
*/
private def orderedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- nodeIndex: Int): Unit = {
+ nodeIndex: Int,
+ instanceWeight: Double,
+ featuresForNode: Option[Array[Int]]): Unit = {
val label = treePoint.label
val nodeOffset = agg.getNodeOffset(nodeIndex)
- // Iterate over all features.
- val numFeatures = agg.numFeatures
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label)
- featureIndex += 1
+ // Iterate over features.
+ if (featuresForNode.nonEmpty) {
+ // Use subsampled features
+ var featureIndexIdx = 0
+ while (featureIndexIdx < featuresForNode.get.size) {
+ val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
+ agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight)
+ featureIndexIdx += 1
+ }
+ } else {
+ // Use all features
+ val numFeatures = agg.metadata.numFeatures
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val binIndex = treePoint.binnedFeatures(featureIndex)
+ agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight)
+ featureIndex += 1
+ }
}
}
/**
- * Returns an array of optimal splits for a group of nodes at a given level
+ * Given a group of nodes, this finds the best split for each node.
*
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
* @param metadata Learning and dataset metadata
- * @param level Level of the tree
- * @param topNode Root node of the tree (or invalid node when training first level).
+ * @param topNodes Root node for each tree. Used for matching instances with nodes.
+ * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree
+ * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
+ * where nodeIndexInfo stores the index in the group and the
+ * feature subsets (if using feature subsets).
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
- * @param numGroups total number of node groups at the current level. Default value is set to 1.
- * @param groupIndex index of the node group being processed. Default value is set to 0.
- * @return (root, doneTraining) where:
- * root = Root node (which is newly created on the first iteration),
- * doneTraining = true if no more internal nodes were created.
+ * @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
+ * Updated with new non-leaf nodes which are created.
*/
- private def findBestSplitsPerGroup(
- input: RDD[TreePoint],
+ private[tree] def findBestSplits(
+ input: RDD[BaggedPoint[TreePoint]],
metadata: DecisionTreeMetadata,
- level: Int,
- topNode: Node,
+ topNodes: Array[Node],
+ nodesForGroup: Map[Int, Array[Node]],
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
- timer: TimeTracker,
- numGroups: Int = 1,
- groupIndex: Int = 0): (Node, Boolean) = {
+ nodeQueue: mutable.Queue[(Int, Node)],
+ timer: TimeTracker = new TimeTracker): Unit = {
/*
* The high-level descriptions of the best split optimizations are noted here.
*
- * *Level-wise training*
- * We perform bin calculations for all nodes at the given level to avoid making multiple
- * passes over the data. Thus, for a slightly increased computation and storage cost we save
- * several iterations over the data especially at higher levels of the decision tree.
+ * *Group-wise training*
+ * We perform bin calculations for groups of nodes to reduce the number of
+ * passes over the data. Each iteration requires more computation and storage,
+ * but saves several iterations over the data.
*
* *Bin-wise computation*
* We use a bin-wise best split computation strategy instead of a straightforward best split
* computation strategy. Instead of analyzing each sample for contribution to the left/right
* child node impurity of every split, we first categorize each feature of a sample into a
- * bin. Each bin is an interval between a low and high split. Since each split, and thus bin,
- * is ordered (read ordering for categorical variables in the findSplitsBins method),
- * we exploit this structure to calculate aggregates for bins and then use these aggregates
+ * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
* to calculate information gain for each split.
*
* *Aggregation over partitions*
@@ -582,42 +475,15 @@ object DecisionTree extends Serializable with Logging {
* drastically reduce the communication overhead.
*/
- // Common calculations for multiple nested methods:
-
- // numNodes: Number of nodes in this (level of tree, group),
- // where nodes at deeper (larger) levels may be divided into groups.
- val numNodes = Node.maxNodesInLevel(level) / numGroups
+ // numNodes: Number of nodes in this group
+ val numNodes = nodesForGroup.values.map(_.size).sum
logDebug("numNodes = " + numNodes)
-
logDebug("numFeatures = " + metadata.numFeatures)
logDebug("numClasses = " + metadata.numClasses)
logDebug("isMulticlass = " + metadata.isMulticlass)
logDebug("isMulticlassWithCategoricalFeatures = " +
metadata.isMulticlassWithCategoricalFeatures)
- // shift when more than one group is used at deep tree level
- val groupShift = numNodes * groupIndex
-
- // Used for treePointToNodeIndex to get an index for this (level, group).
- // - Node.startIndexInLevel(level) gives the global index offset for nodes at this level.
- // - groupShift corrects for groups in this level before the current group.
- val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift
-
- /**
- * Find the node index for the given example.
- * Nodes are indexed from 0 at the start of this (level, group).
- * If the example does not reach this level, returns a value < 0.
- */
- def treePointToNodeIndex(treePoint: TreePoint): Int = {
- if (level == 0) {
- 0
- } else {
- val globalNodeIndex =
- predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
- globalNodeIndex - globalNodeIndexOffset
- }
- }
-
/**
* Performs a sequential aggregation over a partition.
*
@@ -626,21 +492,27 @@ object DecisionTree extends Serializable with Logging {
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (node, feature, bin).
- * @param treePoint Data point being aggregated.
+ * @param baggedPoint Data point being aggregated.
* @return agg
*/
def binSeqOp(
agg: DTStatsAggregator,
- treePoint: TreePoint): DTStatsAggregator = {
- val nodeIndex = treePointToNodeIndex(treePoint)
- // If the example does not reach this level, then nodeIndex < 0.
- // If the example reaches this level but is handled in a different group,
- // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group).
- if (nodeIndex >= 0 && nodeIndex < numNodes) {
- if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg, treePoint, nodeIndex)
- } else {
- mixedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures)
+ baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
+ bins, metadata.unorderedFeatures)
+ val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null)
+ // If the example does not reach a node in this group, then nodeIndex = null.
+ if (nodeInfo != null) {
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val featuresForNode = nodeInfo.featureSubset
+ val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode)
+ } else {
+ mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures,
+ instanceWeight, featuresForNode)
+ }
}
}
agg
@@ -649,71 +521,62 @@ object DecisionTree extends Serializable with Logging {
// Calculate bin aggregates.
timer.start("aggregation")
val binAggregates: DTStatsAggregator = {
- val initAgg = new DTStatsAggregator(metadata, numNodes)
+ val initAgg = if (metadata.subsamplingFeatures) {
+ new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo)
+ } else {
+ new DTStatsAggregatorFixedFeatures(metadata, numNodes)
+ }
input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp)
}
timer.stop("aggregation")
- // Calculate best splits for all nodes at a given level
+ // Calculate best splits for all nodes in the group
timer.start("chooseSplits")
- // On the first iteration, we need to get and return the newly created root node.
- var newTopNode: Node = topNode
-
- // Iterate over all nodes at this level
- var nodeIndex = 0
- var internalNodeCount = 0
- while (nodeIndex < numNodes) {
- val (split: Split, stats: InformationGainStats, predict: Predict) =
- binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
- logDebug("best split = " + split)
-
- val globalNodeIndex = globalNodeIndexOffset + nodeIndex
- // Extract info for this node at the current level.
- val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth)
- val node =
- new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats))
- logDebug("Node = " + node)
-
- if (!isLeaf) {
- internalNodeCount += 1
- }
- if (level == 0) {
- newTopNode = node
- } else {
- // Set parent.
- val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode)
- if (Node.isLeftChild(globalNodeIndex)) {
- parentNode.leftNode = Some(node)
- } else {
- parentNode.rightNode = Some(node)
+ // Iterate over all nodes in this group.
+ nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
+ nodesForTree.foreach { node =>
+ val nodeIndex = node.id
+ val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val featuresForNode = nodeInfo.featureSubset
+ val (split: Split, stats: InformationGainStats, predict: Predict) =
+ binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode)
+ logDebug("best split = " + split)
+
+ // Extract info for this node. Create children if not leaf.
+ val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
+ assert(node.id == nodeIndex)
+ node.predict = predict.predict
+ node.isLeaf = isLeaf
+ node.stats = Some(stats)
+ logDebug("Node = " + node)
+
+ if (!isLeaf) {
+ node.split = Some(split)
+ node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex)))
+ node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex)))
+ nodeQueue.enqueue((treeIndex, node.leftNode.get))
+ nodeQueue.enqueue((treeIndex, node.rightNode.get))
+ logDebug("leftChildIndex = " + node.leftNode.get.id +
+ ", impurity = " + stats.leftImpurity)
+ logDebug("rightChildIndex = " + node.rightNode.get.id +
+ ", impurity = " + stats.rightImpurity)
}
}
- if (level < metadata.maxDepth) {
- logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) +
- ", impurity = " + stats.leftImpurity)
- logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) +
- ", impurity = " + stats.rightImpurity)
- }
-
- nodeIndex += 1
}
timer.stop("chooseSplits")
-
- val doneTraining = internalNodeCount == 0
- (newTopNode, doneTraining)
}
/**
* Calculate the information gain for a given (feature, split) based upon left/right aggregates.
* @param leftImpurityCalculator left node aggregates for this (feature, split)
* @param rightImpurityCalculator right node aggregate for this (feature, split)
- * @return information gain and statistics for all splits
+ * @return information gain and statistics for split
*/
private def calculateGainForSplit(
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
- level: Int,
metadata: DecisionTreeMetadata): InformationGainStats = {
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
@@ -753,7 +616,7 @@ object DecisionTree extends Serializable with Logging {
* Calculate predict value for current node, given stats of any split.
* Note that this function is called only once for each node.
* @param leftImpurityCalculator left node aggregates for a split
- * @param rightImpurityCalculator right node aggregates for a node
+ * @param rightImpurityCalculator right node aggregates for a split
* @return predict value for current node
*/
private def calculatePredict(
@@ -770,27 +633,33 @@ object DecisionTree extends Serializable with Logging {
/**
* Find the best split for a node.
* @param binAggregates Bin statistics.
- * @param nodeIndex Index for node to split in this (level, group).
- * @return tuple for best split: (Split, information gain)
+ * @param nodeIndex Index into aggregates for node to split in this group.
+ * @return tuple for best split: (Split, information gain, prediction at node)
*/
private def binsToBestSplit(
binAggregates: DTStatsAggregator,
nodeIndex: Int,
- level: Int,
- metadata: DecisionTreeMetadata,
- splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
+ splits: Array[Array[Split]],
+ featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
+
+ val metadata: DecisionTreeMetadata = binAggregates.metadata
// calculate predict only once
var predict: Option[Predict] = None
// For each (feature, split), calculate the gain, and select the best (feature, split).
- val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex =>
+ val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx =>
+ val featureIndex = if (featuresForNode.nonEmpty) {
+ featuresForNode.get.apply(featureIndexIdx)
+ } else {
+ featureIndexIdx
+ }
val numSplits = metadata.numSplits(featureIndex)
if (metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
// Afterwards, binAggregates for a bin is the sum of aggregates for
// that bin + all preceding bins.
- val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
+ val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
var splitIndex = 0
while (splitIndex < numSplits) {
binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
@@ -803,26 +672,26 @@ object DecisionTree extends Serializable with Logging {
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
(splitIdx, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val (leftChildOffset, rightChildOffset) =
- binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex)
+ binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
// Ordered categorical feature
- val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
+ val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx)
val numBins = metadata.numBins(featureIndex)
/* Each bin is one category (feature value).
@@ -887,7 +756,7 @@ object DecisionTree extends Serializable with Logging {
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
+ val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata)
(splitIndex, gainStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
@@ -904,18 +773,6 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Get the number of values to be stored per node in the bin aggregates.
- */
- private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = {
- val totalBins = metadata.numBins.map(_.toLong).sum
- if (metadata.isClassification) {
- metadata.numClasses * totalBins
- } else {
- 3 * totalBins
- }
- }
-
- /**
* Returns splits and bins for decision tree calculation.
* Continuous and categorical features are handled differently.
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
new file mode 100644
index 0000000000..7fa7725e79
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -0,0 +1,451 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
+import org.apache.spark.mllib.tree.impurity.Impurities
+import org.apache.spark.mllib.tree.model._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+/**
+ * :: Experimental ::
+ * A class which implements a random forest learning algorithm for classification and regression.
+ * It supports both continuous and categorical features.
+ *
+ * The settings for featureSubsetStrategy are based on the following references:
+ * - log2: tested in Breiman (2001)
+ * - sqrt: recommended by Breiman manual for random forests
+ * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
+ * package.
+ * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]]
+ * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for
+ * random forests]]
+ *
+ * @param strategy The configuration parameters for the random forest algorithm which specify
+ * the type of algorithm (classification, regression, etc.), feature type
+ * (continuous, categorical), depth of the tree, quantile calculation strategy,
+ * etc.
+ * @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
+ * @param featureSubsetStrategy Number of features to consider for splits at each node.
+ * Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ * If "auto" is set, this parameter is set based on numTrees:
+ * if numTrees == 1, set to "all";
+ * if numTrees > 1 (forest) set to "sqrt" for classification and
+ * to "onethird" for regression.
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
+ */
+@Experimental
+private class RandomForest (
+ private val strategy: Strategy,
+ private val numTrees: Int,
+ featureSubsetStrategy: String,
+ private val seed: Int)
+ extends Serializable with Logging {
+
+ strategy.assertValid()
+ require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
+ require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
+ s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
+ s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
+
+ /**
+ * Method to train a decision tree model over an RDD
+ * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @return RandomForestModel that can be used for prediction
+ */
+ def train(input: RDD[LabeledPoint]): RandomForestModel = {
+
+ val timer = new TimeTracker()
+
+ timer.start("total")
+
+ timer.start("init")
+
+ val retaggedInput = input.retag(classOf[LabeledPoint])
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
+ logDebug("algo = " + strategy.algo)
+ logDebug("numTrees = " + numTrees)
+ logDebug("seed = " + seed)
+ logDebug("maxBins = " + metadata.maxBins)
+ logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
+ logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
+
+ // Find the splits and the corresponding bins (interval between the splits) using a sample
+ // of the input data.
+ timer.start("findSplitsBins")
+ val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
+ timer.stop("findSplitsBins")
+ logDebug("numBins: feature: number of bins")
+ logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
+ s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
+ }.mkString("\n"))
+
+ // Bin feature values (TreePoint representation).
+ // Cache input RDD for speedup during multiple passes.
+ val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
+ val baggedInput = if (numTrees > 1) {
+ BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed)
+ } else {
+ BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+ }.persist(StorageLevel.MEMORY_AND_DISK)
+
+ // depth of the decision tree
+ val maxDepth = strategy.maxDepth
+ require(maxDepth <= 30,
+ s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
+
+ // Max memory usage for aggregates
+ // TODO: Calculate memory usage more precisely.
+ val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
+ logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
+ val maxMemoryPerNode = {
+ val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+ // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
+ Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
+ .take(metadata.numFeaturesPerNode).map(_._2))
+ } else {
+ None
+ }
+ RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
+ }
+ require(maxMemoryPerNode <= maxMemoryUsage,
+ s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
+ " which is too small for the given features." +
+ s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")
+
+ timer.stop("init")
+
+ /*
+ * The main idea here is to perform group-wise training of the decision tree nodes thus
+ * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
+ * Each data sample is handled by a particular node (or it reaches a leaf and is not used
+ * in lower levels).
+ */
+
+ // FIFO queue of nodes to train: (treeIndex, node)
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+
+ val rng = new scala.util.Random()
+ rng.setSeed(seed)
+
+ // Allocate and queue root nodes.
+ val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
+ Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
+
+ while (nodeQueue.nonEmpty) {
+ // Collect some nodes to split, and choose features for each node (if subsampling).
+ // Each group of nodes may come from one or multiple trees, and at multiple levels.
+ val (nodesForGroup, treeToNodeToIndexInfo) =
+ RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+ // Sanity check (should never occur):
+ assert(nodesForGroup.size > 0,
+ s"RandomForest selected empty nodesForGroup. Error for unknown reason.")
+
+ // Choose node splits, and enqueue new nodes as needed.
+ timer.start("findBestSplits")
+ DecisionTree.findBestSplits(baggedInput,
+ metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
+ timer.stop("findBestSplits")
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+ val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
+ RandomForestModel.build(trees)
+ }
+
+}
+
+object RandomForest extends Serializable with Logging {
+
+ /**
+ * Method to train a decision tree model for binary or multiclass classification.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * Labels should take values {0, 1, ..., numClasses-1}.
+ * @param strategy Parameters for training each tree in the forest.
+ * @param numTrees Number of trees in the random forest.
+ * @param featureSubsetStrategy Number of features to consider for splits at each node.
+ * Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ * If "auto" is set, this parameter is set based on numTrees:
+ * if numTrees == 1, set to "all";
+ * if numTrees > 1 (forest) set to "sqrt" for classification and
+ * to "onethird" for regression.
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @return RandomForestModel that can be used for prediction
+ */
+ def trainClassifier(
+ input: RDD[LabeledPoint],
+ strategy: Strategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Int): RandomForestModel = {
+ require(strategy.algo == Classification,
+ s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
+ val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
+ rf.train(input)
+ }
+
+ /**
+ * Method to train a decision tree model for binary or multiclass classification.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * Labels should take values {0, 1, ..., numClasses-1}.
+ * @param numClassesForClassification number of classes for classification.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
+ * E.g., an entry (n -> k) indicates that feature n is categorical
+ * with k categories indexed from 0: {0, 1, ..., k-1}.
+ * @param numTrees Number of trees in the random forest.
+ * @param featureSubsetStrategy Number of features to consider for splits at each node.
+ * Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ * If "auto" is set, this parameter is set based on numTrees:
+ * if numTrees == 1, set to "all";
+ * if numTrees > 1 (forest) set to "sqrt" for classification and
+ * to "onethird" for regression.
+ * @param impurity Criterion used for information gain calculation.
+ * Supported values: "gini" (recommended) or "entropy".
+ * @param maxDepth Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * (suggested value: 4)
+ * @param maxBins maximum number of bins used for splitting features
+ * (suggested value: 100)
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @return RandomForestModel that can be used for prediction
+ */
+ def trainClassifier(
+ input: RDD[LabeledPoint],
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo: Map[Int, Int],
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int,
+ seed: Int = Utils.random.nextInt()): RandomForestModel = {
+ val impurityType = Impurities.fromString(impurity)
+ val strategy = new Strategy(Classification, impurityType, maxDepth,
+ numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
+ trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainClassifier]]
+ */
+ def trainClassifier(
+ input: JavaRDD[LabeledPoint],
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int,
+ seed: Int): RandomForestModel = {
+ trainClassifier(input.rdd, numClassesForClassification,
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
+ }
+
+ /**
+ * Method to train a decision tree model for regression.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * Labels are real numbers.
+ * @param strategy Parameters for training each tree in the forest.
+ * @param numTrees Number of trees in the random forest.
+ * @param featureSubsetStrategy Number of features to consider for splits at each node.
+ * Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ * If "auto" is set, this parameter is set based on numTrees:
+ * if numTrees == 1, set to "all";
+ * if numTrees > 1 (forest) set to "sqrt" for classification and
+ * to "onethird" for regression.
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @return RandomForestModel that can be used for prediction
+ */
+ def trainRegressor(
+ input: RDD[LabeledPoint],
+ strategy: Strategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ seed: Int): RandomForestModel = {
+ require(strategy.algo == Regression,
+ s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
+ val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
+ rf.train(input)
+ }
+
+ /**
+ * Method to train a decision tree model for regression.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * Labels are real numbers.
+ * @param categoricalFeaturesInfo Map storing arity of categorical features.
+ * E.g., an entry (n -> k) indicates that feature n is categorical
+ * with k categories indexed from 0: {0, 1, ..., k-1}.
+ * @param numTrees Number of trees in the random forest.
+ * @param featureSubsetStrategy Number of features to consider for splits at each node.
+ * Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ * If "auto" is set, this parameter is set based on numTrees:
+ * if numTrees == 1, set to "all";
+ * if numTrees > 1 (forest) set to "sqrt" for classification and
+ * to "onethird" for regression.
+ * @param impurity Criterion used for information gain calculation.
+ * Supported values: "variance".
+ * @param maxDepth Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * (suggested value: 4)
+ * @param maxBins maximum number of bins used for splitting features
+ * (suggested value: 100)
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @return RandomForestModel that can be used for prediction
+ */
+ def trainRegressor(
+ input: RDD[LabeledPoint],
+ categoricalFeaturesInfo: Map[Int, Int],
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int,
+ seed: Int = Utils.random.nextInt()): RandomForestModel = {
+ val impurityType = Impurities.fromString(impurity)
+ val strategy = new Strategy(Regression, impurityType, maxDepth,
+ 0, maxBins, Sort, categoricalFeaturesInfo)
+ trainRegressor(input, strategy, numTrees, featureSubsetStrategy, seed)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainRegressor]]
+ */
+ def trainRegressor(
+ input: JavaRDD[LabeledPoint],
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ impurity: String,
+ maxDepth: Int,
+ maxBins: Int,
+ seed: Int): RandomForestModel = {
+ trainRegressor(input.rdd,
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
+ }
+
+ /**
+ * List of supported feature subset sampling strategies.
+ */
+ val supportedFeatureSubsetStrategies: Array[String] =
+ Array("auto", "all", "sqrt", "log2", "onethird")
+
+ private[tree] class NodeIndexInfo(
+ val nodeIndexInGroup: Int,
+ val featureSubset: Option[Array[Int]]) extends Serializable
+
+ /**
+ * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.
+ * This tracks the memory usage for aggregates and stops adding nodes when too much memory
+ * will be needed; this allows an adaptive number of nodes since different nodes may require
+ * different amounts of memory (if featureSubsetStrategy is not "all").
+ *
+ * @param nodeQueue Queue of nodes to split.
+ * @param maxMemoryUsage Bound on size of aggregate statistics.
+ * @return (nodesForGroup, treeToNodeToIndexInfo).
+ * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.
+ * treeToNodeToIndexInfo holds indices selected features for each node:
+ * treeIndex --> (global) node index --> (node index in group, feature indices).
+ * The (global) node index is the index in the tree; the node index in group is the
+ * index in [0, numNodesInGroup) of the node in this group.
+ * The feature indices are None if not subsampling features.
+ */
+ private[tree] def selectNodesToSplit(
+ nodeQueue: mutable.Queue[(Int, Node)],
+ maxMemoryUsage: Long,
+ metadata: DecisionTreeMetadata,
+ rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = {
+ // Collect some nodes to split:
+ // nodesForGroup(treeIndex) = nodes to split
+ val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]()
+ val mutableTreeToNodeToIndexInfo =
+ new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()
+ var memUsage: Long = 0L
+ var numNodesInGroup = 0
+ while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {
+ val (treeIndex, node) = nodeQueue.head
+ // Choose subset of features for node (if subsampling).
+ val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
+ // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir)
+ Some(rng.shuffle(Range(0, metadata.numFeatures).toList)
+ .take(metadata.numFeaturesPerNode).toArray)
+ } else {
+ None
+ }
+ // Check if enough memory remains to add this node to the group.
+ val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
+ if (memUsage + nodeMemUsage <= maxMemoryUsage) {
+ nodeQueue.dequeue()
+ mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node
+ mutableTreeToNodeToIndexInfo
+ .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)
+ = new NodeIndexInfo(numNodesInGroup, featureSubset)
+ }
+ numNodesInGroup += 1
+ memUsage += nodeMemUsage
+ }
+ // Convert mutable maps to immutable ones.
+ val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap
+ val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap
+ (nodesForGroup, treeToNodeToIndexInfo)
+ }
+
+ /**
+ * Get the number of values to be stored for this node in the bin aggregates.
+ * @param featureSubset Indices of features which may be split at this node.
+ * If None, then use all features.
+ */
+ private[tree] def aggregateSizeForNode(
+ metadata: DecisionTreeMetadata,
+ featureSubset: Option[Array[Int]]): Long = {
+ val totalBins = if (featureSubset.nonEmpty) {
+ featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum
+ } else {
+ metadata.numBins.map(_.toLong).sum
+ }
+ if (metadata.isClassification) {
+ metadata.numClasses * totalBins
+ } else {
+ 3 * totalBins
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
new file mode 100644
index 0000000000..937c8a2ac5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+/**
+ * Internal representation of a datapoint which belongs to several subsamples of the same dataset,
+ * particularly for bagging (e.g., for random forests).
+ *
+ * This holds one instance, as well as an array of weights which represent the (weighted)
+ * number of times which this instance appears in each subsample.
+ * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that
+ * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
+ *
+ * @param datum Data instance
+ * @param subsampleWeights Weight of this instance in each subsampled dataset.
+ *
+ * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted
+ * dataset support, update. (We store subsampleWeights as Double for this future extension.)
+ */
+private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
+ extends Serializable
+
+private[tree] object BaggedPoint {
+
+ /**
+ * Convert an input dataset into its BaggedPoint representation,
+ * choosing subsample counts for each instance.
+ * Each subsample has the same number of instances as the original dataset,
+ * and is created by subsampling with replacement.
+ * @param input Input dataset.
+ * @param numSubsamples Number of subsamples of this RDD to take.
+ * @param seed Random seed.
+ * @return BaggedPoint dataset representation
+ */
+ def convertToBaggedRDD[Datum](
+ input: RDD[Datum],
+ numSubsamples: Int,
+ seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
+ input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+ // TODO: Support different sampling rates, and sampling without replacement.
+ // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
+ val poisson = new Poisson(1.0, new DRand(seed + partitionIndex + 1))
+ instances.map { instance =>
+ val subsampleWeights = new Array[Double](numSubsamples)
+ var subsampleIndex = 0
+ while (subsampleIndex < numSubsamples) {
+ subsampleWeights(subsampleIndex) = poisson.nextInt()
+ subsampleIndex += 1
+ }
+ new BaggedPoint(instance, subsampleWeights)
+ }
+ }
+ }
+
+ def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
+ input.map(datum => new BaggedPoint(datum, Array(1.0)))
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 61a9424671..d49df7a016 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -17,16 +17,17 @@
package org.apache.spark.mllib.tree.impl
+import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.impurity._
/**
* DecisionTree statistics aggregator.
* This holds a flat array of statistics for a set of (nodes, features, bins)
* and helps with indexing.
+ * This class is abstract to support learning with and without feature subsampling.
*/
-private[tree] class DTStatsAggregator(
- val metadata: DecisionTreeMetadata,
- val numNodes: Int) extends Serializable {
+private[tree] abstract class DTStatsAggregator(
+ val metadata: DecisionTreeMetadata) extends Serializable {
/**
* [[ImpurityAggregator]] instance specifying the impurity type.
@@ -43,18 +44,6 @@ private[tree] class DTStatsAggregator(
*/
val statsSize: Int = impurityAggregator.statsSize
- val numFeatures: Int = metadata.numFeatures
-
- /**
- * Number of bins for each feature. This is indexed by the feature index.
- */
- val numBins: Array[Int] = metadata.numBins
-
- /**
- * Number of splits for the given feature.
- */
- def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex)
-
/**
* Indicator for each feature of whether that feature is an unordered feature.
* TODO: Is Array[Boolean] any faster?
@@ -62,30 +51,14 @@ private[tree] class DTStatsAggregator(
def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
/**
- * Offset for each feature for calculating indices into the [[allStats]] array.
- */
- private val featureOffsets: Array[Int] = {
- numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
- }
-
- /**
- * Number of elements for each node, corresponding to stride between nodes in [[allStats]].
- */
- private val nodeStride: Int = featureOffsets.last
-
- /**
* Total number of elements stored in this aggregator.
*/
- val allStatsSize: Int = numNodes * nodeStride
+ def allStatsSize: Int
/**
- * Flat array of elements.
- * Index for start of stats for a (node, feature, bin) is:
- * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
- * Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex))
- * and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex))
+ * Get flat array of elements stored in this aggregator.
*/
- val allStats: Array[Double] = new Array[Double](allStatsSize)
+ protected def allStats: Array[Double]
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
@@ -102,36 +75,39 @@ private[tree] class DTStatsAggregator(
/**
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
*/
- def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
- val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label)
+ def update(
+ nodeIndex: Int,
+ featureIndex: Int,
+ binIndex: Int,
+ label: Double,
+ instanceWeight: Double): Unit = {
+ val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize
+ impurityAggregator.update(allStats, i, label, instanceWeight)
}
/**
* Pre-compute node offset for use with [[nodeUpdate]].
*/
- def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
+ def getNodeOffset(nodeIndex: Int): Int
/**
* Faster version of [[update]].
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
* @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
*/
- def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
- val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label)
- }
+ def nodeUpdate(
+ nodeOffset: Int,
+ nodeIndex: Int,
+ featureIndex: Int,
+ binIndex: Int,
+ label: Double,
+ instanceWeight: Double): Unit
/**
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
* For ordered features only.
*/
- def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
- require(!isUnordered(featureIndex),
- s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" +
- s" for unordered feature $featureIndex.")
- nodeIndex * nodeStride + featureOffsets(featureIndex)
- }
+ def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int
/**
* Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
@@ -140,9 +116,9 @@ private[tree] class DTStatsAggregator(
def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = {
require(isUnordered(featureIndex),
s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
- s" but was called for ordered feature $featureIndex.")
- val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
- (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
+ s" but was called for ordered feature $featureIndex.")
+ val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex)
+ (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize)
}
/**
@@ -154,8 +130,13 @@ private[tree] class DTStatsAggregator(
* (node, feature, left/right child) offset from
* [[getLeftRightNodeFeatureOffsets]].
*/
- def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = {
- impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label)
+ def nodeFeatureUpdate(
+ nodeFeatureOffset: Int,
+ binIndex: Int,
+ label: Double,
+ instanceWeight: Double): Unit = {
+ impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label,
+ instanceWeight)
}
/**
@@ -189,7 +170,139 @@ private[tree] class DTStatsAggregator(
}
this
}
+}
+
+/**
+ * DecisionTree statistics aggregator.
+ * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * and helps with indexing.
+ *
+ * This instance of [[DTStatsAggregator]] is used when not subsampling features.
+ *
+ * @param numNodes Number of nodes to collect statistics for.
+ */
+private[tree] class DTStatsAggregatorFixedFeatures(
+ metadata: DecisionTreeMetadata,
+ numNodes: Int) extends DTStatsAggregator(metadata) {
+
+ /**
+ * Offset for each feature for calculating indices into the [[allStats]] array.
+ * Mapping: featureIndex --> offset
+ */
+ private val featureOffsets: Array[Int] = {
+ metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
+ }
+
+ /**
+ * Number of elements for each node, corresponding to stride between nodes in [[allStats]].
+ */
+ private val nodeStride: Int = featureOffsets.last
+ override val allStatsSize: Int = numNodes * nodeStride
+
+ /**
+ * Flat array of elements.
+ * Index for start of stats for a (node, feature, bin) is:
+ * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize
+ * Note: For unordered features, the left child stats precede the right child stats
+ * in the binIndex order.
+ */
+ override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
+
+ override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
+
+ override def nodeUpdate(
+ nodeOffset: Int,
+ nodeIndex: Int,
+ featureIndex: Int,
+ binIndex: Int,
+ label: Double,
+ instanceWeight: Double): Unit = {
+ val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
+ impurityAggregator.update(allStats, i, label, instanceWeight)
+ }
+
+ override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
+ nodeIndex * nodeStride + featureOffsets(featureIndex)
+ }
+}
+
+/**
+ * DecisionTree statistics aggregator.
+ * This holds a flat array of statistics for a set of (nodes, features, bins)
+ * and helps with indexing.
+ *
+ * This instance of [[DTStatsAggregator]] is used when subsampling features.
+ *
+ * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo,
+ * where nodeIndexInfo stores the index in the group and the
+ * feature subsets (if using feature subsets).
+ */
+private[tree] class DTStatsAggregatorSubsampledFeatures(
+ metadata: DecisionTreeMetadata,
+ treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) {
+
+ /**
+ * For each node, offset for each feature for calculating indices into the [[allStats]] array.
+ * Mapping: nodeIndex --> featureIndex --> offset
+ */
+ private val featureOffsets: Array[Array[Int]] = {
+ val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum
+ val offsets = new Array[Array[Int]](numNodes)
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) =>
+ nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) =>
+ offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_))
+ .scanLeft(0)((total, nBins) => total + statsSize * nBins)
+ }
+ }
+ offsets
+ }
+
+ /**
+ * For each node, offset for each feature for calculating indices into the [[allStats]] array.
+ */
+ protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _)
+
+ override val allStatsSize: Int = nodeOffsets.last
+
+ /**
+ * Flat array of elements.
+ * Index for start of stats for a (node, feature, bin) is:
+ * index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize
+ * Note: For unordered features, the left child stats precede the right child stats
+ * in the binIndex order.
+ */
+ override protected val allStats: Array[Double] = new Array[Double](allStatsSize)
+
+ override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex)
+
+ /**
+ * Faster version of [[update]].
+ * Update the stats for a given (node, feature, bin) for ordered features, using the given label.
+ * @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
+ * @param featureIndex Index of feature in featuresForNodes(nodeIndex).
+ * Note: This is NOT the original feature index.
+ */
+ override def nodeUpdate(
+ nodeOffset: Int,
+ nodeIndex: Int,
+ featureIndex: Int,
+ binIndex: Int,
+ label: Double,
+ instanceWeight: Double): Unit = {
+ val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize
+ impurityAggregator.update(allStats, i, label, instanceWeight)
+ }
+
+ /**
+ * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]].
+ * For ordered features only.
+ * @param featureIndex Index of feature in featuresForNodes(nodeIndex).
+ * Note: This is NOT the original feature index.
+ */
+ override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = {
+ nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex)
+ }
}
private[tree] object DTStatsAggregator extends Serializable {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index b6d49e5555..212dce2523 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -48,7 +48,9 @@ private[tree] class DecisionTreeMetadata(
val quantileStrategy: QuantileStrategy,
val maxDepth: Int,
val minInstancesPerNode: Int,
- val minInfoGain: Double) extends Serializable {
+ val minInfoGain: Double,
+ val numTrees: Int,
+ val numFeaturesPerNode: Int) extends Serializable {
def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
@@ -73,6 +75,11 @@ private[tree] class DecisionTreeMetadata(
numBins(featureIndex) - 1
}
+ /**
+ * Indicates if feature subsampling is being used.
+ */
+ def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode
+
}
private[tree] object DecisionTreeMetadata {
@@ -82,7 +89,11 @@ private[tree] object DecisionTreeMetadata {
* This computes which categorical features will be ordered vs. unordered,
* as well as the number of splits and bins for each feature.
*/
- def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
+ def buildMetadata(
+ input: RDD[LabeledPoint],
+ strategy: Strategy,
+ numTrees: Int,
+ featureSubsetStrategy: String): DecisionTreeMetadata = {
val numFeatures = input.take(1)(0).features.size
val numExamples = input.count()
@@ -128,13 +139,43 @@ private[tree] object DecisionTreeMetadata {
}
}
+ // Set number of features to use per node (for random forests).
+ val _featureSubsetStrategy = featureSubsetStrategy match {
+ case "auto" =>
+ if (numTrees == 1) {
+ "all"
+ } else {
+ if (strategy.algo == Classification) {
+ "sqrt"
+ } else {
+ "onethird"
+ }
+ }
+ case _ => featureSubsetStrategy
+ }
+ val numFeaturesPerNode: Int = _featureSubsetStrategy match {
+ case "all" => numFeatures
+ case "sqrt" => math.sqrt(numFeatures).ceil.toInt
+ case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
+ case "onethird" => (numFeatures / 3.0).ceil.toInt
+ }
+
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
- strategy.minInstancesPerNode, strategy.minInfoGain)
+ strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
}
/**
+ * Version of [[buildMetadata()]] for DecisionTree.
+ */
+ def buildMetadata(
+ input: RDD[LabeledPoint],
+ strategy: Strategy): DecisionTreeMetadata = {
+ buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
+ }
+
+ /**
* Given the arity of a categorical feature (arity = number of categories),
* return the number of bins for the feature if it is to be treated as an unordered feature.
* There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 1c8afc2d0f..0e02345aa3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -89,12 +89,12 @@ private[tree] class EntropyAggregator(numClasses: Int)
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+ def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
if (label >= statsSize) {
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
}
- allStats(offset + label.toInt) += 1
+ allStats(offset + label.toInt) += instanceWeight
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 5cfdf345d1..7c83cd48e1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -85,12 +85,12 @@ private[tree] class GiniAggregator(numClasses: Int)
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
+ def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
if (label >= statsSize) {
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
}
- allStats(offset + label.toInt) += 1
+ allStats(offset + label.toInt) += instanceWeight
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 5a047d6cb5..60e2ab2bb8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -78,7 +78,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double): Unit
+ def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit
/**
* Get an [[ImpurityCalculator]] for a (node, feature, bin).
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index e9ccecb1b8..df9eafa5da 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -75,10 +75,10 @@ private[tree] class VarianceAggregator()
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double): Unit = {
- allStats(offset) += 1
- allStats(offset + 1) += label
- allStats(offset + 2) += label * label
+ def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
+ allStats(offset) += instanceWeight
+ allStats(offset + 1) += instanceWeight * label
+ allStats(offset + 2) += instanceWeight * label * label
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 5f0095d23c..56c3e25d92 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -41,12 +41,12 @@ import org.apache.spark.mllib.linalg.Vector
@DeveloperApi
class Node (
val id: Int,
- val predict: Double,
- val isLeaf: Boolean,
- val split: Option[Split],
+ var predict: Double,
+ var isLeaf: Boolean,
+ var split: Option[Split],
var leftNode: Option[Node],
var rightNode: Option[Node],
- val stats: Option[InformationGainStats]) extends Serializable with Logging {
+ var stats: Option[InformationGainStats]) extends Serializable with Logging {
override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
"split = " + split + ", stats = " + stats
@@ -168,6 +168,11 @@ class Node (
private[tree] object Node {
/**
+ * Return a node with the given node id (but nothing else set).
+ */
+ def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None)
+
+ /**
* Return the index of the left child of this node.
*/
def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
new file mode 100644
index 0000000000..538c0e2332
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import scala.collection.mutable
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: Experimental ::
+ * Random forest model for classification or regression.
+ * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make
+ * aggregate predictions.
+ * @param trees Trees which make up this forest. This cannot be empty.
+ * @param algo algorithm type -- classification or regression
+ */
+@Experimental
+class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable {
+
+ require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
+
+ /**
+ * Predict values for a single data point.
+ *
+ * @param features array representing a single data point
+ * @return Double prediction from the trained model
+ */
+ def predict(features: Vector): Double = {
+ algo match {
+ case Classification =>
+ val predictionToCount = new mutable.HashMap[Int, Int]()
+ trees.foreach { tree =>
+ val prediction = tree.predict(features).toInt
+ predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
+ }
+ predictionToCount.maxBy(_._2)._1
+ case Regression =>
+ trees.map(_.predict(features)).sum / trees.size
+ }
+ }
+
+ /**
+ * Predict values for the given data set.
+ *
+ * @param features RDD representing data points to be predicted
+ * @return RDD[Double] where each entry contains the corresponding prediction
+ */
+ def predict(features: RDD[Vector]): RDD[Double] = {
+ features.map(x => predict(x))
+ }
+
+ /**
+ * Get number of trees in forest.
+ */
+ def numTrees: Int = trees.size
+
+ /**
+ * Print full model.
+ */
+ override def toString: String = {
+ val header = algo match {
+ case Classification =>
+ s"RandomForestModel classifier with $numTrees trees\n"
+ case Regression =>
+ s"RandomForestModel regressor with $numTrees trees\n"
+ case _ => throw new IllegalArgumentException(
+ s"RandomForestModel given unknown algo parameter: $algo.")
+ }
+ header + trees.zipWithIndex.map { case (tree, treeIndex) =>
+ s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
+ }.fold("")(_ + _)
+ }
+
+}
+
+private[tree] object RandomForestModel {
+
+ def build(trees: Array[DecisionTreeModel]): RandomForestModel = {
+ require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
+ val algo: Algo = trees(0).algo
+ require(trees.forall(_.algo == algo),
+ "RandomForestModel cannot combine trees which have different output types" +
+ " (classification/regression).")
+ new RandomForestModel(trees, algo)
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 2b2e579b99..a48ed71a1c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
+import scala.collection.mutable
import org.scalatest.FunSuite
@@ -26,39 +27,13 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
import org.apache.spark.mllib.util.LocalSparkContext
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
- def validateClassifier(
- model: DecisionTreeModel,
- input: Seq[LabeledPoint],
- requiredAccuracy: Double) {
- val predictions = input.map(x => model.predict(x.features))
- val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
- prediction != expected.label
- }
- val accuracy = (input.length - numOffPredictions).toDouble / input.length
- assert(accuracy >= requiredAccuracy,
- s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
- }
-
- def validateRegressor(
- model: DecisionTreeModel,
- input: Seq[LabeledPoint],
- requiredMSE: Double) {
- val predictions = input.map(x => model.predict(x.features))
- val squaredError = predictions.zip(input).map { case (prediction, expected) =>
- val err = prediction - expected.label
- err * err
- }.sum
- val mse = squaredError / input.length
- assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
- }
-
test("Binary classification with continuous features: split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
@@ -233,7 +208,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
- // 2^10 - 1 > 100, so categorical features will be ordered
+ // 2^(10-1) - 1 > 100, so categorical features will be ordered
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(!metadata.isUnordered(featureIndex = 0))
@@ -269,9 +244,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 0)
assert(bins(0).length === 0)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode: Node, doneTraining: Boolean) =
- DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
assert(split.categories === List(1.0))
@@ -299,10 +272,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
assert(split.categories.length === 1)
@@ -331,7 +301,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(!metadata.isUnordered(featureIndex = 1))
val model = DecisionTree.train(rdd, strategy)
- validateRegressor(model, arr, 0.0)
+ DecisionTreeSuite.validateRegressor(model, arr, 0.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
}
@@ -352,12 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins.length === 2)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
-
- val split = rootNode.split.get
- assert(split.feature === 0)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val stats = rootNode.stats.get
assert(stats.gain === 0)
@@ -381,12 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins.length === 2)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
-
- val split = rootNode.split.get
- assert(split.feature === 0)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val stats = rootNode.stats.get
assert(stats.gain === 0)
@@ -411,12 +371,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins.length === 2)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
-
- val split = rootNode.split.get
- assert(split.feature === 0)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val stats = rootNode.stats.get
assert(stats.gain === 0)
@@ -441,12 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins.length === 2)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
-
- val split = rootNode.split.get
- assert(split.feature === 0)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val stats = rootNode.stats.get
assert(stats.gain === 0)
@@ -471,25 +421,39 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
numClassesForClassification = 2, maxBins = 100)
val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
- val rootNodeCopy1 = modelOneNode.topNode.deepCopy()
- val rootNodeCopy2 = modelOneNode.topNode.deepCopy()
+ val rootNode1 = modelOneNode.topNode.deepCopy()
+ val rootNode2 = modelOneNode.topNode.deepCopy()
+ assert(rootNode1.leftNode.nonEmpty)
+ assert(rootNode1.rightNode.nonEmpty)
- // Single group second level tree construction.
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
- rootNodeCopy1, splits, bins, 10)
- assert(rootNode.leftNode.nonEmpty)
- assert(rootNode.rightNode.nonEmpty)
+ val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+ // Single group second level tree construction.
+ val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
+ (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
val children1 = new Array[Node](2)
- children1(0) = rootNode.leftNode.get
- children1(1) = rootNode.rightNode.get
-
- // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
- // level tree construction.
- val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
- rootNodeCopy2, splits, bins, 0)
- assert(rootNode2.leftNode.nonEmpty)
- assert(rootNode2.rightNode.nonEmpty)
+ children1(0) = rootNode1.leftNode.get
+ children1(1) = rootNode1.rightNode.get
+
+ // Train one second-level node at a time.
+ val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
+ val treeToNodeToIndexInfoA = Map((0, Map(
+ (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
+ val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
+ val treeToNodeToIndexInfoB = Map((0, Map(
+ (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
val children2 = new Array[Node](2)
children2(0) = rootNode2.leftNode.get
children2(1) = rootNode2.rightNode.get
@@ -521,10 +485,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(metadata.isUnordered(featureIndex = 0))
assert(metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
assert(split.feature === 0)
@@ -544,7 +505,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2)
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 1.0)
+ DecisionTreeSuite.validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
}
@@ -561,7 +522,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2)
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 1.0)
+ DecisionTreeSuite.validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
assert(model.topNode.split.get.feature === 1)
@@ -581,14 +542,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(metadata.isUnordered(featureIndex = 1))
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 1.0)
+ DecisionTreeSuite.validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = model.topNode
val split = rootNode.split.get
assert(split.feature === 0)
@@ -610,12 +568,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 0.9)
+ DecisionTreeSuite.validateClassifier(model, arr, 0.9)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = model.topNode
val split = rootNode.split.get
assert(split.feature === 1)
@@ -635,12 +590,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(metadata.isUnordered(featureIndex = 0))
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 0.9)
+ DecisionTreeSuite.validateClassifier(model, arr, 0.9)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = model.topNode
val split = rootNode.split.get
assert(split.feature === 1)
@@ -660,10 +612,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
assert(split.feature === 0)
@@ -682,7 +631,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(strategy.isMulticlassClassification)
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 0.6)
+ DecisionTreeSuite.validateClassifier(model, arr, 0.6)
}
test("split must satisfy min instances per node requirements") {
@@ -691,24 +640,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
assert(model.topNode.isLeaf)
assert(model.topNode.predict == 0.0)
- val predicts = input.map(p => model.predict(p.features)).collect()
+ val predicts = rdd.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
}
- // test for findBestSplits when no valid split can be found
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ // test when no valid split can be found
+ val rootNode = model.topNode
val gain = rootNode.stats.get
assert(gain == InformationGainStats.invalidInformationGainStats)
@@ -723,15 +668,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
numClassesForClassification = 2, minInstancesPerNode = 2)
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
val gain = rootNode.stats.get
@@ -757,12 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(predict == 0.0)
}
- // test for findBestSplits when no valid split can be found
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ // test when no valid split can be found
+ val rootNode = model.topNode
val gain = rootNode.stats.get
assert(gain == InformationGainStats.invalidInformationGainStats)
@@ -771,6 +709,32 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
object DecisionTreeSuite {
+ def validateClassifier(
+ model: DecisionTreeModel,
+ input: Seq[LabeledPoint],
+ requiredAccuracy: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+ prediction != expected.label
+ }
+ val accuracy = (input.length - numOffPredictions).toDouble / input.length
+ assert(accuracy >= requiredAccuracy,
+ s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+ }
+
+ def validateRegressor(
+ model: DecisionTreeModel,
+ input: Seq[LabeledPoint],
+ requiredMSE: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+ val err = prediction - expected.label
+ err * err
+ }.sum
+ val mse = squaredError / input.length
+ assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+ }
+
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000) {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
new file mode 100644
index 0000000000..30669fcd1c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
+import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
+import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.util.StatCounter
+
+/**
+ * Test suite for [[RandomForest]].
+ */
+class RandomForestSuite extends FunSuite with LocalSparkContext {
+
+ test("BaggedPoint RDD: without subsampling") {
+ val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
+ val rdd = sc.parallelize(arr)
+ val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd)
+ baggedRDD.collect().foreach { baggedPoint =>
+ assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling") {
+ val numSubsamples = 100
+ val (expectedMean, expectedStddev) = (1.0, 1.0)
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+
+ test("Binary classification with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+
+ val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val numTrees = 1
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+ val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
+ featureSubsetStrategy = "auto", seed = 123)
+ assert(rf.trees.size === 1)
+ val rfTree = rf.trees(0)
+
+ val dt = DecisionTree.train(rdd, strategy)
+
+ RandomForestSuite.validateClassifier(rf, arr, 0.9)
+ DecisionTreeSuite.validateClassifier(dt, arr, 0.9)
+
+ // Make sure trees are the same.
+ assert(rfTree.toString == dt.toString)
+ }
+
+ test("Regression with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+
+ val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val numTrees = 1
+
+ val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+ val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
+ featureSubsetStrategy = "auto", seed = 123)
+ assert(rf.trees.size === 1)
+ val rfTree = rf.trees(0)
+
+ val dt = DecisionTree.train(rdd, strategy)
+
+ RandomForestSuite.validateRegressor(rf, arr, 0.01)
+ DecisionTreeSuite.validateRegressor(dt, arr, 0.01)
+
+ // Make sure trees are the same.
+ assert(rfTree.toString == dt.toString)
+ }
+
+ test("Binary classification with continuous features: subsampling features") {
+ val numFeatures = 50
+ val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+ // Select feature subset for top nodes. Return true if OK.
+ def checkFeatureSubsetStrategy(
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ numFeaturesPerNode: Int): Unit = {
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val maxMemoryUsage: Long = 128 * 1024L * 1024L
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy)
+ seeds.foreach { seed =>
+ val failString = s"Failed on test with:" +
+ s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
+ s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ val topNodes: Array[Node] = new Array[Node](numTrees)
+ Range(0, numTrees).foreach { treeIndex =>
+ topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1)
+ nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
+ }
+ val rng = new scala.util.Random(seed = seed)
+ val (nodesForGroup: Map[Int, Array[Node]],
+ treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
+ RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+
+ assert(nodesForGroup.size === numTrees, failString)
+ assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree
+ if (numFeaturesPerNode == numFeatures) {
+ // featureSubset values should all be None
+ assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
+ failString)
+ } else {
+ // Check number of features.
+ assert(treeToNodeToIndexInfo.values.forall(_.values.forall(
+ _.featureSubset.get.size === numFeaturesPerNode)), failString)
+ }
+ }
+ }
+
+ checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 1, "log2",
+ (math.log(numFeatures) / math.log(2)).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
+
+ checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "log2",
+ (math.log(numFeatures) / math.log(2)).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
+ }
+
+}
+
+object RandomForestSuite {
+
+ /**
+ * Aggregates all values in data, and tests whether the empirical mean and stddev are within
+ * epsilon of the expected values.
+ * @param data Every element of the data should be an i.i.d. sample from some distribution.
+ */
+ def testRandomArrays(
+ data: Array[Array[Double]],
+ numCols: Int,
+ expectedMean: Double,
+ expectedStddev: Double,
+ epsilon: Double) {
+ val values = new mutable.ArrayBuffer[Double]()
+ data.foreach { row =>
+ assert(row.size == numCols)
+ values ++= row
+ }
+ val stats = new StatCounter(values)
+ assert(math.abs(stats.mean - expectedMean) < epsilon)
+ assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ }
+
+ def validateClassifier(
+ model: RandomForestModel,
+ input: Seq[LabeledPoint],
+ requiredAccuracy: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+ prediction != expected.label
+ }
+ val accuracy = (input.length - numOffPredictions).toDouble / input.length
+ assert(accuracy >= requiredAccuracy,
+ s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+ }
+
+ def validateRegressor(
+ model: RandomForestModel,
+ input: Seq[LabeledPoint],
+ requiredMSE: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+ val err = prediction - expected.label
+ err * err
+ }.sum
+ val mse = squaredError / input.length
+ assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+ }
+
+ def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = {
+ val numInstances = 1000
+ val arr = new Array[LabeledPoint](numInstances)
+ for (i <- 0 until numInstances) {
+ val label = if (i < numInstances / 10) {
+ 0.0
+ } else if (i < numInstances / 2) {
+ 1.0
+ } else if (i < numInstances * 0.9) {
+ 0.0
+ } else {
+ 1.0
+ }
+ val features = Array.fill[Double](numFeatures)(i.toDouble)
+ arr(i) = new LabeledPoint(label, Vectors.dense(features))
+ }
+ arr
+ }
+
+}