diff options
author | Josh Rosen <joshrosen@apache.org> | 2013-11-05 17:52:39 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@apache.org> | 2013-11-10 16:45:38 -0800 |
commit | cbb7f04aef2220ece93dea9f3fa98b5db5f270d6 (patch) | |
tree | 5feaed6b6064b81272fcb74b48ee2579e32de4e6 /python/pyspark/context.py | |
parent | 7d68a81a8ed5f49fefb3bd0fa0b9d3835cc7d86e (diff) | |
download | spark-cbb7f04aef2220ece93dea9f3fa98b5db5f270d6.tar.gz spark-cbb7f04aef2220ece93dea9f3fa98b5db5f270d6.tar.bz2 spark-cbb7f04aef2220ece93dea9f3fa98b5db5f270d6.zip |
Add custom serializer support to PySpark.
For now, this only adds MarshalSerializer, but it lays the groundwork
for other supporting custom serializers. Many of these mechanisms
can also be used to support deserialization of different data formats
sent by Java, such as data encoded by MsgPack.
This also fixes a bug in SparkContext.union().
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r-- | python/pyspark/context.py | 61 |
1 files changed, 45 insertions, 16 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 0fec1a6bf6..6bb1c6c3a1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD @@ -51,7 +51,7 @@ class SparkContext(object): def __init__(self, master, jobName, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024): + environment=None, batchSize=1024, serializer=PickleSerializer()): """ Create a new SparkContext. @@ -67,6 +67,7 @@ class SparkContext(object): @param batchSize: The number of Python objects represented as a single Java object. Set 1 to disable batching or -1 to use an unlimited batch size. + @param serializer: The serializer for RDDs. >>> from pyspark.context import SparkContext @@ -83,7 +84,13 @@ class SparkContext(object): self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J self.environment = environment or {} - self.batchSize = batchSize # -1 represents a unlimited batch size + self._batchSize = batchSize # -1 represents an unlimited batch size + self._unbatched_serializer = serializer + if batchSize == 1: + self.serializer = self._unbatched_serializer + else: + self.serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) # Create the Java SparkContext through Py4J empty_string_array = self._gateway.new_array(self._jvm.String, 0) @@ -184,15 +191,17 @@ class SparkContext(object): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self.batchSize) + batchSize = min(len(c) // numSlices, self._batchSize) if batchSize > 1: - c = batched(c, batchSize) - for x in c: - write_with_length(dump_pickle(x), tempFile) + serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) + else: + serializer = self._unbatched_serializer + serializer.dump_stream(c, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) - return RDD(jrdd, self) + return RDD(jrdd, self, serializer) def textFile(self, name, minSplits=None): """ @@ -201,21 +210,39 @@ class SparkContext(object): RDD of Strings. """ minSplits = minSplits or min(self.defaultParallelism, 2) - jrdd = self._jsc.textFile(name, minSplits) - return RDD(jrdd, self) + return RDD(self._jsc.textFile(name, minSplits), self, + MUTF8Deserializer()) - def _checkpointFile(self, name): + def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) - return RDD(jrdd, self) + return RDD(jrdd, self, input_deserializer) def union(self, rdds): """ Build the union of a list of RDDs. + + This supports unions() of RDDs with different serialized formats, + although this forces them to be reserialized using the default + serializer: + + >>> path = os.path.join(tempdir, "union-text.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("Hello") + >>> textFile = sc.textFile(path) + >>> textFile.collect() + [u'Hello'] + >>> parallelized = sc.parallelize(["World!"]) + >>> sorted(sc.union([textFile, parallelized]).collect()) + [u'Hello', 'World!'] """ + first_jrdd_deserializer = rdds[0]._jrdd_deserializer + if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): + rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self.gateway._gateway_client) - return RDD(self._jsc.union(first, rest), self) + rest = ListConverter().convert(rest, self._gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self, + rdds[0]._jrdd_deserializer) def broadcast(self, value): """ @@ -223,7 +250,9 @@ class SparkContext(object): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) + pickleSer = PickleSerializer() + pickled = pickleSer._dumps(value) + jbroadcast = self._jsc.broadcast(bytearray(pickled)) return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) @@ -235,7 +264,7 @@ class SparkContext(object): and floating-point numbers if you do not provide one. For other types, a custom AccumulatorParam can be used. """ - if accum_param == None: + if accum_param is None: if isinstance(value, int): accum_param = accumulators.INT_ACCUMULATOR_PARAM elif isinstance(value, float): |