aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-01-26 19:46:17 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-26 19:46:17 -0800
commitd6894b1c5314c751cfdaf78005b99b2104e6e4d1 (patch)
treedc8c7c806097d81235c99de9a972eb356ab8eaf4
parentf2ba5c6fc3dde81a4d234c75dae2d4e3b46512d1 (diff)
downloadspark-d6894b1c5314c751cfdaf78005b99b2104e6e4d1.tar.gz
spark-d6894b1c5314c751cfdaf78005b99b2104e6e4d1.tar.bz2
spark-d6894b1c5314c751cfdaf78005b99b2104e6e4d1.zip
[SPARK-3726] [MLlib] Allow sampling_rate not equal to 1.0 in RandomForests
I've added support for sampling_rate not equal to 1.0 . I have two major questions. 1. A Scala style test is failing, since the number of parameters now exceed 10. 2. I would like suggestions to understand how to test this. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #4073 from MechCoder/spark-3726 and squashes the following commits: 8012fb2 [MechCoder] Add test in Strategy e0e0d9c [MechCoder] TST: Add better test d1df1b2 [MechCoder] Add test to verify subsampling behavior a7bfc70 [MechCoder] [SPARK-3726] Allow sampling_rate not equal to 1.0
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala16
3 files changed, 24 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index e9304b5e5c..482dd4b272 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -140,6 +140,7 @@ private class RandomForest (
logDebug("maxBins = " + metadata.maxBins)
logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
+ logDebug("subsamplingRate = " + strategy.subsamplingRate)
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
@@ -155,19 +156,12 @@ private class RandomForest (
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
- val (subsample, withReplacement) = {
- // TODO: Have a stricter check for RF in the strategy
- val isRandomForest = numTrees > 1
- if (isRandomForest) {
- (1.0, true)
- } else {
- (strategy.subsamplingRate, false)
- }
- }
+ val withReplacement = if (numTrees > 1) true else false
val baggedInput
- = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
- .persist(StorageLevel.MEMORY_AND_DISK)
+ = BaggedPoint.convertToBaggedRDD(treeInput,
+ strategy.subsamplingRate, numTrees,
+ withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
val maxDepth = strategy.maxDepth
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 972959885f..3308adb675 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -156,6 +156,9 @@ class Strategy (
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
require(maxMemoryInMB <= 10240,
s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
+ require(subsamplingRate > 0 && subsamplingRate <= 1,
+ s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " +
+ s"$subsamplingRate")
}
/** Returns a shallow copy of this instance. */
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index f7f0f20c6c..55e963977b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -196,6 +196,22 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
featureSubsetStrategy = "sqrt", seed = 12345)
EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
+
+ test("subsampling rate in RandomForest"){
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int],
+ useNodeIdCache = true)
+
+ val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
+ featureSubsetStrategy = "auto", seed = 123)
+ strategy.subsamplingRate = 0.5
+ val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
+ featureSubsetStrategy = "auto", seed = 123)
+ assert(rf1.toDebugString != rf2.toDebugString)
+ }
+
}