aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-29 15:34:57 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-29 15:34:57 -0800
commit26186e2d259f3aa2db9c8594097fd342107ce147 (patch)
treead5ddb982a9989c9f2ced19022dc6e34c3666d9f /pyspark
parent6ee1ff2663cf1f776dd33e448548a8ddcf974dc6 (diff)
downloadspark-26186e2d259f3aa2db9c8594097fd342107ce147.tar.gz
spark-26186e2d259f3aa2db9c8594097fd342107ce147.tar.bz2
spark-26186e2d259f3aa2db9c8594097fd342107ce147.zip
Use batching in pyspark parallelize(); fix cartesian()
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/context.py4
-rw-r--r--pyspark/pyspark/rdd.py31
-rw-r--r--pyspark/pyspark/serializers.py23
3 files changed, 31 insertions, 27 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index b90596ecc2..6172d69dcf 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile
from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length
+from pyspark.serializers import dump_pickle, write_with_length, batched
from pyspark.rdd import RDD
from py4j.java_collections import ListConverter
@@ -91,6 +91,8 @@ class SparkContext(object):
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False)
atexit.register(lambda: os.unlink(tempFile.name))
+ if self.batchSize != 1:
+ c = batched(c, self.batchSize)
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 20f84b2dd0..203f7377d2 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -2,7 +2,7 @@ import atexit
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
-from itertools import chain, ifilter, imap
+from itertools import chain, ifilter, imap, product
import operator
import os
import shlex
@@ -123,12 +123,6 @@ class RDD(object):
>>> rdd = sc.parallelize([1, 1, 2, 3])
>>> rdd.union(rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
-
- Union of batched and unbatched RDDs (internal test):
-
- >>> batchedRDD = sc.parallelize([Batch([1, 2, 3, 4, 5])])
- >>> rdd.union(batchedRDD).collect()
- [1, 1, 2, 3, 1, 2, 3, 4, 5]
"""
return RDD(self._jrdd.union(other._jrdd), self.ctx)
@@ -168,7 +162,18 @@ class RDD(object):
>>> sorted(rdd.cartesian(rdd).collect())
[(1, 1), (1, 2), (2, 1), (2, 2)]
"""
- return RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
+ # Due to batching, we can't use the Java cartesian method.
+ java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
+ def unpack_batches(pair):
+ (x, y) = pair
+ if type(x) == Batch or type(y) == Batch:
+ xs = x.items if type(x) == Batch else [x]
+ ys = y.items if type(y) == Batch else [y]
+ for pair in product(xs, ys):
+ yield pair
+ else:
+ yield pair
+ return java_cartesian.flatMap(unpack_batches)
def groupBy(self, f, numSplits=None):
"""
@@ -293,8 +298,6 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4]).count()
3
- >>> sc.parallelize([Batch([2, 3, 4])]).count()
- 3
"""
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
@@ -667,12 +670,8 @@ class PipelinedRDD(RDD):
if not self._bypass_serializer and self.ctx.batchSize != 1:
oldfunc = self.func
batchSize = self.ctx.batchSize
- if batchSize == -1: # unlimited batch size
- def batched_func(iterator):
- yield Batch(list(oldfunc(iterator)))
- else:
- def batched_func(iterator):
- return batched(oldfunc(iterator), batchSize)
+ def batched_func(iterator):
+ return batched(oldfunc(iterator), batchSize)
func = batched_func
cmds = [func, self._bypass_serializer]
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py
index 8b08f7ef0f..9a5151ea00 100644
--- a/pyspark/pyspark/serializers.py
+++ b/pyspark/pyspark/serializers.py
@@ -16,17 +16,20 @@ class Batch(object):
def batched(iterator, batchSize):
- items = []
- count = 0
- for item in iterator:
- items.append(item)
- count += 1
- if count == batchSize:
+ if batchSize == -1: # unlimited batch size
+ yield Batch(list(iterator))
+ else:
+ items = []
+ count = 0
+ for item in iterator:
+ items.append(item)
+ count += 1
+ if count == batchSize:
+ yield Batch(items)
+ items = []
+ count = 0
+ if items:
yield Batch(items)
- items = []
- count = 0
- if items:
- yield Batch(items)
def dump_pickle(obj):