diff options
author | AiHe <ai.he@ussuning.com> | 2015-05-15 20:42:35 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-05-15 20:42:35 -0700 |
commit | deb411335a09b91eb1f75421d77e1c3686719621 (patch) | |
tree | bfb7e26708ebb3b8e33dda95e756c9a4141fb7f4 /mllib/src | |
parent | d7b69946cb21cd2781c9ad3e691e54b28efbbf3d (diff) | |
download | spark-deb411335a09b91eb1f75421d77e1c3686719621.tar.gz spark-deb411335a09b91eb1f75421d77e1c3686719621.tar.bz2 spark-deb411335a09b91eb1f75421d77e1c3686719621.zip |
[SPARK-7473] [MLLIB] Add reservoir sample in RandomForest
reservoir feature sample by using existing api
Author: AiHe <ai.he@ussuning.com>
Closes #5988 from AiHe/reservoir and squashes the following commits:
e7a41ac [AiHe] remove non-robust testing case
28ffb9a [AiHe] set seed as rng.nextLong
37459e1 [AiHe] set fixed seed
1e98a4c [AiHe] [MLLIB][tree] Add reservoir sample in RandomForest
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala | 6 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala | 1 |
2 files changed, 3 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 055e60c7d9..b347c450c1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +import org.apache.spark.util.random.SamplingUtils /** * :: Experimental :: @@ -473,9 +474,8 @@ object RandomForest extends Serializable with Logging { val (treeIndex, node) = nodeQueue.head // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { - // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir) - Some(rng.shuffle(Range(0, metadata.numFeatures).toList) - .take(metadata.numFeaturesPerNode).toArray) + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index ee3bc98486..4ed66953cb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -196,7 +196,6 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext { numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, featureSubsetStrategy = "sqrt", seed = 12345) - EnsembleTestHelper.validateClassifier(model, arr, 1.0) } test("subsampling rate in RandomForest"){ |