aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorErik Erlandson <eerlands@redhat.com>2014-10-30 22:30:52 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-30 22:30:52 -0700
commitad3bd0dff8997861c5a04438145ba6f91c57a849 (patch)
tree8d99ab3fbef08b915b281f15e62ea9f6d4b9e32e /mllib
parent872fc669b497fb255db3212568f2a14c2ba0d5db (diff)
downloadspark-ad3bd0dff8997861c5a04438145ba6f91c57a849.tar.gz
spark-ad3bd0dff8997861c5a04438145ba6f91c57a849.tar.bz2
spark-ad3bd0dff8997861c5a04438145ba6f91c57a849.zip
[SPARK-3250] Implement Gap Sampling optimization for random sampling
More efficient sampling, based on Gap Sampling optimization: http://erikerlandson.github.io/blog/2014/09/11/faster-random-samples-with-gap-sampling/ Author: Erik Erlandson <eerlands@redhat.com> Closes #2455 from erikerlandson/spark-3250-pr and squashes the following commits: 72496bc [Erik Erlandson] [SPARK-3250] Implement Gap Sampling optimization for random sampling
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala4
1 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index b88e08bf14..9353351af7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PartitionwiseSampledRDD
-import org.apache.spark.util.random.BernoulliSampler
+import org.apache.spark.util.random.BernoulliCellSampler
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.storage.StorageLevel
@@ -244,7 +244,7 @@ object MLUtils {
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
val numFoldsF = numFolds.toFloat
(1 to numFolds).map { fold =>
- val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
+ val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
complement = false)
val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed)
val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed)