aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-11-26 20:55:40 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-11-26 20:55:40 -0800
commitfb6875dd5c9334802580155464cef9ac4d4cc1f0 (patch)
tree982a3e535fda57950b6e3b1f26b024f2b025d2d9 /python
parent330ada1766c1f8a7274b5566fa66b796329d7054 (diff)
parent1b74a27da026aba7dbe2088ee64974d772feb23d (diff)
downloadspark-fb6875dd5c9334802580155464cef9ac4d4cc1f0.tar.gz
spark-fb6875dd5c9334802580155464cef9ac4d4cc1f0.tar.bz2
spark-fb6875dd5c9334802580155464cef9ac4d4cc1f0.zip
Merge pull request #146 from JoshRosen/pyspark-custom-serializers
Custom Serializers for PySpark This pull request adds support for custom serializers to PySpark. For now, all Python-transformed (or parallelize()d RDDs) are serialized with the same serializer that's specified when creating SparkContext. For now, PySpark includes `PickleSerDe` and `MarshalSerDe` classes for using Python's `pickle` and `marshal` serializers. It's pretty easy to add support for other serializers, although I still need to add instructions on this. A few notable changes: - The Scala `PythonRDD` class no longer manipulates Pickled objects; data from `textFile` is written to Python as MUTF-8 strings. The Python code performs the appropriate bookkeeping to track which deserializer should be used when reading an underlying JavaRDD. This mechanism could also be used to support other data exchange formats, such as MsgPack. - Several magic numbers were refactored into constants. - Batching is implemented by wrapping / decorating an unbatched SerDe.
Diffstat (limited to 'python')
-rw-r--r--python/epydoc.conf2
-rw-r--r--python/pyspark/accumulators.py6
-rw-r--r--python/pyspark/context.py71
-rw-r--r--python/pyspark/rdd.py97
-rw-r--r--python/pyspark/serializers.py301
-rw-r--r--python/pyspark/tests.py3
-rw-r--r--python/pyspark/worker.py44
-rwxr-xr-xpython/run-tests1
8 files changed, 383 insertions, 142 deletions
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 1d0d002d36..0b42e729f8 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -32,6 +32,6 @@ target: docs/
private: no
-exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
pyspark.rddsampler pyspark.daemon
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index da3d96689a..2204e9c9ca 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -90,9 +90,11 @@ import struct
import SocketServer
import threading
from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import read_int, read_with_length, load_pickle
+from pyspark.serializers import read_int, PickleSerializer
+pickleSer = PickleSerializer()
+
# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
# the local accumulator updates back to the driver program at the end of a task.
_accumulatorRegistry = {}
@@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
from pyspark.accumulators import _accumulatorRegistry
num_updates = read_int(self.rfile)
for _ in range(num_updates):
- (aid, update) = load_pickle(read_with_length(self.rfile))
+ (aid, update) = pickleSer._read_with_length(self.rfile)
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a7ca8bc888..cbd41e58c4 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
@@ -42,7 +42,7 @@ class SparkContext(object):
_gateway = None
_jvm = None
- _writeIteratorToPickleFile = None
+ _writeToFile = None
_takePartition = None
_next_accum_id = 0
_active_spark_context = None
@@ -51,7 +51,7 @@ class SparkContext(object):
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
- environment=None, batchSize=1024):
+ environment=None, batchSize=1024, serializer=PickleSerializer()):
"""
Create a new SparkContext.
@@ -67,6 +67,7 @@ class SparkContext(object):
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
+ @param serializer: The serializer for RDDs.
>>> from pyspark.context import SparkContext
@@ -83,7 +84,13 @@ class SparkContext(object):
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
self.environment = environment or {}
- self.batchSize = batchSize # -1 represents a unlimited batch size
+ self._batchSize = batchSize # -1 represents an unlimited batch size
+ self._unbatched_serializer = serializer
+ if batchSize == 1:
+ self.serializer = self._unbatched_serializer
+ else:
+ self.serializer = BatchedSerializer(self._unbatched_serializer,
+ batchSize)
# Create the Java SparkContext through Py4J
empty_string_array = self._gateway.new_array(self._jvm.String, 0)
@@ -125,8 +132,8 @@ class SparkContext(object):
if not SparkContext._gateway:
SparkContext._gateway = launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
- SparkContext._writeIteratorToPickleFile = \
- SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+ SparkContext._writeToFile = \
+ SparkContext._jvm.PythonRDD.writeToFile
SparkContext._takePartition = \
SparkContext._jvm.PythonRDD.takePartition
@@ -184,15 +191,17 @@ class SparkContext(object):
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
- batchSize = min(len(c) // numSlices, self.batchSize)
+ batchSize = min(len(c) // numSlices, self._batchSize)
if batchSize > 1:
- c = batched(c, batchSize)
- for x in c:
- write_with_length(dump_pickle(x), tempFile)
+ serializer = BatchedSerializer(self._unbatched_serializer,
+ batchSize)
+ else:
+ serializer = self._unbatched_serializer
+ serializer.dump_stream(c, tempFile)
tempFile.close()
- readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
- jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
- return RDD(jrdd, self)
+ readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+ jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+ return RDD(jrdd, self, serializer)
def textFile(self, name, minSplits=None):
"""
@@ -201,21 +210,39 @@ class SparkContext(object):
RDD of Strings.
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
- jrdd = self._jsc.textFile(name, minSplits)
- return RDD(jrdd, self)
+ return RDD(self._jsc.textFile(name, minSplits), self,
+ MUTF8Deserializer())
- def _checkpointFile(self, name):
+ def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
- return RDD(jrdd, self)
+ return RDD(jrdd, self, input_deserializer)
def union(self, rdds):
"""
Build the union of a list of RDDs.
+
+ This supports unions() of RDDs with different serialized formats,
+ although this forces them to be reserialized using the default
+ serializer:
+
+ >>> path = os.path.join(tempdir, "union-text.txt")
+ >>> with open(path, "w") as testFile:
+ ... testFile.write("Hello")
+ >>> textFile = sc.textFile(path)
+ >>> textFile.collect()
+ [u'Hello']
+ >>> parallelized = sc.parallelize(["World!"])
+ >>> sorted(sc.union([textFile, parallelized]).collect())
+ [u'Hello', 'World!']
"""
+ first_jrdd_deserializer = rdds[0]._jrdd_deserializer
+ if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
+ rdds = [x._reserialize() for x in rdds]
first = rdds[0]._jrdd
rest = [x._jrdd for x in rdds[1:]]
- rest = ListConverter().convert(rest, self.gateway._gateway_client)
- return RDD(self._jsc.union(first, rest), self)
+ rest = ListConverter().convert(rest, self._gateway._gateway_client)
+ return RDD(self._jsc.union(first, rest), self,
+ rdds[0]._jrdd_deserializer)
def broadcast(self, value):
"""
@@ -223,7 +250,9 @@ class SparkContext(object):
object for reading it in distributed functions. The variable will be
sent to each cluster only once.
"""
- jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+ pickleSer = PickleSerializer()
+ pickled = pickleSer.dumps(value)
+ jbroadcast = self._jsc.broadcast(bytearray(pickled))
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
@@ -235,7 +264,7 @@ class SparkContext(object):
and floating-point numbers if you do not provide one. For other types,
a custom AccumulatorParam can be used.
"""
- if accum_param == None:
+ if accum_param is None:
if isinstance(value, int):
accum_param = accumulators.INT_ACCUMULATOR_PARAM
elif isinstance(value, float):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7019fb8bee..957f3f89c0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,7 +18,7 @@
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
-from itertools import chain, ifilter, imap, product
+from itertools import chain, ifilter, imap
import operator
import os
import sys
@@ -27,9 +27,8 @@ from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
-from pyspark import cloudpickle
-from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
- read_from_pickle_file, pack_long
+from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
+ BatchedSerializer, CloudPickleSerializer, pack_long
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -48,12 +47,12 @@ class RDD(object):
operated on in parallel.
"""
- def __init__(self, jrdd, ctx):
+ def __init__(self, jrdd, ctx, jrdd_deserializer):
self._jrdd = jrdd
self.is_cached = False
self.is_checkpointed = False
self.ctx = ctx
- self._partitionFunc = None
+ self._jrdd_deserializer = jrdd_deserializer
@property
def context(self):
@@ -247,7 +246,23 @@ class RDD(object):
>>> rdd.union(rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
"""
- return RDD(self._jrdd.union(other._jrdd), self.ctx)
+ if self._jrdd_deserializer == other._jrdd_deserializer:
+ rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
+ self._jrdd_deserializer)
+ return rdd
+ else:
+ # These RDDs contain data in different serialized formats, so we
+ # must normalize them to the default serializer.
+ self_copy = self._reserialize()
+ other_copy = other._reserialize()
+ return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+ self.ctx.serializer)
+
+ def _reserialize(self):
+ if self._jrdd_deserializer == self.ctx.serializer:
+ return self
+ else:
+ return self.map(lambda x: x, preservesPartitioning=True)
def __add__(self, other):
"""
@@ -334,17 +349,9 @@ class RDD(object):
[(1, 1), (1, 2), (2, 1), (2, 2)]
"""
# Due to batching, we can't use the Java cartesian method.
- java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
- def unpack_batches(pair):
- (x, y) = pair
- if type(x) == Batch or type(y) == Batch:
- xs = x.items if type(x) == Batch else [x]
- ys = y.items if type(y) == Batch else [y]
- for pair in product(xs, ys):
- yield pair
- else:
- yield pair
- return java_cartesian.flatMap(unpack_batches)
+ deserializer = CartesianDeserializer(self._jrdd_deserializer,
+ other._jrdd_deserializer)
+ return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
def groupBy(self, f, numPartitions=None):
"""
@@ -391,8 +398,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- picklesInJava = self._jrdd.collect().iterator()
- return list(self._collect_iterator_through_file(picklesInJava))
+ bytesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
@@ -400,10 +407,10 @@ class RDD(object):
# file and read it back.
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
- self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+ self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
- for item in read_from_pickle_file(tempFile):
+ for item in self._jrdd_deserializer.load_stream(tempFile):
yield item
os.unlink(tempFile.name)
@@ -571,7 +578,7 @@ class RDD(object):
items = []
for partition in range(mapped._jrdd.splits().size()):
iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
- items.extend(self._collect_iterator_through_file(iterator))
+ items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break
return items[:num]
@@ -735,6 +742,7 @@ class RDD(object):
# Transferring O(n) objects to Java is too expensive. Instead, we'll
# form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
+ outputSerializer = self.ctx._unbatched_serializer
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
@@ -743,14 +751,14 @@ class RDD(object):
buckets[partitionFunc(k) % numPartitions].append((k, v))
for (split, items) in buckets.iteritems():
yield pack_long(split)
- yield dump_pickle(Batch(items))
+ yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
- rdd = RDD(jrdd, self.ctx)
+ rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if
# partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
@@ -787,7 +795,8 @@ class RDD(object):
numPartitions = self.ctx.defaultParallelism
def combineLocally(iterator):
combiners = {}
- for (k, v) in iterator:
+ for x in iterator:
+ (k, v) = x
if k not in combiners:
combiners[k] = createCombiner(v)
else:
@@ -929,38 +938,39 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
- if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
+ if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
+ # This transformation is the first in its stage:
+ self.func = func
+ self.preservesPartitioning = preservesPartitioning
+ self._prev_jrdd = prev._jrdd
+ self._prev_jrdd_deserializer = prev._jrdd_deserializer
+ else:
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
self.func = pipeline_func
self.preservesPartitioning = \
prev.preservesPartitioning and preservesPartitioning
- self._prev_jrdd = prev._prev_jrdd
- else:
- self.func = func
- self.preservesPartitioning = preservesPartitioning
- self._prev_jrdd = prev._jrdd
+ self._prev_jrdd = prev._prev_jrdd # maintain the pipeline
+ self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
+ self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
@property
def _jrdd(self):
if self._jrdd_val:
return self._jrdd_val
- func = self.func
- if not self._bypass_serializer and self.ctx.batchSize != 1:
- oldfunc = self.func
- batchSize = self.ctx.batchSize
- def batched_func(split, iterator):
- return batched(oldfunc(split, iterator), batchSize)
- func = batched_func
- cmds = [func, self._bypass_serializer]
- pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+ if self._bypass_serializer:
+ serializer = NoOpSerializer()
+ else:
+ serializer = self.ctx.serializer
+ command = (self.func, self._prev_jrdd_deserializer, serializer)
+ pickled_command = CloudPickleSerializer().dumps(command)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
@@ -971,8 +981,9 @@ class PipelinedRDD(RDD):
includes = ListConverter().convert(self.ctx._python_includes,
self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator, class_manifest)
+ bytearray(pickled_command), env, includes, self.preservesPartitioning,
+ self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
+ class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 54fed1c9c7..811fa6f018 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -15,45 +15,269 @@
# limitations under the License.
#
-import struct
+"""
+PySpark supports custom serializers for transferring data; this can improve
+performance.
+
+By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
+C{cPickle} serializer, which can serialize nearly any Python object.
+Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
+faster.
+
+The serializer is chosen when creating L{SparkContext}:
+
+>>> from pyspark.context import SparkContext
+>>> from pyspark.serializers import MarshalSerializer
+>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
+>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+>>> sc.stop()
+
+By default, PySpark serialize objects in batches; the batch size can be
+controlled through SparkContext's C{batchSize} parameter
+(the default size is 1024 objects):
+
+>>> sc = SparkContext('local', 'test', batchSize=2)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+
+Behind the scenes, this creates a JavaRDD with four partitions, each of
+which contains two batches of two objects:
+
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+8L
+>>> sc.stop()
+
+A batch size of -1 uses an unlimited batch size, and a size of 1 disables
+batching:
+
+>>> sc = SparkContext('local', 'test', batchSize=1)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+16L
+"""
+
import cPickle
+from itertools import chain, izip, product
+import marshal
+import struct
+from pyspark import cloudpickle
+
+
+__all__ = ["PickleSerializer", "MarshalSerializer"]
+
+
+class SpecialLengths(object):
+ END_OF_DATA_SECTION = -1
+ PYTHON_EXCEPTION_THROWN = -2
+ TIMING_DATA = -3
+
+
+class Serializer(object):
+
+ def dump_stream(self, iterator, stream):
+ """
+ Serialize an iterator of objects to the output stream.
+ """
+ raise NotImplementedError
+
+ def load_stream(self, stream):
+ """
+ Return an iterator of deserialized objects from the input stream.
+ """
+ raise NotImplementedError
+
+
+ def _load_stream_without_unbatching(self, stream):
+ return self.load_stream(stream)
+
+ # Note: our notion of "equality" is that output generated by
+ # equal serializers can be deserialized using the same serializer.
+
+ # This default implementation handles the simple cases;
+ # subclasses should override __eq__ as appropriate.
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class FramedSerializer(Serializer):
+ """
+ Serializer that writes objects as a stream of (length, data) pairs,
+ where C{length} is a 32-bit integer and data is C{length} bytes.
+ """
+
+ def dump_stream(self, iterator, stream):
+ for obj in iterator:
+ self._write_with_length(obj, stream)
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self._read_with_length(stream)
+ except EOFError:
+ return
+
+ def _write_with_length(self, obj, stream):
+ serialized = self.dumps(obj)
+ write_int(len(serialized), stream)
+ stream.write(serialized)
+
+ def _read_with_length(self, stream):
+ length = read_int(stream)
+ obj = stream.read(length)
+ if obj == "":
+ raise EOFError
+ return self.loads(obj)
+
+ def dumps(self, obj):
+ """
+ Serialize an object into a byte array.
+ When batching is used, this will be called with an array of objects.
+ """
+ raise NotImplementedError
+
+ def loads(self, obj):
+ """
+ Deserialize an object from a byte array.
+ """
+ raise NotImplementedError
+
+
+class BatchedSerializer(Serializer):
+ """
+ Serializes a stream of objects in batches by calling its wrapped
+ Serializer with streams of objects.
+ """
+
+ UNLIMITED_BATCH_SIZE = -1
+
+ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
+ self.serializer = serializer
+ self.batchSize = batchSize
+
+ def _batched(self, iterator):
+ if self.batchSize == self.UNLIMITED_BATCH_SIZE:
+ yield list(iterator)
+ else:
+ items = []
+ count = 0
+ for item in iterator:
+ items.append(item)
+ count += 1
+ if count == self.batchSize:
+ yield items
+ items = []
+ count = 0
+ if items:
+ yield items
+
+ def dump_stream(self, iterator, stream):
+ self.serializer.dump_stream(self._batched(iterator), stream)
+
+ def load_stream(self, stream):
+ return chain.from_iterable(self._load_stream_without_unbatching(stream))
+
+ def _load_stream_without_unbatching(self, stream):
+ return self.serializer.load_stream(stream)
+
+ def __eq__(self, other):
+ return isinstance(other, BatchedSerializer) and \
+ other.serializer == self.serializer
+
+ def __str__(self):
+ return "BatchedSerializer<%s>" % str(self.serializer)
-class Batch(object):
+class CartesianDeserializer(FramedSerializer):
"""
- Used to store multiple RDD entries as a single Java object.
+ Deserializes the JavaRDD cartesian() of two PythonRDDs.
+ """
+
+ def __init__(self, key_ser, val_ser):
+ self.key_ser = key_ser
+ self.val_ser = val_ser
+
+ def load_stream(self, stream):
+ key_stream = self.key_ser._load_stream_without_unbatching(stream)
+ val_stream = self.val_ser._load_stream_without_unbatching(stream)
+ key_is_batched = isinstance(self.key_ser, BatchedSerializer)
+ val_is_batched = isinstance(self.val_ser, BatchedSerializer)
+ for (keys, vals) in izip(key_stream, val_stream):
+ keys = keys if key_is_batched else [keys]
+ vals = vals if val_is_batched else [vals]
+ for pair in product(keys, vals):
+ yield pair
+
+ def __eq__(self, other):
+ return isinstance(other, CartesianDeserializer) and \
+ self.key_ser == other.key_ser and self.val_ser == other.val_ser
+
+ def __str__(self):
+ return "CartesianDeserializer<%s, %s>" % \
+ (str(self.key_ser), str(self.val_ser))
- This relieves us from having to explicitly track whether an RDD
- is stored as batches of objects and avoids problems when processing
- the union() of batched and unbatched RDDs (e.g. the union() of textFile()
- with another RDD).
+
+class NoOpSerializer(FramedSerializer):
+
+ def loads(self, obj): return obj
+ def dumps(self, obj): return obj
+
+
+class PickleSerializer(FramedSerializer):
"""
- def __init__(self, items):
- self.items = items
+ Serializes objects using Python's cPickle serializer:
+ http://docs.python.org/2/library/pickle.html
-def batched(iterator, batchSize):
- if batchSize == -1: # unlimited batch size
- yield Batch(list(iterator))
- else:
- items = []
- count = 0
- for item in iterator:
- items.append(item)
- count += 1
- if count == batchSize:
- yield Batch(items)
- items = []
- count = 0
- if items:
- yield Batch(items)
+ This serializer supports nearly any Python object, but may
+ not be as fast as more specialized serializers.
+ """
+ def dumps(self, obj): return cPickle.dumps(obj, 2)
+ loads = cPickle.loads
-def dump_pickle(obj):
- return cPickle.dumps(obj, 2)
+class CloudPickleSerializer(PickleSerializer):
+ def dumps(self, obj): return cloudpickle.dumps(obj, 2)
-load_pickle = cPickle.loads
+
+class MarshalSerializer(FramedSerializer):
+ """
+ Serializes objects using Python's Marshal serializer:
+
+ http://docs.python.org/2/library/marshal.html
+
+ This serializer is faster than PickleSerializer but supports fewer datatypes.
+ """
+
+ dumps = marshal.dumps
+ loads = marshal.loads
+
+
+class MUTF8Deserializer(Serializer):
+ """
+ Deserializes streams written by Java's DataOutputStream.writeUTF().
+ """
+
+ def loads(self, stream):
+ length = struct.unpack('>H', stream.read(2))[0]
+ return stream.read(length).decode('utf8')
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self.loads(stream)
+ except struct.error:
+ return
+ except EOFError:
+ return
def read_long(stream):
@@ -84,25 +308,4 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
- stream.write(obj)
-
-
-def read_with_length(stream):
- length = read_int(stream)
- obj = stream.read(length)
- if obj == "":
- raise EOFError
- return obj
-
-
-def read_from_pickle_file(stream):
- try:
- while True:
- obj = load_pickle(read_with_length(stream))
- if type(obj) == Batch: # We don't care about inheritance
- for item in obj.items:
- yield item
- else:
- yield obj
- except EOFError:
- return
+ stream.write(obj) \ No newline at end of file
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 29d6a128f6..621e1cb58c 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase):
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
- recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
+ flatMappedRDD._jrdd_deserializer)
self.assertEquals([1, 2, 3, 4], recovered.collect())
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d63c2aaef7..f2b3f3c142 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,23 +23,22 @@ import sys
import time
import socket
import traceback
-from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, read_with_length, write_int, \
- read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+from pyspark.serializers import write_with_length, write_int, read_long, \
+ write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
-def load_obj(infile):
- return load_pickle(standard_b64decode(infile.readline().strip()))
+pickleSer = PickleSerializer()
+mutf8_deserializer = MUTF8Deserializer()
def report_times(outfile, boot, init, finish):
- write_int(-3, outfile)
+ write_int(SpecialLengths.TIMING_DATA, outfile)
write_long(1000 * boot, outfile)
write_long(1000 * init, outfile)
write_long(1000 * finish, outfile)
@@ -52,7 +51,7 @@ def main(infile, outfile):
return
# fetch name of workdir
- spark_files_dir = load_pickle(read_with_length(infile))
+ spark_files_dir = mutf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
@@ -60,38 +59,33 @@ def main(infile, outfile):
num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
- value = read_with_length(infile)
- _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+ value = pickleSer._read_with_length(infile)
+ _broadcastRegistry[bid] = Broadcast(bid, value)
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
- sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+ filename = mutf8_deserializer.loads(infile)
+ sys.path.append(os.path.join(spark_files_dir, filename))
- # now load function
- func = load_obj(infile)
- bypassSerializer = load_obj(infile)
- if bypassSerializer:
- dumps = lambda x: x
- else:
- dumps = dump_pickle
+ command = pickleSer._read_with_length(infile)
+ (func, deserializer, serializer) = command
init_time = time.time()
- iterator = read_from_pickle_file(infile)
try:
- for obj in func(split_index, iterator):
- write_with_length(dumps(obj), outfile)
+ iterator = deserializer.load_stream(infile)
+ serializer.dump_stream(func(split_index, iterator), outfile)
except Exception as e:
- write_int(-2, outfile)
+ write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
sys.exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
# Mark the beginning of the accumulators section of the output
- write_int(-1, outfile)
- for aid, accum in _accumulatorRegistry.items():
- write_with_length(dump_pickle((aid, accum._value)), outfile)
- write_int(-1, outfile)
+ write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+ write_int(len(_accumulatorRegistry), outfile)
+ for (aid, accum) in _accumulatorRegistry.items():
+ pickleSer._write_with_length((aid, accum._value), outfile)
if __name__ == '__main__':
diff --git a/python/run-tests b/python/run-tests
index cbc554ea9d..d4dad672d2 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -37,6 +37,7 @@ run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "-m doctest pyspark/broadcast.py"
run_test "-m doctest pyspark/accumulators.py"
+run_test "-m doctest pyspark/serializers.py"
run_test "pyspark/tests.py"
if [[ $FAILED != 0 ]]; then