diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-01-26 19:46:17 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-01-26 19:46:17 -0800 |
commit | d6894b1c5314c751cfdaf78005b99b2104e6e4d1 (patch) | |
tree | dc8c7c806097d81235c99de9a972eb356ab8eaf4 /mllib/src/test | |
parent | f2ba5c6fc3dde81a4d234c75dae2d4e3b46512d1 (diff) | |
download | spark-d6894b1c5314c751cfdaf78005b99b2104e6e4d1.tar.gz spark-d6894b1c5314c751cfdaf78005b99b2104e6e4d1.tar.bz2 spark-d6894b1c5314c751cfdaf78005b99b2104e6e4d1.zip |
[SPARK-3726] [MLlib] Allow sampling_rate not equal to 1.0 in RandomForests
I've added support for sampling_rate not equal to 1.0 . I have two major questions.
1. A Scala style test is failing, since the number of parameters now exceed 10.
2. I would like suggestions to understand how to test this.
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #4073 from MechCoder/spark-3726 and squashes the following commits:
8012fb2 [MechCoder] Add test in Strategy
e0e0d9c [MechCoder] TST: Add better test
d1df1b2 [MechCoder] Add test to verify subsampling behavior
a7bfc70 [MechCoder] [SPARK-3726] Allow sampling_rate not equal to 1.0
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index f7f0f20c6c..55e963977b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -196,6 +196,22 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { featureSubsetStrategy = "sqrt", seed = 12345) EnsembleTestHelper.validateClassifier(model, arr, 1.0) } + + test("subsampling rate in RandomForest"){ + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int], + useNodeIdCache = true) + + val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + strategy.subsamplingRate = 0.5 + val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3, + featureSubsetStrategy = "auto", seed = 123) + assert(rf1.toDebugString != rf2.toDebugString) + } + } |