aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-02-03 13:02:09 -0800
committerReynold Xin <rxin@apache.org>2014-02-03 13:02:09 -0800
commit23af00f9e0e5108f62cdb9629e3eb4e54bbaa321 (patch)
tree34211d41568e5f412393bfa2787f133cf79d10cf
parent1625d8c44693420de026138f3abecce2d12f895c (diff)
downloadspark-23af00f9e0e5108f62cdb9629e3eb4e54bbaa321.tar.gz
spark-23af00f9e0e5108f62cdb9629e3eb4e54bbaa321.tar.bz2
spark-23af00f9e0e5108f62cdb9629e3eb4e54bbaa321.zip
Merge pull request #528 from mengxr/sample. Closes #528.
Refactor RDD sampling and add randomSplit to RDD (update) Replace SampledRDD by PartitionwiseSampledRDD, which accepts a RandomSampler instance as input. The current sample with/without replacement can be easily integrated via BernoulliSampler and PoissonSampler. The benefits are: 1) RDD.randomSplit is implemented in the same way, related to https://github.com/apache/incubator-spark/pull/513 2) Stratified sampling and importance sampling can be implemented in the same manner as well. Unit tests are included for samplers and RDD.randomSplit. This should performance better than my previous request where the BernoulliSampler creates many Iterator instances: https://github.com/apache/incubator-spark/pull/513 Author: Xiangrui Meng <meng@databricks.com> == Merge branch commits == commit e8ce957e5f0a600f2dec057924f4a2ca6adba373 Author: Xiangrui Meng <meng@databricks.com> Date: Mon Feb 3 12:21:08 2014 -0800 more docs to PartitionwiseSampledRDD commit fbb4586d0478ff638b24bce95f75ff06f713d43b Author: Xiangrui Meng <meng@databricks.com> Date: Mon Feb 3 00:44:23 2014 -0800 move XORShiftRandom to util.random and use it in BernoulliSampler commit 987456b0ee8612fd4f73cb8c40967112dc3c4c2d Author: Xiangrui Meng <meng@databricks.com> Date: Sat Feb 1 11:06:59 2014 -0800 relax assertions in SortingSuite because the RangePartitioner has large variance in this case commit 3690aae416b2dc9b2f9ba32efa465ba7948477f4 Author: Xiangrui Meng <meng@databricks.com> Date: Sat Feb 1 09:56:28 2014 -0800 test split ratio of RDD.randomSplit commit 8a410bc933a60c4d63852606f8bbc812e416d6ae Author: Xiangrui Meng <meng@databricks.com> Date: Sat Feb 1 09:25:22 2014 -0800 add a test to ensure seed distribution and minor style update commit ce7e866f674c30ab48a9ceb09da846d5362ab4b6 Author: Xiangrui Meng <meng@databricks.com> Date: Fri Jan 31 18:06:22 2014 -0800 minor style change commit 750912b4d77596ed807d361347bd2b7e3b9b7a74 Author: Xiangrui Meng <meng@databricks.com> Date: Fri Jan 31 18:04:54 2014 -0800 fix some long lines commit c446a25c38d81db02821f7f194b0ce5ab4ed7ff5 Author: Xiangrui Meng <meng@databricks.com> Date: Fri Jan 31 17:59:59 2014 -0800 add complement to BernoulliSampler and minor style changes commit dbe2bc2bd888a7bdccb127ee6595840274499403 Author: Xiangrui Meng <meng@databricks.com> Date: Fri Jan 31 17:45:08 2014 -0800 switch to partition-wise sampling for better performance commit a1fca5232308feb369339eac67864c787455bb23 Merge: ac712e4 cf6128f Author: Xiangrui Meng <meng@databricks.com> Date: Fri Jan 31 16:33:09 2014 -0800 Merge branch 'sample' of github.com:mengxr/incubator-spark into sample commit cf6128fb672e8c589615adbd3eaa3cbdb72bd461 Author: Xiangrui Meng <meng@databricks.com> Date: Sun Jan 26 14:40:07 2014 -0800 set SampledRDD deprecated in 1.0 commit f430f847c3df91a3894687c513f23f823f77c255 Author: Xiangrui Meng <meng@databricks.com> Date: Sun Jan 26 14:38:59 2014 -0800 update code style commit a8b5e2021a9204e318c80a44d00c5c495f1befb6 Author: Xiangrui Meng <meng@databricks.com> Date: Sun Jan 26 12:56:27 2014 -0800 move package random to util.random commit ab0fa2c4965033737a9e3a9bf0a59cbb0df6a6f5 Author: Xiangrui Meng <meng@databricks.com> Date: Sun Jan 26 12:50:35 2014 -0800 add Apache headers and update code style commit 985609fe1a55655ad11966e05a93c18c138a403d Author: Xiangrui Meng <meng@databricks.com> Date: Sun Jan 26 11:49:25 2014 -0800 add new lines commit b21bddf29850a2c006a868869b8f91960a029322 Author: Xiangrui Meng <meng@databricks.com> Date: Sun Jan 26 11:46:35 2014 -0800 move samplers to random.IndependentRandomSampler and add tests commit c02dacb4a941618e434cefc129c002915db08be6 Author: Xiangrui Meng <meng@databricks.com> Date: Sat Jan 25 15:20:24 2014 -0800 add RandomSampler commit 8ff7ba3c5cf1fc338c29ae8b5fa06c222640e89c Author: Xiangrui Meng <meng@databricks.com> Date: Fri Jan 24 13:23:22 2014 -0800 init impl of IndependentlySampledRDD
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala65
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/Vector.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala94
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala (renamed from core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala)6
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala49
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala99
-rw-r--r--core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala (renamed from core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala)4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala3
13 files changed, 390 insertions, 16 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
new file mode 100644
index 0000000000..629f7074c1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.util.Random
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{TaskContext, Partition}
+import org.apache.spark.util.random.RandomSampler
+
+private[spark]
+class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
+ extends Partition with Serializable {
+ override val index: Int = prev.index
+}
+
+/**
+ * A RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD,
+ * a user-specified [[org.apache.spark.util.random.RandomSampler]] instance is used to obtain
+ * a random sample of the records in the partition. The random seeds assigned to the samplers
+ * are guaranteed to have different values.
+ *
+ * @param prev RDD to be sampled
+ * @param sampler a random sampler
+ * @param seed random seed, default to System.nanoTime
+ * @tparam T input RDD item type
+ * @tparam U sampled RDD item type
+ */
+class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
+ prev: RDD[T],
+ sampler: RandomSampler[T, U],
+ seed: Long = System.nanoTime)
+ extends RDD[U](prev) {
+
+ override def getPartitions: Array[Partition] = {
+ val random = new Random(seed)
+ firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong()))
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] =
+ firstParent[T].preferredLocations(split.asInstanceOf[PartitionwiseSampledRDDPartition].prev)
+
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[U] = {
+ val split = splitIn.asInstanceOf[PartitionwiseSampledRDDPartition]
+ val thisSampler = sampler.clone
+ thisSampler.setSeed(split.seed)
+ thisSampler.sample(firstParent[T].iterator(split.prev, context))
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 1472c92b60..033d334079 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -45,6 +45,7 @@ import org.apache.spark.util.{Utils, BoundedPriorityQueue, SerializableHyperLogL
import org.apache.spark.SparkContext._
import org.apache.spark._
+import org.apache.spark.util.random.{PoissonSampler, BernoulliSampler}
/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -319,8 +320,29 @@ abstract class RDD[T: ClassTag](
/**
* Return a sampled subset of this RDD.
*/
- def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
- new SampledRDD(this, withReplacement, fraction, seed)
+ def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
+ if (withReplacement) {
+ new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
+ } else {
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), seed)
+ }
+ }
+
+ /**
+ * Randomly splits this RDD with the provided weights.
+ *
+ * @param weights weights for splits, will be normalized if they don't sum to 1
+ * @param seed random seed, default to System.nanoTime
+ *
+ * @return split RDDs in an array
+ */
+ def randomSplit(weights: Array[Double], seed: Long = System.nanoTime): Array[RDD[T]] = {
+ val sum = weights.sum
+ val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
+ normalizedCumWeights.sliding(2).map { x =>
+ new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), seed)
+ }.toArray
+ }
def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
var fraction = 0.0
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
index d433670cc2..08534b6f1d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
@@ -25,11 +25,13 @@ import cern.jet.random.engine.DRand
import org.apache.spark.{Partition, TaskContext}
+@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0")
private[spark]
class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
override val index: Int = prev.index
}
+@deprecated("Replaced by PartitionwiseSampledRDD", "1.0")
class SampledRDD[T: ClassTag](
prev: RDD[T],
withReplacement: Boolean,
diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala
index fcdf848637..83fa0bf1e5 100644
--- a/core/src/main/scala/org/apache/spark/util/Vector.scala
+++ b/core/src/main/scala/org/apache/spark/util/Vector.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import scala.util.Random
+import org.apache.spark.util.random.XORShiftRandom
class Vector(val elements: Array[Double]) extends Serializable {
def length = elements.length
diff --git a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala
new file mode 100644
index 0000000000..98569143ee
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.random
+
+/**
+ * A class with pseudorandom behavior.
+ */
+trait Pseudorandom {
+ /** Set random seed. */
+ def setSeed(seed: Long)
+}
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
new file mode 100644
index 0000000000..6b66d54751
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.random
+
+import java.util.Random
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+
+/**
+ * A pseudorandom sampler. It is possible to change the sampled item type. For example, we might
+ * want to add weights for stratified sampling or importance sampling. Should only use
+ * transformations that are tied to the sampler and cannot be applied after sampling.
+ *
+ * @tparam T item type
+ * @tparam U sampled item type
+ */
+trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable {
+
+ /** take a random sample */
+ def sample(items: Iterator[T]): Iterator[U]
+
+ override def clone: RandomSampler[T, U] =
+ throw new NotImplementedError("clone() is not implemented.")
+}
+
+/**
+ * A sampler based on Bernoulli trials.
+ *
+ * @param lb lower bound of the acceptance range
+ * @param ub upper bound of the acceptance range
+ * @param complement whether to use the complement of the range specified, default to false
+ * @tparam T item type
+ */
+class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
+ (implicit random: Random = new XORShiftRandom)
+ extends RandomSampler[T, T] {
+
+ def this(ratio: Double)(implicit random: Random = new XORShiftRandom)
+ = this(0.0d, ratio)(random)
+
+ override def setSeed(seed: Long) = random.setSeed(seed)
+
+ override def sample(items: Iterator[T]): Iterator[T] = {
+ items.filter { item =>
+ val x = random.nextDouble()
+ (x >= lb && x < ub) ^ complement
+ }
+ }
+
+ override def clone = new BernoulliSampler[T](lb, ub)
+}
+
+/**
+ * A sampler based on values drawn from Poisson distribution.
+ *
+ * @param poisson a Poisson random number generator
+ * @tparam T item type
+ */
+class PoissonSampler[T](mean: Double)
+ (implicit var poisson: Poisson = new Poisson(mean, new DRand))
+ extends RandomSampler[T, T] {
+
+ override def setSeed(seed: Long) {
+ poisson = new Poisson(mean, new DRand(seed.toInt))
+ }
+
+ override def sample(items: Iterator[T]): Iterator[T] = {
+ items.flatMap { item =>
+ val count = poisson.nextInt()
+ if (count == 0) {
+ Iterator.empty
+ } else {
+ Iterator.fill(count)(item)
+ }
+ }
+ }
+
+ override def clone = new PoissonSampler[T](mean)
+}
diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
index 08b31ac64f..20d32d01b5 100644
--- a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.util
+package org.apache.spark.util.random
import java.util.{Random => JavaRandom}
import org.apache.spark.util.Utils.timeIt
@@ -46,6 +46,10 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
seed = nextSeed
(nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
}
+
+ override def setSeed(s: Long) {
+ seed = s
+ }
}
/** Contains benchmark method and main method to run benchmark of the RNG */
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
new file mode 100644
index 0000000000..cfe96fb3f7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.scalatest.FunSuite
+import org.apache.spark.SharedSparkContext
+import org.apache.spark.util.random.RandomSampler
+
+/** a sampler that outputs its seed */
+class MockSampler extends RandomSampler[Long, Long] {
+
+ private var s: Long = _
+
+ override def setSeed(seed: Long) {
+ s = seed
+ }
+
+ override def sample(items: Iterator[Long]): Iterator[Long] = {
+ return Iterator(s)
+ }
+
+ override def clone = new MockSampler
+}
+
+class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
+
+ test("seedDistribution") {
+ val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
+ val sampler = new MockSampler
+ val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
+ assert(sample.distinct.count == 2, "Seeds must be different.")
+ }
+}
+
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 559ea051d3..cd01303bad 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -486,6 +486,21 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}
+ test("randomSplit") {
+ val n = 600
+ val data = sc.parallelize(1 to n, 2)
+ for(seed <- 1 to 5) {
+ val splits = data.randomSplit(Array(1.0, 2.0, 3.0), seed)
+ assert(splits.size == 3, "wrong number of splits")
+ assert(splits.flatMap(_.collect).sorted.toList == data.collect.toList,
+ "incomplete or wrong split")
+ val s = splits.map(_.count)
+ assert(math.abs(s(0) - 100) < 50) // std = 9.13
+ assert(math.abs(s(1) - 200) < 50) // std = 11.55
+ assert(math.abs(s(2) - 300) < 50) // std = 12.25
+ }
+ }
+
test("runJob on an invalid partition") {
intercept[IllegalArgumentException] {
sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false)
diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
index 2f7bd370fc..e836119942 100644
--- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
@@ -98,10 +98,10 @@ class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers
assert(sorted.collect() === pairArr.sortBy(_._1))
val partitions = sorted.collectPartitions()
logInfo("Partition lengths: " + partitions.map(_.length).mkString(", "))
- partitions(0).length should be > 180
- partitions(1).length should be > 180
- partitions(2).length should be > 180
- partitions(3).length should be > 180
+ val lengthArr = partitions.map(_.length)
+ lengthArr.foreach { len =>
+ assert(len > 100 && len < 400)
+ }
partitions(0).last should be < partitions(1).head
partitions(1).last should be < partitions(2).head
partitions(2).last should be < partitions(3).head
@@ -113,10 +113,10 @@ class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers
assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
val partitions = sorted.collectPartitions()
logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
- partitions(0).length should be > 180
- partitions(1).length should be > 180
- partitions(2).length should be > 180
- partitions(3).length should be > 180
+ val lengthArr = partitions.map(_.length)
+ lengthArr.foreach { len =>
+ assert(len > 100 && len < 400)
+ }
partitions(0).last should be > partitions(1).head
partitions(1).last should be > partitions(2).head
partitions(2).last should be > partitions(3).head
diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
new file mode 100644
index 0000000000..0f4792cd3b
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.random
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.mock.EasyMockSugar
+
+import java.util.Random
+import cern.jet.random.Poisson
+
+class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar {
+
+ val a = List(1, 2, 3, 4, 5, 6, 7, 8, 9)
+
+ var random: Random = _
+ var poisson: Poisson = _
+
+ before {
+ random = mock[Random]
+ poisson = mock[Poisson]
+ }
+
+ test("BernoulliSamplerWithRange") {
+ expecting {
+ for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
+ random.nextDouble().andReturn(x)
+ }
+ }
+ whenExecuting(random)
+ {
+ val sampler = new BernoulliSampler[Int](0.25, 0.55)(random)
+ assert(sampler.sample(a.iterator).toList == List(3, 4, 5))
+ }
+ }
+
+ test("BernoulliSamplerWithRatio") {
+ expecting {
+ for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
+ random.nextDouble().andReturn(x)
+ }
+ }
+ whenExecuting(random)
+ {
+ val sampler = new BernoulliSampler[Int](0.35)(random)
+ assert(sampler.sample(a.iterator).toList == List(1, 2, 3))
+ }
+ }
+
+ test("BernoulliSamplerWithComplement") {
+ expecting {
+ for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
+ random.nextDouble().andReturn(x)
+ }
+ }
+ whenExecuting(random)
+ {
+ val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
+ assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9))
+ }
+ }
+
+ test("BernoulliSamplerSetSeed") {
+ expecting {
+ random.setSeed(10L)
+ }
+ whenExecuting(random)
+ {
+ val sampler = new BernoulliSampler[Int](0.2)(random)
+ sampler.setSeed(10L)
+ }
+ }
+
+ test("PoissonSampler") {
+ expecting {
+ for(x <- Seq(0, 1, 2, 0, 1, 1, 0, 0, 0)) {
+ poisson.nextInt().andReturn(x)
+ }
+ }
+ whenExecuting(poisson) {
+ val sampler = new PoissonSampler[Int](0.2)(poisson)
+ assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6))
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
index f1d7b61b31..352aa94219 100644
--- a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
@@ -15,10 +15,8 @@
* limitations under the License.
*/
-package org.apache.spark.util
+package org.apache.spark.util.random
-import java.util.Random
-import org.scalatest.FlatSpec
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.util.Utils.times
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 0dee9399a8..e508b76c3f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -26,8 +26,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.util.XORShiftRandom
-
+import org.apache.spark.util.random.XORShiftRandom
/**