From e2dad15621f5dc15275b300df05483afde5025a0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 26 Dec 2012 17:34:24 -0800 Subject: Add support for batched serialization of Python objects in PySpark. --- pyspark/pyspark/context.py | 3 ++- pyspark/pyspark/rdd.py | 57 +++++++++++++++++++++++++++++------------- pyspark/pyspark/serializers.py | 34 ++++++++++++++++++++++++- 3 files changed, 74 insertions(+), 20 deletions(-) (limited to 'pyspark') diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 19f9f9e133..032619693a 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -17,13 +17,14 @@ class SparkContext(object): readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile - def __init__(self, master, name, defaultParallelism=None): + def __init__(self, master, name, defaultParallelism=None, batchSize=-1): self.master = master self.name = name self._jsc = self.jvm.JavaSparkContext(master, name) self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') + self.batchSize = batchSize # -1 represents a unlimited batch size # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 01908cff96..d7081dffd2 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -2,6 +2,7 @@ import atexit from base64 import standard_b64encode as b64enc from collections import defaultdict from itertools import chain, ifilter, imap +import operator import os import shlex from subprocess import Popen, PIPE @@ -9,7 +10,8 @@ from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file +from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ + read_from_pickle_file from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup @@ -83,6 +85,11 @@ class RDD(object): >>> rdd = sc.parallelize([1, 1, 2, 3]) >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] + + # Union of batched and unbatched RDDs: + >>> batchedRDD = sc.parallelize([Batch([1, 2, 3, 4, 5])]) + >>> rdd.union(batchedRDD).collect() + [1, 1, 2, 3, 1, 2, 3, 4, 5] """ return RDD(self._jrdd.union(other._jrdd), self.ctx) @@ -147,13 +154,8 @@ class RDD(object): self.map(f).collect() # Force evaluation def collect(self): - # To minimize the number of transfers between Python and Java, we'll - # flatten each partition into a list before collecting it. Due to - # pipelining, this should add minimal overhead. - def asList(iterator): - yield list(iterator) - picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect() - return list(chain.from_iterable(self._collect_array_through_file(picklesInJava))) + picklesInJava = self._jrdd.rdd().collect() + return list(self._collect_array_through_file(picklesInJava)) def _collect_array_through_file(self, array): # Transferring lots of data through Py4J can be slow because @@ -214,12 +216,21 @@ class RDD(object): # TODO: aggregate + def sum(self): + """ + >>> sc.parallelize([1.0, 2.0, 3.0]).sum() + 6.0 + """ + return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + def count(self): """ >>> sc.parallelize([2, 3, 4]).count() - 3L + 3 + >>> sc.parallelize([Batch([2, 3, 4])]).count() + 3 """ - return self._jrdd.count() + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() def countByValue(self): """ @@ -342,24 +353,23 @@ class RDD(object): """ if numSplits is None: numSplits = self.ctx.defaultParallelism + # Transferring O(n) objects to Java is too expensive. Instead, we'll + # form the hash buckets in Python, transferring O(numSplits) objects + # to Java. Each object is a (splitNumber, [objects]) pair. def add_shuffle_key(iterator): buckets = defaultdict(list) for (k, v) in iterator: buckets[hashFunc(k) % numSplits].append((k, v)) for (split, items) in buckets.iteritems(): yield str(split) - yield dump_pickle(items) + yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits) - # Transferring O(n) objects to Java is too expensive. Instead, we'll - # form the hash buckets in Python, transferring O(numSplits) objects - # to Java. Each object is a (splitNumber, [objects]) pair. jrdd = pairRDD.partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) - # Flatten the resulting RDD: - return RDD(jrdd, self.ctx).flatMap(lambda items: items) + return RDD(jrdd, self.ctx) def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numSplits=None): @@ -478,8 +488,19 @@ class PipelinedRDD(RDD): def _jrdd(self): if self._jrdd_val: return self._jrdd_val - funcs = [self.func, self._bypass_serializer] - pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs) + func = self.func + if not self._bypass_serializer and self.ctx.batchSize != 1: + oldfunc = self.func + batchSize = self.ctx.batchSize + if batchSize == -1: # unlimited batch size + def batched_func(iterator): + yield Batch(list(oldfunc(iterator))) + else: + def batched_func(iterator): + return batched(oldfunc(iterator), batchSize) + func = batched_func + cmds = [func, self._bypass_serializer] + pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx.gateway._gateway_client) diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index bfcdda8f12..4ed925697c 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -2,6 +2,33 @@ import struct import cPickle +class Batch(object): + """ + Used to store multiple RDD entries as a single Java object. + + This relieves us from having to explicitly track whether an RDD + is stored as batches of objects and avoids problems when processing + the union() of batched and unbatched RDDs (e.g. the union() of textFile() + with another RDD). + """ + def __init__(self, items): + self.items = items + + +def batched(iterator, batchSize): + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == batchSize: + yield Batch(items) + items = [] + count = [] + if items: + yield Batch(items) + + def dump_pickle(obj): return cPickle.dumps(obj, 2) @@ -38,6 +65,11 @@ def read_with_length(stream): def read_from_pickle_file(stream): try: while True: - yield load_pickle(read_with_length(stream)) + obj = load_pickle(read_with_length(stream)) + if type(obj) == Batch: # We don't care about inheritance + for item in obj.items: + yield item + else: + yield obj except EOFError: return -- cgit v1.2.3