aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-07-29 00:09:11 -0400
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-07-29 02:51:43 -0400
commitfeba7ee540fca28872957120e5e39b9e36466953 (patch)
treec4349aa082e6727f638bc360ba6d9352a88959bc /python/pyspark
parentd75c3086951f603ec30b2527c24559e053ed7f25 (diff)
downloadspark-feba7ee540fca28872957120e5e39b9e36466953.tar.gz
spark-feba7ee540fca28872957120e5e39b9e36466953.tar.bz2
spark-feba7ee540fca28872957120e5e39b9e36466953.zip
SPARK-815. Python parallelize() should split lists before batching
One unfortunate consequence of this fix is that we materialize any collections that are given to us as generators, but this seems necessary to get reasonable behavior on small collections. We could add a batchSize parameter later to bypass auto-computation of batch size if this becomes a problem (e.g. if users really want to parallelize big generators nicely)
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/context.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 2f741cb345..c2b49ff37a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -141,14 +141,21 @@ class SparkContext(object):
def parallelize(self, c, numSlices=None):
"""
Distribute a local Python collection to form an RDD.
+
+ >>> sc.parallelize(range(5), 5).glom().collect()
+ [[0], [1], [2], [3], [4]]
"""
numSlices = numSlices or self.defaultParallelism
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
- if self.batchSize != 1:
- c = batched(c, self.batchSize)
+ # Make sure we distribute data evenly if it's smaller than self.batchSize
+ if "__len__" not in dir(c):
+ c = list(c) # Make it a list so we can compute its length
+ batchSize = min(len(c) // numSlices, self.batchSize)
+ if batchSize > 1:
+ c = batched(c, batchSize)
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()