aboutsummaryrefslogtreecommitdiff
path: root/core/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-06-26 21:46:55 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-06-26 21:46:55 -0700
commitc23f5db32b3bd4d965d56e5df684a3b814a91cd6 (patch)
treeb9eabeef9945a5b5938c523f66d01f647376e20b /core/src/test
parentd1636dd72fc4966413baeb97ba55b313dc1da63d (diff)
downloadspark-c23f5db32b3bd4d965d56e5df684a3b814a91cd6.tar.gz
spark-c23f5db32b3bd4d965d56e5df684a3b814a91cd6.tar.bz2
spark-c23f5db32b3bd4d965d56e5df684a3b814a91cd6.zip
[SPARK-2251] fix concurrency issues in random sampler
The following code is very likely to throw an exception: ~~~ val rdd = sc.parallelize(0 until 111, 10).sample(false, 0.1) rdd.zip(rdd).count() ~~~ because the same random number generator is used in compute partitions. Author: Xiangrui Meng <meng@databricks.com> Closes #1229 from mengxr/fix-sample and squashes the following commits: f1ee3d7 [Xiangrui Meng] fix concurrency issues in random sampler
Diffstat (limited to 'core/src/test')
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala18
-rw-r--r--core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala18
2 files changed, 26 insertions, 10 deletions
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
index 00c273df63..5dd8de319a 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.rdd
import org.scalatest.FunSuite
import org.apache.spark.SharedSparkContext
-import org.apache.spark.util.random.RandomSampler
+import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, RandomSampler}
/** a sampler that outputs its seed */
class MockSampler extends RandomSampler[Long, Long] {
@@ -32,7 +32,7 @@ class MockSampler extends RandomSampler[Long, Long] {
}
override def sample(items: Iterator[Long]): Iterator[Long] = {
- return Iterator(s)
+ Iterator(s)
}
override def clone = new MockSampler
@@ -40,11 +40,21 @@ class MockSampler extends RandomSampler[Long, Long] {
class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
- test("seedDistribution") {
+ test("seed distribution") {
val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
val sampler = new MockSampler
val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
- assert(sample.distinct.count == 2, "Seeds must be different.")
+ assert(sample.distinct().count == 2, "Seeds must be different.")
+ }
+
+ test("concurrency") {
+ // SPARK-2251: zip with self computes each partition twice.
+ // We want to make sure there are no concurrency issues.
+ val rdd = sc.parallelize(0 until 111, 10)
+ for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) {
+ val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler)
+ sampled.zip(sampled).count()
+ }
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
index e166787f17..36877476e7 100644
--- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
@@ -42,7 +42,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.25, 0.55)(random)
+ val sampler = new BernoulliSampler[Int](0.25, 0.55)
+ sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(3, 4, 5))
}
}
@@ -54,7 +55,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
+ val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
+ sampler.rng = random
assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9))
}
}
@@ -66,7 +68,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.35)(random)
+ val sampler = new BernoulliSampler[Int](0.35)
+ sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(1, 2, 3))
}
}
@@ -78,7 +81,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.25, 0.55, true)(random)
+ val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
+ sampler.rng = random
assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9))
}
}
@@ -88,7 +92,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
random.setSeed(10L)
}
whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.2)(random)
+ val sampler = new BernoulliSampler[Int](0.2)
+ sampler.rng = random
sampler.setSeed(10L)
}
}
@@ -100,7 +105,8 @@ class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
}
whenExecuting(poisson) {
- val sampler = new PoissonSampler[Int](0.2)(poisson)
+ val sampler = new PoissonSampler[Int](0.2)
+ sampler.rng = poisson
assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6))
}
}