diff options
author | Davies Liu <davies@databricks.com> | 2014-11-07 20:53:03 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@databricks.com> | 2014-11-07 20:53:03 -0800 |
commit | 7779109796c90d789464ab0be35917f963bbe867 (patch) | |
tree | f3086d40f5f144c436b70ecc81919d83cf5a30a6 | |
parent | 5923dd986ba26d0fcc8707dd8d16863f1c1005cb (diff) | |
download | spark-7779109796c90d789464ab0be35917f963bbe867.tar.gz spark-7779109796c90d789464ab0be35917f963bbe867.tar.bz2 spark-7779109796c90d789464ab0be35917f963bbe867.zip |
[SPARK-4304] [PySpark] Fix sort on empty RDD
This PR fix sortBy()/sortByKey() on empty RDD.
This should be back ported into 1.1/1.2
Author: Davies Liu <davies@databricks.com>
Closes #3162 from davies/fix_sort and squashes the following commits:
84f64b7 [Davies Liu] add tests
52995b5 [Davies Liu] fix sortByKey() on empty RDD
-rw-r--r-- | python/pyspark/rdd.py | 2 | ||||
-rw-r--r-- | python/pyspark/tests.py | 3 |
2 files changed, 5 insertions, 0 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 879655dc53..08d0474026 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -521,6 +521,8 @@ class RDD(object): # the key-space into bins such that the bins have roughly the same # number of (key, value) pairs falling into them rddSize = self.count() + if not rddSize: + return self # empty RDD maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 9f625c5c6c..491e445a21 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -649,6 +649,9 @@ class RDDTests(ReusedPySparkTestCase): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_sort_on_empty_rdd(self): + self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) + def test_sample(self): rdd = self.sc.parallelize(range(0, 100), 4) wo = rdd.sample(False, 0.1, 2).collect() |