From 3014803ead0aac31f36f4387c919174877525ff4 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 14 Nov 2014 12:43:17 -0800 Subject: [SPARK-4398][PySpark] specialize sc.parallelize(xrange) `sc.parallelize(range(1 << 20), 1).count()` may take 15 seconds to finish and the rdd object stores the entire list, making task size very large. This PR adds a specialized version for xrange. JoshRosen davies Author: Xiangrui Meng Closes #3264 from mengxr/SPARK-4398 and squashes the following commits: 8953c41 [Xiangrui Meng] follow davies' suggestion cbd58e3 [Xiangrui Meng] specialize sc.parallelize(xrange) (cherry picked from commit abd581752f9314791a688690c07ad1bb68cc09fe) Signed-off-by: Xiangrui Meng --- python/pyspark/context.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) (limited to 'python/pyspark/context.py') diff --git a/python/pyspark/context.py b/python/pyspark/context.py index faa5952258..b6c991453d 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -289,12 +289,29 @@ class SparkContext(object): def parallelize(self, c, numSlices=None): """ - Distribute a local Python collection to form an RDD. + Distribute a local Python collection to form an RDD. Using xrange + is recommended if the input represents a range for performance. - >>> sc.parallelize(range(5), 5).glom().collect() - [[0], [1], [2], [3], [4]] + >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect() + [[0], [2], [3], [4], [6]] + >>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect() + [[], [0], [], [2], [4]] """ - numSlices = numSlices or self.defaultParallelism + numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism + if isinstance(c, xrange): + size = len(c) + if size == 0: + return self.parallelize([], numSlices) + step = c[1] - c[0] if size > 1 else 1 + start0 = c[0] + + def getStart(split): + return start0 + (split * size / numSlices) * step + + def f(split, iterator): + return xrange(getStart(split), getStart(split + 1), step) + + return self.parallelize([], numSlices).mapPartitionsWithIndex(f) # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). -- cgit v1.2.3