aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCodingCat <zhunansjtu@gmail.com>2014-03-16 22:14:59 -0700
committerMatei Zaharia <matei@databricks.com>2014-03-16 22:14:59 -0700
commitdc9654638f1d781ee1e54348fa41436b27793365 (patch)
tree6b0278c8ca06dd78add6f035aad90d1608354549
parentf5486e9f75d62919583da5ecf9a9ad00222b2227 (diff)
downloadspark-dc9654638f1d781ee1e54348fa41436b27793365.tar.gz
spark-dc9654638f1d781ee1e54348fa41436b27793365.tar.bz2
spark-dc9654638f1d781ee1e54348fa41436b27793365.zip
SPARK-1240: handle the case of empty RDD when takeSample
https://spark-project.atlassian.net/browse/SPARK-1240 It seems that the current implementation does not handle the empty RDD case when run takeSample In this patch, before calling sample() inside takeSample API, I add a checker for this case and returns an empty Array when it's a empty RDD; also in sample(), I add a checker for the invalid fraction value In the test case, I also add several lines for this case Author: CodingCat <zhunansjtu@gmail.com> Closes #135 from CodingCat/SPARK-1240 and squashes the following commits: fef57d4 [CodingCat] fix the same problem in PySpark 36db06b [CodingCat] create new test cases for takeSample from an empty red 810948d [CodingCat] further fix a40e8fb [CodingCat] replace if with require ad483fd [CodingCat] handle the case with empty RDD when take sample
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala7
-rw-r--r--python/pyspark/rdd.py4
3 files changed, 17 insertions, 1 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index b50c9963b9..f8283fbbb9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -310,6 +310,7 @@ abstract class RDD[T: ClassTag](
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
+ require(fraction >= 0.0, "Invalid fraction value: " + fraction)
if (withReplacement) {
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
} else {
@@ -344,6 +345,10 @@ abstract class RDD[T: ClassTag](
throw new IllegalArgumentException("Negative number of elements requested")
}
+ if (initialCount == 0) {
+ return new Array[T](0)
+ }
+
if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
@@ -362,7 +367,7 @@ abstract class RDD[T: ClassTag](
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
// If the first sample didn't turn out large enough, keep trying to take samples;
- // this shouldn't happen often because we use a big multiplier for thei initial size
+ // this shouldn't happen often because we use a big multiplier for the initial size
while (samples.length < total) {
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 60bcada552..9512e0e6ee 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -457,6 +457,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
test("takeSample") {
val data = sc.parallelize(1 to 100, 2)
+
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
@@ -488,6 +489,12 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}
+ test("takeSample from an empty rdd") {
+ val emptySet = sc.parallelize(Seq.empty[Int], 2)
+ val sample = emptySet.takeSample(false, 20, 1)
+ assert(sample.length === 0)
+ }
+
test("randomSplit") {
val n = 600
val data = sc.parallelize(1 to n, 2)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 6d549b40e5..f3b432ff24 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -268,6 +268,7 @@ class RDD(object):
>>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
[2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
"""
+ assert fraction >= 0.0, "Invalid fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
# this is ported from scala/spark/RDD.scala
@@ -288,6 +289,9 @@ class RDD(object):
if (num < 0):
raise ValueError
+ if (initialCount == 0):
+ return list()
+
if initialCount > sys.maxint - 1:
maxSelected = sys.maxint - 1
else: