aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorAiHe <ai.he@ussuning.com>2015-05-15 20:42:35 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-15 20:42:59 -0700
commitf41be8fb38608c79ff69a85f0715de5ebd3ae2a5 (patch)
tree688ccc3132da4116137401be183a3be358e880da /mllib/src/main
parent8164fbc2557487b5b4a11dcf2d02c93f0141e1fc (diff)
downloadspark-f41be8fb38608c79ff69a85f0715de5ebd3ae2a5.tar.gz
spark-f41be8fb38608c79ff69a85f0715de5ebd3ae2a5.tar.bz2
spark-f41be8fb38608c79ff69a85f0715de5ebd3ae2a5.zip
[SPARK-7473] [MLLIB] Add reservoir sample in RandomForest
reservoir feature sample by using existing api Author: AiHe <ai.he@ussuning.com> Closes #5988 from AiHe/reservoir and squashes the following commits: e7a41ac [AiHe] remove non-robust testing case 28ffb9a [AiHe] set seed as rng.nextLong 37459e1 [AiHe] set fixed seed 1e98a4c [AiHe] [MLLIB][tree] Add reservoir sample in RandomForest (cherry picked from commit deb411335a09b91eb1f75421d77e1c3686719621) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala6
1 files changed, 3 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 055e60c7d9..b347c450c1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -36,6 +36,7 @@ import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
+import org.apache.spark.util.random.SamplingUtils
/**
* :: Experimental ::
@@ -473,9 +474,8 @@ object RandomForest extends Serializable with Logging {
val (treeIndex, node) = nodeQueue.head
// Choose subset of features for node (if subsampling).
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
- // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir)
- Some(rng.shuffle(Range(0, metadata.numFeatures).toList)
- .take(metadata.numFeaturesPerNode).toArray)
+ Some(SamplingUtils.reservoirSampleAndCount(Range(0,
+ metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)
} else {
None
}