aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-14 12:43:17 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-14 12:43:25 -0800
commit3014803ead0aac31f36f4387c919174877525ff4 (patch)
tree0c6fd3005a0cc7922da595ff28eb545688a6b17c /python
parent3219271f403091d4d3af4cddd08121ba538a459b (diff)
downloadspark-3014803ead0aac31f36f4387c919174877525ff4.tar.gz
spark-3014803ead0aac31f36f4387c919174877525ff4.tar.bz2
spark-3014803ead0aac31f36f4387c919174877525ff4.zip
[SPARK-4398][PySpark] specialize sc.parallelize(xrange)
`sc.parallelize(range(1 << 20), 1).count()` may take 15 seconds to finish and the rdd object stores the entire list, making task size very large. This PR adds a specialized version for xrange. JoshRosen davies Author: Xiangrui Meng <meng@databricks.com> Closes #3264 from mengxr/SPARK-4398 and squashes the following commits: 8953c41 [Xiangrui Meng] follow davies' suggestion cbd58e3 [Xiangrui Meng] specialize sc.parallelize(xrange) (cherry picked from commit abd581752f9314791a688690c07ad1bb68cc09fe) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/context.py25
1 files changed, 21 insertions, 4 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index faa5952258..b6c991453d 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -289,12 +289,29 @@ class SparkContext(object):
def parallelize(self, c, numSlices=None):
"""
- Distribute a local Python collection to form an RDD.
+ Distribute a local Python collection to form an RDD. Using xrange
+ is recommended if the input represents a range for performance.
- >>> sc.parallelize(range(5), 5).glom().collect()
- [[0], [1], [2], [3], [4]]
+ >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect()
+ [[0], [2], [3], [4], [6]]
+ >>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect()
+ [[], [0], [], [2], [4]]
"""
- numSlices = numSlices or self.defaultParallelism
+ numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism
+ if isinstance(c, xrange):
+ size = len(c)
+ if size == 0:
+ return self.parallelize([], numSlices)
+ step = c[1] - c[0] if size > 1 else 1
+ start0 = c[0]
+
+ def getStart(split):
+ return start0 + (split * size / numSlices) * step
+
+ def f(split, iterator):
+ return xrange(getStart(split), getStart(split + 1), step)
+
+ return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().