aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/context.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 2f741cb345..c2b49ff37a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -141,14 +141,21 @@ class SparkContext(object):
def parallelize(self, c, numSlices=None):
"""
Distribute a local Python collection to form an RDD.
+
+ >>> sc.parallelize(range(5), 5).glom().collect()
+ [[0], [1], [2], [3], [4]]
"""
numSlices = numSlices or self.defaultParallelism
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
- if self.batchSize != 1:
- c = batched(c, self.batchSize)
+ # Make sure we distribute data evenly if it's smaller than self.batchSize
+ if "__len__" not in dir(c):
+ c = list(c) # Make it a list so we can compute its length
+ batchSize = min(len(c) // numSlices, self.batchSize)
+ if batchSize > 1:
+ c = batched(c, batchSize)
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()