From 26186e2d259f3aa2db9c8594097fd342107ce147 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 29 Dec 2012 15:34:57 -0800 Subject: Use batching in pyspark parallelize(); fix cartesian() --- pyspark/pyspark/context.py | 4 +++- pyspark/pyspark/rdd.py | 31 +++++++++++++++---------------- pyspark/pyspark/serializers.py | 23 +++++++++++++---------- 3 files changed, 31 insertions(+), 27 deletions(-) (limited to 'pyspark') diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index b90596ecc2..6172d69dcf 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length +from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.rdd import RDD from py4j.java_collections import ListConverter @@ -91,6 +91,8 @@ class SparkContext(object): # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False) atexit.register(lambda: os.unlink(tempFile.name)) + if self.batchSize != 1: + c = batched(c, self.batchSize) for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 20f84b2dd0..203f7377d2 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -2,7 +2,7 @@ import atexit from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from itertools import chain, ifilter, imap +from itertools import chain, ifilter, imap, product import operator import os import shlex @@ -123,12 +123,6 @@ class RDD(object): >>> rdd = sc.parallelize([1, 1, 2, 3]) >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] - - Union of batched and unbatched RDDs (internal test): - - >>> batchedRDD = sc.parallelize([Batch([1, 2, 3, 4, 5])]) - >>> rdd.union(batchedRDD).collect() - [1, 1, 2, 3, 1, 2, 3, 4, 5] """ return RDD(self._jrdd.union(other._jrdd), self.ctx) @@ -168,7 +162,18 @@ class RDD(object): >>> sorted(rdd.cartesian(rdd).collect()) [(1, 1), (1, 2), (2, 1), (2, 2)] """ - return RDD(self._jrdd.cartesian(other._jrdd), self.ctx) + # Due to batching, we can't use the Java cartesian method. + java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) + def unpack_batches(pair): + (x, y) = pair + if type(x) == Batch or type(y) == Batch: + xs = x.items if type(x) == Batch else [x] + ys = y.items if type(y) == Batch else [y] + for pair in product(xs, ys): + yield pair + else: + yield pair + return java_cartesian.flatMap(unpack_batches) def groupBy(self, f, numSplits=None): """ @@ -293,8 +298,6 @@ class RDD(object): >>> sc.parallelize([2, 3, 4]).count() 3 - >>> sc.parallelize([Batch([2, 3, 4])]).count() - 3 """ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() @@ -667,12 +670,8 @@ class PipelinedRDD(RDD): if not self._bypass_serializer and self.ctx.batchSize != 1: oldfunc = self.func batchSize = self.ctx.batchSize - if batchSize == -1: # unlimited batch size - def batched_func(iterator): - yield Batch(list(oldfunc(iterator))) - else: - def batched_func(iterator): - return batched(oldfunc(iterator), batchSize) + def batched_func(iterator): + return batched(oldfunc(iterator), batchSize) func = batched_func cmds = [func, self._bypass_serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index 8b08f7ef0f..9a5151ea00 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -16,17 +16,20 @@ class Batch(object): def batched(iterator, batchSize): - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: + if batchSize == -1: # unlimited batch size + yield Batch(list(iterator)) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == batchSize: + yield Batch(items) + items = [] + count = 0 + if items: yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) def dump_pickle(obj): -- cgit v1.2.3