diff options
author | Erik Erlandson <eerlands@redhat.com> | 2014-10-30 22:30:52 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-10-30 22:30:52 -0700 |
commit | ad3bd0dff8997861c5a04438145ba6f91c57a849 (patch) | |
tree | 8d99ab3fbef08b915b281f15e62ea9f6d4b9e32e /mllib/src | |
parent | 872fc669b497fb255db3212568f2a14c2ba0d5db (diff) | |
download | spark-ad3bd0dff8997861c5a04438145ba6f91c57a849.tar.gz spark-ad3bd0dff8997861c5a04438145ba6f91c57a849.tar.bz2 spark-ad3bd0dff8997861c5a04438145ba6f91c57a849.zip |
[SPARK-3250] Implement Gap Sampling optimization for random sampling
More efficient sampling, based on Gap Sampling optimization:
http://erikerlandson.github.io/blog/2014/09/11/faster-random-samples-with-gap-sampling/
Author: Erik Erlandson <eerlands@redhat.com>
Closes #2455 from erikerlandson/spark-3250-pr and squashes the following commits:
72496bc [Erik Erlandson] [SPARK-3250] Implement Gap Sampling optimization for random sampling
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index b88e08bf14..9353351af7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD -import org.apache.spark.util.random.BernoulliSampler +import org.apache.spark.util.random.BernoulliCellSampler import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.storage.StorageLevel @@ -244,7 +244,7 @@ object MLUtils { def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat (1 to numFolds).map { fold => - val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, + val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, complement = false) val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed) val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed) |