aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorAiHe <ai.he@ussuning.com>2015-05-15 20:42:35 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-15 20:42:35 -0700
commitdeb411335a09b91eb1f75421d77e1c3686719621 (patch)
treebfb7e26708ebb3b8e33dda95e756c9a4141fb7f4 /mllib
parentd7b69946cb21cd2781c9ad3e691e54b28efbbf3d (diff)
downloadspark-deb411335a09b91eb1f75421d77e1c3686719621.tar.gz
spark-deb411335a09b91eb1f75421d77e1c3686719621.tar.bz2
spark-deb411335a09b91eb1f75421d77e1c3686719621.zip
[SPARK-7473] [MLLIB] Add reservoir sample in RandomForest
reservoir feature sample by using existing api Author: AiHe <ai.he@ussuning.com> Closes #5988 from AiHe/reservoir and squashes the following commits: e7a41ac [AiHe] remove non-robust testing case 28ffb9a [AiHe] set seed as rng.nextLong 37459e1 [AiHe] set fixed seed 1e98a4c [AiHe] [MLLIB][tree] Add reservoir sample in RandomForest
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala1
2 files changed, 3 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 055e60c7d9..b347c450c1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
+import org.apache.spark.util.random.SamplingUtils
/**
* :: Experimental ::
@@ -473,9 +474,8 @@ object RandomForest extends Serializable with Logging {
val (treeIndex, node) = nodeQueue.head
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
- // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir)
- Some(rng.shuffle(Range(0, metadata.numFeatures).toList)
- .take(metadata.numFeaturesPerNode).toArray)
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
} else {
None
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index ee3bc98486..4ed66953cb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -196,7 +196,6 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
featureSubsetStrategy = "sqrt", seed = 12345)
- EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
test("subsampling rate in RandomForest"){