aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-09-28 21:44:50 -0700
committerXiangrui Meng <meng@databricks.com>2014-09-28 21:44:50 -0700
commit0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe (patch)
treef3d82bc455227282e96e471b45a87fb07923edce /mllib/src/test
parentf350cd307045c2c02e713225d8f1247f18ba123e (diff)
downloadspark-0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe.tar.gz
spark-0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe.tar.bz2
spark-0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe.zip
[SPARK-1545] [mllib] Add Random Forests
This PR adds RandomForest to MLlib. The implementation is basic, and future performance optimizations will be important. (Note: RFs = Random Forests.) # Overview ## RandomForest * trains multiple trees at once to reduce the number of passes over the data * allows feature subsets at each node * uses a queue of nodes instead of fixed groups for each level This implementation is based an implementation by manishamde and the [Alpine Labs Sequoia Forest](https://github.com/AlpineNow/SparkML2) by codedeft (in particular, the TreePoint, BaggedPoint, and node queue implementations). Thank you for your inputs! ## Testing Correctness: This has been tested for correctness with the test suites and with DecisionTreeRunner on example datasets. Performance: This has been performance tested using [this branch of spark-perf](https://github.com/jkbradley/spark-perf/tree/rfs). Results below. ### Regression tests for DecisionTree Summary: For training 1 tree, there are small regressions, especially from feature subsampling. In the table below, each row is a single (random) dataset. The 2 different sets of result columns are for 2 different RF implementations: * (numTrees): This is from an earlier commit, after implementing RandomForest to train multiple trees at once. It does not include any code for feature subsampling. * (feature subsets): This is from this current PR's code, after implementing feature subsampling. These tests were to identify regressions in DecisionTree, so they are training 1 tree with all of the features (i.e., no feature subsampling). These were run on an EC2 cluster with 15 workers, training 1 tree with maxDepth = 5 (= 6 levels). Speedup values < 1 indicate slowdowns from the old DecisionTree implementation. numInstances | numFeatures | runtime (sec) | speedup | runtime (sec) | speedup ---- | ---- | ---- | ---- | ---- | ---- | | (numTrees) | (numTrees) | (feature subsets) | (feature subsets) 20000 | 100 | 4.051 | 1.044433473 | 4.478 | 0.9448414471 20000 | 500 | 8.472 | 1.104461756 | 9.315 | 1.004508857 20000 | 1500 | 19.354 | 1.05854087 | 20.863 | 0.9819776638 20000 | 3500 | 43.674 | 1.072033704 | 45.887 | 1.020332556 200000 | 100 | 4.196 | 1.171830315 | 4.848 | 1.014232673 200000 | 500 | 8.926 | 1.082791844 | 9.771 | 0.989151571 200000 | 1500 | 20.58 | 1.068415938 | 22.134 | 0.9934038131 200000 | 3500 | 48.043 | 1.075203464 | 52.249 | 0.9886505005 2000000 | 100 | 4.944 | 1.01355178 | 5.796 | 0.8645617667 2000000 | 500 | 11.11 | 1.016831683 | 12.482 | 0.9050632911 2000000 | 1500 | 31.144 | 1.017852556 | 35.274 | 0.8986789136 2000000 | 3500 | 79.981 | 1.085382778 | 101.105 | 0.8586123337 20000000 | 100 | 8.304 | 0.9270231214 | 9.073 | 0.8484514494 20000000 | 500 | 28.174 | 1.083268262 | 34.236 | 0.8914592826 20000000 | 1500 | 143.97 | 0.9579634646 | 159.275 | 0.8659111599 ### Tests for forests I have run other tests with numTrees=10 and with sqrt(numFeatures), and those indicate that multi-model training and feature subsets can speed up training for forests, especially when training deeper trees. # Details on specific classes ## Changes to DecisionTree * Main train() method is now in RandomForest. * findBestSplits() is no longer needed. (It split levels into groups, but we now use a queue of nodes.) * Many small changes to support RFs. (Note: These methods should be moved to RandomForest.scala in a later PR, but are in DecisionTree.scala to make code comparison easier.) ## RandomForest * Main train() method is from old DecisionTree. * selectNodesToSplit: Note that it selects nodes and feature subsets jointly to track memory usage. ## RandomForestModel * Stores an Array[DecisionTreeModel] * Prediction: * For classification, most common label. For regression, mean. * We could support other methods later. ## examples/.../DecisionTreeRunner * This now takes numTrees and featureSubsetStrategy, to support RFs. ## DTStatsAggregator * 2 types of functionality (w/ and w/o subsampling features): These require different indexing methods. (We could treat both as subsampling, but this is less efficient DTStatsAggregator is now abstract, and 2 child classes implement these 2 types of functionality. ## impurities * These now take instance weights. ## Node * Some vals changed to vars. * This is unfortunately a public API change (DeveloperApi). This could be avoided by creating a LearningNode struct, but would be awkward. ## RandomForestSuite Please let me know if there are missing tests! ## BaggedPoint This wraps TreePoint and holds bootstrap weights/counts. # Design decisions * BaggedPoint: BaggedPoint is separate from TreePoint since it may be useful for other bagging algorithms later on. * RandomForest public API: What options should be easily supported by the train* methods? Should ALL options be in the Java-friendly constructors? Should there be a constructor taking Strategy? * Feature subsampling options: What options should be supported? scikit-learn supports the same options, except for "onethird." One option would be to allow users to specific fractions ("0.1"): the current options could be supported, and any unrecognized values would be parsed as Doubles in [0,1]. * Splits and bins are computed before bootstrapping, so all trees use the same discretization. * One queue, instead of one queue per tree. CC: mengxr manishamde codedeft chouqin Please let me know if you have suggestions---thanks! Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Author: qiping.lqp <qiping.lqp@alibaba-inc.com> Author: chouqin <liqiping1991@gmail.com> Closes #2435 from jkbradley/rfs-new and squashes the following commits: c694174 [Joseph K. Bradley] Fixed typo cc59d78 [Joseph K. Bradley] fixed imports e25909f [Joseph K. Bradley] Simplified node group maps. Specifically, created NodeIndexInfo to store node index in agg and feature subsets, and no longer create extra maps in findBestSplits fbe9a1e [Joseph K. Bradley] Changed default featureSubsetStrategy to be sqrt for classification, onethird for regression. Updated docs with references. ef7c293 [Joseph K. Bradley] Updates based on code review. Most substantial changes: * Simplified DTStatsAggregator * Made RandomForestModel.trees public * Added test for regression to RandomForestSuite 593b13c [Joseph K. Bradley] Fixed bug in metadata for computing log2(num features). Now it checks >= 1. a1a08df [Joseph K. Bradley] Removed old comments 866e766 [Joseph K. Bradley] Changed RandomForestSuite randomized tests to use multiple fixed random seeds. ff8bb96 [Joseph K. Bradley] removed usage of null from RandomForest and replaced with Option bf1a4c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 6b79c07 [Joseph K. Bradley] Added RandomForestSuite, and fixed small bugs, style issues. d7753d4 [Joseph K. Bradley] Added numTrees and featureSubsetStrategy to DecisionTreeRunner (to support RandomForest). Fixed bugs so that RandomForest now runs. 746d43c [Joseph K. Bradley] Implemented feature subsampling. Tested DecisionTree but not RandomForest. 6309d1d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new. Added RandomForestModel.toString b7ae594 [Joseph K. Bradley] Updated docs. Small fix for bug which does not cause errors: No longer allocate unused child nodes for leaf nodes. 121c74e [Joseph K. Bradley] Basic random forests are implemented. Random features per node not yet implemented. Test suite not implemented. 325d18a [Joseph K. Bradley] Merge branch 'chouqin-dt-preprune' into rfs-new 4ef9bf1 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy. a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 6da8571 [Joseph K. Bradley] RFs partly implemented, not done yet eddd1eb [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 5c4ac33 [Joseph K. Bradley] Added check in Strategy to make sure minInstancesPerNode >= 1 0dd4d87 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py 306120f [Joseph K. Bradley] Fixed typo in DecisionTreeModel.scala doc eaa1dcf [Joseph K. Bradley] Added topNode doc in DecisionTree and scalastyle fix d4d7864 [Joseph K. Bradley] Marked Node.build as deprecated d4dbb99 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 1a8f0ad [Joseph K. Bradley] Eliminated pre-allocated nodes array in main train() method. * Nodes are constructed and added to the tree structure as needed during training. 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 2ab763b [Joseph K. Bradley] Simplifications to DecisionTree code: efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala210
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala245
2 files changed, 332 insertions, 123 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 2b2e579b99..a48ed71a1c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
+import scala.collection.mutable
import org.scalatest.FunSuite
@@ -26,39 +27,13 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
import org.apache.spark.mllib.util.LocalSparkContext
class DecisionTreeSuite extends FunSuite with LocalSparkContext {
- def validateClassifier(
- model: DecisionTreeModel,
- input: Seq[LabeledPoint],
- requiredAccuracy: Double) {
- val predictions = input.map(x => model.predict(x.features))
- val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
- prediction != expected.label
- }
- val accuracy = (input.length - numOffPredictions).toDouble / input.length
- assert(accuracy >= requiredAccuracy,
- s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
- }
-
- def validateRegressor(
- model: DecisionTreeModel,
- input: Seq[LabeledPoint],
- requiredMSE: Double) {
- val predictions = input.map(x => model.predict(x.features))
- val squaredError = predictions.zip(input).map { case (prediction, expected) =>
- val err = prediction - expected.label
- err * err
- }.sum
- val mse = squaredError / input.length
- assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
- }
-
test("Binary classification with continuous features: split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
@@ -233,7 +208,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 100,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
- // 2^10 - 1 > 100, so categorical features will be ordered
+ // 2^(10-1) - 1 > 100, so categorical features will be ordered
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
assert(!metadata.isUnordered(featureIndex = 0))
@@ -269,9 +244,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 0)
assert(bins(0).length === 0)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode: Node, doneTraining: Boolean) =
- DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
assert(split.categories === List(1.0))
@@ -299,10 +272,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
assert(split.categories.length === 1)
@@ -331,7 +301,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(!metadata.isUnordered(featureIndex = 1))
val model = DecisionTree.train(rdd, strategy)
- validateRegressor(model, arr, 0.0)
+ DecisionTreeSuite.validateRegressor(model, arr, 0.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
}
@@ -352,12 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins.length === 2)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
-
- val split = rootNode.split.get
- assert(split.feature === 0)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val stats = rootNode.stats.get
assert(stats.gain === 0)
@@ -381,12 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins.length === 2)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
-
- val split = rootNode.split.get
- assert(split.feature === 0)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val stats = rootNode.stats.get
assert(stats.gain === 0)
@@ -411,12 +371,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins.length === 2)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
-
- val split = rootNode.split.get
- assert(split.feature === 0)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val stats = rootNode.stats.get
assert(stats.gain === 0)
@@ -441,12 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(bins.length === 2)
assert(bins(0).length === 100)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
-
- val split = rootNode.split.get
- assert(split.feature === 0)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val stats = rootNode.stats.get
assert(stats.gain === 0)
@@ -471,25 +421,39 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
numClassesForClassification = 2, maxBins = 100)
val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
- val rootNodeCopy1 = modelOneNode.topNode.deepCopy()
- val rootNodeCopy2 = modelOneNode.topNode.deepCopy()
+ val rootNode1 = modelOneNode.topNode.deepCopy()
+ val rootNode2 = modelOneNode.topNode.deepCopy()
+ assert(rootNode1.leftNode.nonEmpty)
+ assert(rootNode1.rightNode.nonEmpty)
- // Single group second level tree construction.
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
- rootNodeCopy1, splits, bins, 10)
- assert(rootNode.leftNode.nonEmpty)
- assert(rootNode.rightNode.nonEmpty)
+ val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+
+ // Single group second level tree construction.
+ val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
+ (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
val children1 = new Array[Node](2)
- children1(0) = rootNode.leftNode.get
- children1(1) = rootNode.rightNode.get
-
- // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
- // level tree construction.
- val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
- rootNodeCopy2, splits, bins, 0)
- assert(rootNode2.leftNode.nonEmpty)
- assert(rootNode2.rightNode.nonEmpty)
+ children1(0) = rootNode1.leftNode.get
+ children1(1) = rootNode1.rightNode.get
+
+ // Train one second-level node at a time.
+ val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
+ val treeToNodeToIndexInfoA = Map((0, Map(
+ (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
+ val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
+ val treeToNodeToIndexInfoB = Map((0, Map(
+ (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
val children2 = new Array[Node](2)
children2(0) = rootNode2.leftNode.get
children2(1) = rootNode2.rightNode.get
@@ -521,10 +485,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(metadata.isUnordered(featureIndex = 0))
assert(metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
assert(split.feature === 0)
@@ -544,7 +505,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2)
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 1.0)
+ DecisionTreeSuite.validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
}
@@ -561,7 +522,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 2)
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 1.0)
+ DecisionTreeSuite.validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
assert(model.topNode.split.get.feature === 1)
@@ -581,14 +542,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(metadata.isUnordered(featureIndex = 1))
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 1.0)
+ DecisionTreeSuite.validateClassifier(model, arr, 1.0)
assert(model.numNodes === 3)
assert(model.depth === 1)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = model.topNode
val split = rootNode.split.get
assert(split.feature === 0)
@@ -610,12 +568,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 0.9)
+ DecisionTreeSuite.validateClassifier(model, arr, 0.9)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = model.topNode
val split = rootNode.split.get
assert(split.feature === 1)
@@ -635,12 +590,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(metadata.isUnordered(featureIndex = 0))
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 0.9)
+ DecisionTreeSuite.validateClassifier(model, arr, 0.9)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = model.topNode
val split = rootNode.split.get
assert(split.feature === 1)
@@ -660,10 +612,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
assert(split.feature === 0)
@@ -682,7 +631,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(strategy.isMulticlassClassification)
val model = DecisionTree.train(rdd, strategy)
- validateClassifier(model, arr, 0.6)
+ DecisionTreeSuite.validateClassifier(model, arr, 0.6)
}
test("split must satisfy min instances per node requirements") {
@@ -691,24 +640,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2)
- val model = DecisionTree.train(input, strategy)
+ val model = DecisionTree.train(rdd, strategy)
assert(model.topNode.isLeaf)
assert(model.topNode.predict == 0.0)
- val predicts = input.map(p => model.predict(p.features)).collect()
+ val predicts = rdd.map(p => model.predict(p.features)).collect()
predicts.foreach { predict =>
assert(predict == 0.0)
}
- // test for findBestSplits when no valid split can be found
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ // test when no valid split can be found
+ val rootNode = model.topNode
val gain = rootNode.stats.get
assert(gain == InformationGainStats.invalidInformationGainStats)
@@ -723,15 +668,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
- val input = sc.parallelize(arr)
+ val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
numClassesForClassification = 2, minInstancesPerNode = 2)
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+
+ val rootNode = DecisionTree.train(rdd, strategy).topNode
val split = rootNode.split.get
val gain = rootNode.stats.get
@@ -757,12 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(predict == 0.0)
}
- // test for findBestSplits when no valid split can be found
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
- null, splits, bins, 10)
+ // test when no valid split can be found
+ val rootNode = model.topNode
val gain = rootNode.stats.get
assert(gain == InformationGainStats.invalidInformationGainStats)
@@ -771,6 +709,32 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
object DecisionTreeSuite {
+ def validateClassifier(
+ model: DecisionTreeModel,
+ input: Seq[LabeledPoint],
+ requiredAccuracy: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+ prediction != expected.label
+ }
+ val accuracy = (input.length - numOffPredictions).toDouble / input.length
+ assert(accuracy >= requiredAccuracy,
+ s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+ }
+
+ def validateRegressor(
+ model: DecisionTreeModel,
+ input: Seq[LabeledPoint],
+ requiredMSE: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+ val err = prediction - expected.label
+ err * err
+ }.sum
+ val mse = squaredError / input.length
+ assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+ }
+
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000) {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
new file mode 100644
index 0000000000..30669fcd1c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
+import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
+import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.util.StatCounter
+
+/**
+ * Test suite for [[RandomForest]].
+ */
+class RandomForestSuite extends FunSuite with LocalSparkContext {
+
+ test("BaggedPoint RDD: without subsampling") {
+ val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
+ val rdd = sc.parallelize(arr)
+ val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd)
+ baggedRDD.collect().foreach { baggedPoint =>
+ assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling") {
+ val numSubsamples = 100
+ val (expectedMean, expectedStddev) = (1.0, 1.0)
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+
+ test("Binary classification with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+
+ val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val numTrees = 1
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+ val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
+ featureSubsetStrategy = "auto", seed = 123)
+ assert(rf.trees.size === 1)
+ val rfTree = rf.trees(0)
+
+ val dt = DecisionTree.train(rdd, strategy)
+
+ RandomForestSuite.validateClassifier(rf, arr, 0.9)
+ DecisionTreeSuite.validateClassifier(dt, arr, 0.9)
+
+ // Make sure trees are the same.
+ assert(rfTree.toString == dt.toString)
+ }
+
+ test("Regression with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+
+ val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val numTrees = 1
+
+ val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+ val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
+ featureSubsetStrategy = "auto", seed = 123)
+ assert(rf.trees.size === 1)
+ val rfTree = rf.trees(0)
+
+ val dt = DecisionTree.train(rdd, strategy)
+
+ RandomForestSuite.validateRegressor(rf, arr, 0.01)
+ DecisionTreeSuite.validateRegressor(dt, arr, 0.01)
+
+ // Make sure trees are the same.
+ assert(rfTree.toString == dt.toString)
+ }
+
+ test("Binary classification with continuous features: subsampling features") {
+ val numFeatures = 50
+ val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+
+ // Select feature subset for top nodes. Return true if OK.
+ def checkFeatureSubsetStrategy(
+ numTrees: Int,
+ featureSubsetStrategy: String,
+ numFeaturesPerNode: Int): Unit = {
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val maxMemoryUsage: Long = 128 * 1024L * 1024L
+ val metadata =
+ DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy)
+ seeds.foreach { seed =>
+ val failString = s"Failed on test with:" +
+ s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," +
+ s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed"
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ val topNodes: Array[Node] = new Array[Node](numTrees)
+ Range(0, numTrees).foreach { treeIndex =>
+ topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1)
+ nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))
+ }
+ val rng = new scala.util.Random(seed = seed)
+ val (nodesForGroup: Map[Int, Array[Node]],
+ treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) =
+ RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
+
+ assert(nodesForGroup.size === numTrees, failString)
+ assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree
+ if (numFeaturesPerNode == numFeatures) {
+ // featureSubset values should all be None
+ assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)),
+ failString)
+ } else {
+ // Check number of features.
+ assert(treeToNodeToIndexInfo.values.forall(_.values.forall(
+ _.featureSubset.get.size === numFeaturesPerNode)), failString)
+ }
+ }
+ }
+
+ checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 1, "log2",
+ (math.log(numFeatures) / math.log(2)).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)
+
+ checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
+ checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "log2",
+ (math.log(numFeatures) / math.log(2)).ceil.toInt)
+ checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
+ }
+
+}
+
+object RandomForestSuite {
+
+ /**
+ * Aggregates all values in data, and tests whether the empirical mean and stddev are within
+ * epsilon of the expected values.
+ * @param data Every element of the data should be an i.i.d. sample from some distribution.
+ */
+ def testRandomArrays(
+ data: Array[Array[Double]],
+ numCols: Int,
+ expectedMean: Double,
+ expectedStddev: Double,
+ epsilon: Double) {
+ val values = new mutable.ArrayBuffer[Double]()
+ data.foreach { row =>
+ assert(row.size == numCols)
+ values ++= row
+ }
+ val stats = new StatCounter(values)
+ assert(math.abs(stats.mean - expectedMean) < epsilon)
+ assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ }
+
+ def validateClassifier(
+ model: RandomForestModel,
+ input: Seq[LabeledPoint],
+ requiredAccuracy: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+ prediction != expected.label
+ }
+ val accuracy = (input.length - numOffPredictions).toDouble / input.length
+ assert(accuracy >= requiredAccuracy,
+ s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+ }
+
+ def validateRegressor(
+ model: RandomForestModel,
+ input: Seq[LabeledPoint],
+ requiredMSE: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+ val err = prediction - expected.label
+ err * err
+ }.sum
+ val mse = squaredError / input.length
+ assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+ }
+
+ def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = {
+ val numInstances = 1000
+ val arr = new Array[LabeledPoint](numInstances)
+ for (i <- 0 until numInstances) {
+ val label = if (i < numInstances / 10) {
+ 0.0
+ } else if (i < numInstances / 2) {
+ 1.0
+ } else if (i < numInstances * 0.9) {
+ 0.0
+ } else {
+ 1.0
+ }
+ val features = Array.fill[Double](numFeatures)(i.toDouble)
+ arr(i) = new LabeledPoint(label, Vectors.dense(features))
+ }
+ arr
+ }
+
+}