aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-26 17:34:24 -0800
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-12-26 18:16:09 -0800
commite2dad15621f5dc15275b300df05483afde5025a0 (patch)
tree8038596ff15987dce3be7d7e48e4f9adf49c2220 /pyspark
parent4608902fb87af64a15b97ab21fe6382cd6e5a644 (diff)
downloadspark-e2dad15621f5dc15275b300df05483afde5025a0.tar.gz
spark-e2dad15621f5dc15275b300df05483afde5025a0.tar.bz2
spark-e2dad15621f5dc15275b300df05483afde5025a0.zip
Add support for batched serialization of Python objects in PySpark.
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/context.py3
-rw-r--r--pyspark/pyspark/rdd.py57
-rw-r--r--pyspark/pyspark/serializers.py34
3 files changed, 74 insertions, 20 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index 19f9f9e133..032619693a 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -17,13 +17,14 @@ class SparkContext(object):
readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile
- def __init__(self, master, name, defaultParallelism=None):
+ def __init__(self, master, name, defaultParallelism=None, batchSize=-1):
self.master = master
self.name = name
self._jsc = self.jvm.JavaSparkContext(master, name)
self.defaultParallelism = \
defaultParallelism or self._jsc.sc().defaultParallelism()
self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python')
+ self.batchSize = batchSize # -1 represents a unlimited batch size
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
# been pickled, so it can determine which Java broadcast objects to
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 01908cff96..d7081dffd2 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -2,6 +2,7 @@ import atexit
from base64 import standard_b64encode as b64enc
from collections import defaultdict
from itertools import chain, ifilter, imap
+import operator
import os
import shlex
from subprocess import Popen, PIPE
@@ -9,7 +10,8 @@ from tempfile import NamedTemporaryFile
from threading import Thread
from pyspark import cloudpickle
-from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file
+from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
+ read_from_pickle_file
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
@@ -83,6 +85,11 @@ class RDD(object):
>>> rdd = sc.parallelize([1, 1, 2, 3])
>>> rdd.union(rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
+
+ # Union of batched and unbatched RDDs:
+ >>> batchedRDD = sc.parallelize([Batch([1, 2, 3, 4, 5])])
+ >>> rdd.union(batchedRDD).collect()
+ [1, 1, 2, 3, 1, 2, 3, 4, 5]
"""
return RDD(self._jrdd.union(other._jrdd), self.ctx)
@@ -147,13 +154,8 @@ class RDD(object):
self.map(f).collect() # Force evaluation
def collect(self):
- # To minimize the number of transfers between Python and Java, we'll
- # flatten each partition into a list before collecting it. Due to
- # pipelining, this should add minimal overhead.
- def asList(iterator):
- yield list(iterator)
- picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect()
- return list(chain.from_iterable(self._collect_array_through_file(picklesInJava)))
+ picklesInJava = self._jrdd.rdd().collect()
+ return list(self._collect_array_through_file(picklesInJava))
def _collect_array_through_file(self, array):
# Transferring lots of data through Py4J can be slow because
@@ -214,12 +216,21 @@ class RDD(object):
# TODO: aggregate
+ def sum(self):
+ """
+ >>> sc.parallelize([1.0, 2.0, 3.0]).sum()
+ 6.0
+ """
+ return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
+
def count(self):
"""
>>> sc.parallelize([2, 3, 4]).count()
- 3L
+ 3
+ >>> sc.parallelize([Batch([2, 3, 4])]).count()
+ 3
"""
- return self._jrdd.count()
+ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
def countByValue(self):
"""
@@ -342,24 +353,23 @@ class RDD(object):
"""
if numSplits is None:
numSplits = self.ctx.defaultParallelism
+ # Transferring O(n) objects to Java is too expensive. Instead, we'll
+ # form the hash buckets in Python, transferring O(numSplits) objects
+ # to Java. Each object is a (splitNumber, [objects]) pair.
def add_shuffle_key(iterator):
buckets = defaultdict(list)
for (k, v) in iterator:
buckets[hashFunc(k) % numSplits].append((k, v))
for (split, items) in buckets.iteritems():
yield str(split)
- yield dump_pickle(items)
+ yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
- # Transferring O(n) objects to Java is too expensive. Instead, we'll
- # form the hash buckets in Python, transferring O(numSplits) objects
- # to Java. Each object is a (splitNumber, [objects]) pair.
jrdd = pairRDD.partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
- # Flatten the resulting RDD:
- return RDD(jrdd, self.ctx).flatMap(lambda items: items)
+ return RDD(jrdd, self.ctx)
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
numSplits=None):
@@ -478,8 +488,19 @@ class PipelinedRDD(RDD):
def _jrdd(self):
if self._jrdd_val:
return self._jrdd_val
- funcs = [self.func, self._bypass_serializer]
- pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs)
+ func = self.func
+ if not self._bypass_serializer and self.ctx.batchSize != 1:
+ oldfunc = self.func
+ batchSize = self.ctx.batchSize
+ if batchSize == -1: # unlimited batch size
+ def batched_func(iterator):
+ yield Batch(list(oldfunc(iterator)))
+ else:
+ def batched_func(iterator):
+ return batched(oldfunc(iterator), batchSize)
+ func = batched_func
+ cmds = [func, self._bypass_serializer]
+ pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx.gateway._gateway_client)
diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py
index bfcdda8f12..4ed925697c 100644
--- a/pyspark/pyspark/serializers.py
+++ b/pyspark/pyspark/serializers.py
@@ -2,6 +2,33 @@ import struct
import cPickle
+class Batch(object):
+ """
+ Used to store multiple RDD entries as a single Java object.
+
+ This relieves us from having to explicitly track whether an RDD
+ is stored as batches of objects and avoids problems when processing
+ the union() of batched and unbatched RDDs (e.g. the union() of textFile()
+ with another RDD).
+ """
+ def __init__(self, items):
+ self.items = items
+
+
+def batched(iterator, batchSize):
+ items = []
+ count = 0
+ for item in iterator:
+ items.append(item)
+ count += 1
+ if count == batchSize:
+ yield Batch(items)
+ items = []
+ count = []
+ if items:
+ yield Batch(items)
+
+
def dump_pickle(obj):
return cPickle.dumps(obj, 2)
@@ -38,6 +65,11 @@ def read_with_length(stream):
def read_from_pickle_file(stream):
try:
while True:
- yield load_pickle(read_with_length(stream))
+ obj = load_pickle(read_with_length(stream))
+ if type(obj) == Batch: # We don't care about inheritance
+ for item in obj.items:
+ yield item
+ else:
+ yield obj
except EOFError:
return