aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala1
2 files changed, 3 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 055e60c7d9..b347c450c1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
+import org.apache.spark.util.random.SamplingUtils
/**
* :: Experimental ::
@@ -473,9 +474,8 @@ object RandomForest extends Serializable with Logging {
val (treeIndex, node) = nodeQueue.head
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
- // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir)
- Some(rng.shuffle(Range(0, metadata.numFeatures).toList)
- .take(metadata.numFeaturesPerNode).toArray)
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
} else {
None
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index ee3bc98486..4ed66953cb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -196,7 +196,6 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
numClasses = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
featureSubsetStrategy = "sqrt", seed = 12345)
- EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
test("subsampling rate in RandomForest"){