diff options
author | CodingCat <zhunansjtu@gmail.com> | 2014-03-16 22:14:59 -0700 |
---|---|---|
committer | Matei Zaharia <matei@databricks.com> | 2014-03-16 22:14:59 -0700 |
commit | dc9654638f1d781ee1e54348fa41436b27793365 (patch) | |
tree | 6b0278c8ca06dd78add6f035aad90d1608354549 /core | |
parent | f5486e9f75d62919583da5ecf9a9ad00222b2227 (diff) | |
download | spark-dc9654638f1d781ee1e54348fa41436b27793365.tar.gz spark-dc9654638f1d781ee1e54348fa41436b27793365.tar.bz2 spark-dc9654638f1d781ee1e54348fa41436b27793365.zip |
SPARK-1240: handle the case of empty RDD when takeSample
https://spark-project.atlassian.net/browse/SPARK-1240
It seems that the current implementation does not handle the empty RDD case when run takeSample
In this patch, before calling sample() inside takeSample API, I add a checker for this case and returns an empty Array when it's a empty RDD; also in sample(), I add a checker for the invalid fraction value
In the test case, I also add several lines for this case
Author: CodingCat <zhunansjtu@gmail.com>
Closes #135 from CodingCat/SPARK-1240 and squashes the following commits:
fef57d4 [CodingCat] fix the same problem in PySpark
36db06b [CodingCat] create new test cases for takeSample from an empty red
810948d [CodingCat] further fix
a40e8fb [CodingCat] replace if with require
ad483fd [CodingCat] handle the case with empty RDD when take sample
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/RDD.scala | 7 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 7 |
2 files changed, 13 insertions, 1 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b50c9963b9..f8283fbbb9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -310,6 +310,7 @@ abstract class RDD[T: ClassTag]( * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = { + require(fraction >= 0.0, "Invalid fraction value: " + fraction) if (withReplacement) { new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed) } else { @@ -344,6 +345,10 @@ abstract class RDD[T: ClassTag]( throw new IllegalArgumentException("Negative number of elements requested") } + if (initialCount == 0) { + return new Array[T](0) + } + if (initialCount > Integer.MAX_VALUE - 1) { maxSelected = Integer.MAX_VALUE - 1 } else { @@ -362,7 +367,7 @@ abstract class RDD[T: ClassTag]( var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for thei initial size + // this shouldn't happen often because we use a big multiplier for the initial size while (samples.length < total) { samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 60bcada552..9512e0e6ee 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -457,6 +457,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("takeSample") { val data = sc.parallelize(1 to 100, 2) + for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements @@ -488,6 +489,12 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("takeSample from an empty rdd") { + val emptySet = sc.parallelize(Seq.empty[Int], 2) + val sample = emptySet.takeSample(false, 20, 1) + assert(sample.length === 0) + } + test("randomSplit") { val n = 600 val data = sc.parallelize(1 to n, 2) |