aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorPrashant Sharma <prashant.s@imaginea.com>2013-11-27 14:44:12 +0530
committerPrashant Sharma <prashant.s@imaginea.com>2013-11-27 14:44:12 +0530
commit17987778daac140027b7a01c0ec22f0b3e4f3b83 (patch)
tree89af24131291a60ac1f4f00cabe27e8119c65593 /python/pyspark/rdd.py
parent54862af5ee813030ead80ec097f48620ddb974fc (diff)
parentfb6875dd5c9334802580155464cef9ac4d4cc1f0 (diff)
downloadspark-17987778daac140027b7a01c0ec22f0b3e4f3b83.tar.gz
spark-17987778daac140027b7a01c0ec22f0b3e4f3b83.tar.bz2
spark-17987778daac140027b7a01c0ec22f0b3e4f3b83.zip
Merge branch 'master' into wip-scala-2.10
Conflicts: core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala core/src/main/scala/org/apache/spark/rdd/RDD.scala python/pyspark/rdd.py
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py97
1 files changed, 54 insertions, 43 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 245a132dfd..d2cb5f191a 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,7 +18,7 @@
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
-from itertools import chain, ifilter, imap, product
+from itertools import chain, ifilter, imap
import operator
import os
import sys
@@ -27,9 +27,8 @@ from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
-from pyspark import cloudpickle
-from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
- read_from_pickle_file, pack_long
+from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
+ BatchedSerializer, CloudPickleSerializer, pack_long
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -48,12 +47,12 @@ class RDD(object):
operated on in parallel.
"""
- def __init__(self, jrdd, ctx):
+ def __init__(self, jrdd, ctx, jrdd_deserializer):
self._jrdd = jrdd
self.is_cached = False
self.is_checkpointed = False
self.ctx = ctx
- self._partitionFunc = None
+ self._jrdd_deserializer = jrdd_deserializer
@property
def context(self):
@@ -247,7 +246,23 @@ class RDD(object):
>>> rdd.union(rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
"""
- return RDD(self._jrdd.union(other._jrdd), self.ctx)
+ if self._jrdd_deserializer == other._jrdd_deserializer:
+ rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
+ self._jrdd_deserializer)
+ return rdd
+ else:
+ # These RDDs contain data in different serialized formats, so we
+ # must normalize them to the default serializer.
+ self_copy = self._reserialize()
+ other_copy = other._reserialize()
+ return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+ self.ctx.serializer)
+
+ def _reserialize(self):
+ if self._jrdd_deserializer == self.ctx.serializer:
+ return self
+ else:
+ return self.map(lambda x: x, preservesPartitioning=True)
def __add__(self, other):
"""
@@ -334,17 +349,9 @@ class RDD(object):
[(1, 1), (1, 2), (2, 1), (2, 2)]
"""
# Due to batching, we can't use the Java cartesian method.
- java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
- def unpack_batches(pair):
- (x, y) = pair
- if type(x) == Batch or type(y) == Batch:
- xs = x.items if type(x) == Batch else [x]
- ys = y.items if type(y) == Batch else [y]
- for pair in product(xs, ys):
- yield pair
- else:
- yield pair
- return java_cartesian.flatMap(unpack_batches)
+ deserializer = CartesianDeserializer(self._jrdd_deserializer,
+ other._jrdd_deserializer)
+ return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
def groupBy(self, f, numPartitions=None):
"""
@@ -391,8 +398,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- picklesInJava = self._jrdd.collect().iterator()
- return list(self._collect_iterator_through_file(picklesInJava))
+ bytesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
@@ -400,10 +407,10 @@ class RDD(object):
# file and read it back.
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
- self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+ self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
- for item in read_from_pickle_file(tempFile):
+ for item in self._jrdd_deserializer.load_stream(tempFile):
yield item
os.unlink(tempFile.name)
@@ -571,7 +578,7 @@ class RDD(object):
items = []
for partition in range(mapped._jrdd.splits().size()):
iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
- items.extend(self._collect_iterator_through_file(iterator))
+ items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break
return items[:num]
@@ -735,6 +742,7 @@ class RDD(object):
# Transferring O(n) objects to Java is too expensive. Instead, we'll
# form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
+ outputSerializer = self.ctx._unbatched_serializer
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
@@ -743,14 +751,14 @@ class RDD(object):
buckets[partitionFunc(k) % numPartitions].append((k, v))
for (split, items) in buckets.iteritems():
yield pack_long(split)
- yield dump_pickle(Batch(items))
+ yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
- rdd = RDD(jrdd, self.ctx)
+ rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if
# partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
@@ -787,7 +795,8 @@ class RDD(object):
numPartitions = self.ctx.defaultParallelism
def combineLocally(iterator):
combiners = {}
- for (k, v) in iterator:
+ for x in iterator:
+ (k, v) = x
if k not in combiners:
combiners[k] = createCombiner(v)
else:
@@ -929,38 +938,39 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
- if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
+ if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
+ # This transformation is the first in its stage:
+ self.func = func
+ self.preservesPartitioning = preservesPartitioning
+ self._prev_jrdd = prev._jrdd
+ self._prev_jrdd_deserializer = prev._jrdd_deserializer
+ else:
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
self.func = pipeline_func
self.preservesPartitioning = \
prev.preservesPartitioning and preservesPartitioning
- self._prev_jrdd = prev._prev_jrdd
- else:
- self.func = func
- self.preservesPartitioning = preservesPartitioning
- self._prev_jrdd = prev._jrdd
+ self._prev_jrdd = prev._prev_jrdd # maintain the pipeline
+ self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
+ self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
@property
def _jrdd(self):
if self._jrdd_val:
return self._jrdd_val
- func = self.func
- if not self._bypass_serializer and self.ctx.batchSize != 1:
- oldfunc = self.func
- batchSize = self.ctx.batchSize
- def batched_func(split, iterator):
- return batched(oldfunc(split, iterator), batchSize)
- func = batched_func
- cmds = [func, self._bypass_serializer]
- pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+ if self._bypass_serializer:
+ serializer = NoOpSerializer()
+ else:
+ serializer = self.ctx.serializer
+ command = (self.func, self._prev_jrdd_deserializer, serializer)
+ pickled_command = CloudPickleSerializer().dumps(command)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
@@ -971,8 +981,9 @@ class PipelinedRDD(RDD):
includes = ListConverter().convert(self.ctx._python_includes,
self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator, class_tag)
+ bytearray(pickled_command), env, includes, self.preservesPartitioning,
+ self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
+ class_tag)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val