diff options
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r-- | python/pyspark/rdd.py | 86 |
1 files changed, 47 insertions, 39 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d3c4d13a1e..6691c30519 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,7 +18,7 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from itertools import chain, ifilter, imap, product +from itertools import chain, ifilter, imap import operator import os import sys @@ -28,8 +28,8 @@ from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ - read_from_pickle_file, pack_long +from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ + BatchedSerializer, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -48,13 +48,12 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx): + def __init__(self, jrdd, ctx, jrdd_deserializer): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False self.ctx = ctx - self._partitionFunc = None - self._stage_input_is_pairs = False + self._jrdd_deserializer = jrdd_deserializer @property def context(self): @@ -248,7 +247,23 @@ class RDD(object): >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] """ - return RDD(self._jrdd.union(other._jrdd), self.ctx) + if self._jrdd_deserializer == other._jrdd_deserializer: + rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, + self._jrdd_deserializer) + return rdd + else: + # These RDDs contain data in different serialized formats, so we + # must normalize them to the default serializer. + self_copy = self._reserialize() + other_copy = other._reserialize() + return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, + self.ctx.serializer) + + def _reserialize(self): + if self._jrdd_deserializer == self.ctx.serializer: + return self + else: + return self.map(lambda x: x, preservesPartitioning=True) def __add__(self, other): """ @@ -335,18 +350,9 @@ class RDD(object): [(1, 1), (1, 2), (2, 1), (2, 2)] """ # Due to batching, we can't use the Java cartesian method. - java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) - def unpack_batches(pair): - (x, y) = pair - if type(x) == Batch or type(y) == Batch: - xs = x.items if type(x) == Batch else [x] - ys = y.items if type(y) == Batch else [y] - for pair in product(xs, ys): - yield pair - else: - yield pair - java_cartesian._stage_input_is_pairs = True - return java_cartesian.flatMap(unpack_batches) + deserializer = CartesianDeserializer(self._jrdd_deserializer, + other._jrdd_deserializer) + return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) def groupBy(self, f, numPartitions=None): """ @@ -405,7 +411,7 @@ class RDD(object): self.ctx._writeToFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: - for item in read_from_pickle_file(tempFile): + for item in self._jrdd_deserializer.load_stream(tempFile): yield item os.unlink(tempFile.name) @@ -573,7 +579,7 @@ class RDD(object): items = [] for partition in range(mapped._jrdd.splits().size()): iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition) - items.extend(self._collect_iterator_through_file(iterator)) + items.extend(mapped._collect_iterator_through_file(iterator)) if len(items) >= num: break return items[:num] @@ -737,6 +743,7 @@ class RDD(object): # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numPartitions) objects # to Java. Each object is a (splitNumber, [objects]) pair. + outputSerializer = self.ctx._unbatched_serializer def add_shuffle_key(split, iterator): buckets = defaultdict(list) @@ -745,14 +752,14 @@ class RDD(object): buckets[partitionFunc(k) % numPartitions].append((k, v)) for (split, items) in buckets.iteritems(): yield pack_long(split) - yield dump_pickle(Batch(items)) + yield outputSerializer._dumps(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() - rdd = RDD(jrdd, self.ctx) + rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) # This is required so that id(partitionFunc) remains unique, even if # partitionFunc is a lambda: rdd._partitionFunc = partitionFunc @@ -789,7 +796,8 @@ class RDD(object): numPartitions = self.ctx.defaultParallelism def combineLocally(iterator): combiners = {} - for (k, v) in iterator: + for x in iterator: + (k, v) = x if k not in combiners: combiners[k] = createCombiner(v) else: @@ -931,38 +939,38 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): + if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): + # This transformation is the first in its stage: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self._prev_jrdd_deserializer = prev._jrdd_deserializer + else: prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning - self._prev_jrdd = prev._prev_jrdd - else: - self.func = func - self.preservesPartitioning = preservesPartitioning - self._prev_jrdd = prev._jrdd - self._stage_input_is_pairs = prev._stage_input_is_pairs + self._prev_jrdd = prev._prev_jrdd # maintain the pipeline + self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False @property def _jrdd(self): if self._jrdd_val: return self._jrdd_val - func = self.func - if not self._bypass_serializer and self.ctx.batchSize != 1: - oldfunc = self.func - batchSize = self.ctx.batchSize - def batched_func(split, iterator): - return batched(oldfunc(split, iterator), batchSize) - func = batched_func - cmds = [func, self._bypass_serializer, self._stage_input_is_pairs] + if self._bypass_serializer: + serializer = NoOpSerializer() + else: + serializer = self.ctx.serializer + cmds = [self.func, self._prev_jrdd_deserializer, serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], |