aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <rosenville@gmail.com>2012-08-18 16:07:10 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-21 14:01:27 -0700
commitfd94e5443c99775bfad1928729f5075c900ad0f9 (patch)
tree1bebffa4c656266bc35bc182e8e6569cc34c5079 /pyspark
parent13b9514966a423f80f672f23f42ec3f0113936fd (diff)
downloadspark-fd94e5443c99775bfad1928729f5075c900ad0f9.tar.gz
spark-fd94e5443c99775bfad1928729f5075c900ad0f9.tar.bz2
spark-fd94e5443c99775bfad1928729f5075c900ad0f9.zip
Use only cPickle for serialization in Python API.
Objects serialized with JSON can be compared for equality, but JSON can be slow to serialize and only supports a limited range of data types.
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/context.py49
-rw-r--r--pyspark/pyspark/java_gateway.py1
-rw-r--r--pyspark/pyspark/join.py32
-rw-r--r--pyspark/pyspark/rdd.py414
-rw-r--r--pyspark/pyspark/serializers.py233
-rw-r--r--pyspark/pyspark/worker.py64
6 files changed, 233 insertions, 560 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index 587ab12b5f..ac7e4057e9 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -3,22 +3,24 @@ import atexit
from tempfile import NamedTemporaryFile
from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import JSONSerializer, NopSerializer
-from pyspark.rdd import RDD, PairRDD
+from pyspark.serializers import PickleSerializer, dumps
+from pyspark.rdd import RDD
class SparkContext(object):
gateway = launch_gateway()
jvm = gateway.jvm
- python_dump = jvm.spark.api.python.PythonRDD.pythonDump
+ pickleFile = jvm.spark.api.python.PythonRDD.pickleFile
+ asPickle = jvm.spark.api.python.PythonRDD.asPickle
+ arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle
- def __init__(self, master, name, defaultSerializer=JSONSerializer,
- defaultParallelism=None, pythonExec='python'):
+
+ def __init__(self, master, name, defaultParallelism=None,
+ pythonExec='python'):
self.master = master
self.name = name
self._jsc = self.jvm.JavaSparkContext(master, name)
- self.defaultSerializer = defaultSerializer
self.defaultParallelism = \
defaultParallelism or self._jsc.sc().defaultParallelism()
self.pythonExec = pythonExec
@@ -31,39 +33,26 @@ class SparkContext(object):
self._jsc.stop()
self._jsc = None
- def parallelize(self, c, numSlices=None, serializer=None):
- serializer = serializer or self.defaultSerializer
- numSlices = numSlices or self.defaultParallelism
- # Calling the Java parallelize() method with an ArrayList is too slow,
- # because it sends O(n) Py4J commands. As an alternative, serialized
- # objects are written to a file and loaded through textFile().
- tempFile = NamedTemporaryFile(delete=False)
- tempFile.writelines(serializer.dumps(x) + '\n' for x in c)
- tempFile.close()
- atexit.register(lambda: os.unlink(tempFile.name))
- return self.textFile(tempFile.name, numSlices, serializer)
-
- def parallelizePairs(self, c, numSlices=None, keySerializer=None,
- valSerializer=None):
+ def parallelize(self, c, numSlices=None):
"""
>>> sc = SparkContext("local", "test")
- >>> rdd = sc.parallelizePairs([(1, 2), (3, 4)])
+ >>> rdd = sc.parallelize([(1, 2), (3, 4)])
>>> rdd.collect()
[(1, 2), (3, 4)]
"""
- keySerializer = keySerializer or self.defaultSerializer
- valSerializer = valSerializer or self.defaultSerializer
numSlices = numSlices or self.defaultParallelism
+ # Calling the Java parallelize() method with an ArrayList is too slow,
+ # because it sends O(n) Py4J commands. As an alternative, serialized
+ # objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False)
- for (k, v) in c:
- tempFile.write(keySerializer.dumps(k).rstrip('\r\n') + '\n')
- tempFile.write(valSerializer.dumps(v).rstrip('\r\n') + '\n')
+ for x in c:
+ dumps(PickleSerializer.dumps(x), tempFile)
tempFile.close()
atexit.register(lambda: os.unlink(tempFile.name))
- jrdd = self.textFile(tempFile.name, numSlices)._pipePairs([], "echo")
- return PairRDD(jrdd, self, keySerializer, valSerializer)
+ jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
+ return RDD(jrdd, self)
- def textFile(self, name, numSlices=None, serializer=NopSerializer):
+ def textFile(self, name, numSlices=None):
numSlices = numSlices or self.defaultParallelism
jrdd = self._jsc.textFile(name, numSlices)
- return RDD(jrdd, self, serializer)
+ return RDD(jrdd, self)
diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py
index 2df80aee85..bcb405ba72 100644
--- a/pyspark/pyspark/java_gateway.py
+++ b/pyspark/pyspark/java_gateway.py
@@ -16,5 +16,4 @@ def launch_gateway():
java_import(gateway.jvm, "spark.api.java.*")
java_import(gateway.jvm, "spark.api.python.*")
java_import(gateway.jvm, "scala.Tuple2")
- java_import(gateway.jvm, "spark.api.python.PythonRDD.pythonDump")
return gateway
diff --git a/pyspark/pyspark/join.py b/pyspark/pyspark/join.py
index c67520fce8..7036c47980 100644
--- a/pyspark/pyspark/join.py
+++ b/pyspark/pyspark/join.py
@@ -30,15 +30,12 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
-from pyspark.serializers import PairSerializer, OptionSerializer, \
- ArraySerializer
-def _do_python_join(rdd, other, numSplits, dispatch, valSerializer):
- vs = rdd.mapPairs(lambda (k, v): (k, (1, v)))
- ws = other.mapPairs(lambda (k, v): (k, (2, v)))
- return vs.union(ws).groupByKey(numSplits) \
- .flatMapValues(dispatch, valSerializer)
+def _do_python_join(rdd, other, numSplits, dispatch):
+ vs = rdd.map(lambda (k, v): (k, (1, v)))
+ ws = other.map(lambda (k, v): (k, (2, v)))
+ return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch)
def python_join(rdd, other, numSplits):
@@ -50,8 +47,7 @@ def python_join(rdd, other, numSplits):
elif n == 2:
wbuf.append(v)
return [(v, w) for v in vbuf for w in wbuf]
- valSerializer = PairSerializer(rdd.valSerializer, other.valSerializer)
- return _do_python_join(rdd, other, numSplits, dispatch, valSerializer)
+ return _do_python_join(rdd, other, numSplits, dispatch)
def python_right_outer_join(rdd, other, numSplits):
@@ -65,9 +61,7 @@ def python_right_outer_join(rdd, other, numSplits):
if not vbuf:
vbuf.append(None)
return [(v, w) for v in vbuf for w in wbuf]
- valSerializer = PairSerializer(OptionSerializer(rdd.valSerializer),
- other.valSerializer)
- return _do_python_join(rdd, other, numSplits, dispatch, valSerializer)
+ return _do_python_join(rdd, other, numSplits, dispatch)
def python_left_outer_join(rdd, other, numSplits):
@@ -81,17 +75,12 @@ def python_left_outer_join(rdd, other, numSplits):
if not wbuf:
wbuf.append(None)
return [(v, w) for v in vbuf for w in wbuf]
- valSerializer = PairSerializer(rdd.valSerializer,
- OptionSerializer(other.valSerializer))
- return _do_python_join(rdd, other, numSplits, dispatch, valSerializer)
+ return _do_python_join(rdd, other, numSplits, dispatch)
def python_cogroup(rdd, other, numSplits):
- resultValSerializer = PairSerializer(
- ArraySerializer(rdd.valSerializer),
- ArraySerializer(other.valSerializer))
- vs = rdd.mapPairs(lambda (k, v): (k, (1, v)))
- ws = other.mapPairs(lambda (k, v): (k, (2, v)))
+ vs = rdd.map(lambda (k, v): (k, (1, v)))
+ ws = other.map(lambda (k, v): (k, (2, v)))
def dispatch(seq):
vbuf, wbuf = [], []
for (n, v) in seq:
@@ -100,5 +89,4 @@ def python_cogroup(rdd, other, numSplits):
elif n == 2:
wbuf.append(v)
return (vbuf, wbuf)
- return vs.union(ws).groupByKey(numSplits) \
- .mapValues(dispatch, resultValSerializer)
+ return vs.union(ws).groupByKey(numSplits).mapValues(dispatch)
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 5579c56de3..8eccddc0a2 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -1,31 +1,17 @@
from base64 import standard_b64encode as b64enc
-from pyspark import cloudpickle
-from itertools import chain
-from pyspark.serializers import PairSerializer, NopSerializer, \
- OptionSerializer, ArraySerializer
+from pyspark import cloudpickle
+from pyspark.serializers import PickleSerializer
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
class RDD(object):
- def __init__(self, jrdd, ctx, serializer=None):
+ def __init__(self, jrdd, ctx):
self._jrdd = jrdd
self.is_cached = False
self.ctx = ctx
- self.serializer = serializer or ctx.defaultSerializer
-
- def _builder(self, jrdd, ctx):
- return RDD(jrdd, ctx, self.serializer)
-
- @property
- def id(self):
- return self._jrdd.id()
-
- @property
- def splits(self):
- return self._jrdd.splits()
@classmethod
def _get_pipe_command(cls, command, functions):
@@ -41,55 +27,18 @@ class RDD(object):
self._jrdd.cache()
return self
- def map(self, f, serializer=None, preservesPartitioning=False):
- return MappedRDD(self, f, serializer, preservesPartitioning)
-
- def mapPairs(self, f, keySerializer=None, valSerializer=None,
- preservesPartitioning=False):
- return PairMappedRDD(self, f, keySerializer, valSerializer,
- preservesPartitioning)
+ def map(self, f, preservesPartitioning=False):
+ return MappedRDD(self, f, preservesPartitioning)
- def flatMap(self, f, serializer=None):
+ def flatMap(self, f):
"""
>>> rdd = sc.parallelize([2, 3, 4])
>>> sorted(rdd.flatMap(lambda x: range(1, x)).collect())
[1, 1, 1, 2, 2, 3]
- """
- serializer = serializer or self.ctx.defaultSerializer
- dumps = serializer.dumps
- loads = self.serializer.loads
- def func(x):
- pickled_elems = (dumps(y) for y in f(loads(x)))
- return "\n".join(pickled_elems) or None
- pipe_command = RDD._get_pipe_command("map", [func])
- class_manifest = self._jrdd.classManifest()
- jrdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
- False, self.ctx.pythonExec,
- class_manifest).asJavaRDD()
- return RDD(jrdd, self.ctx, serializer)
-
- def flatMapPairs(self, f, keySerializer=None, valSerializer=None,
- preservesPartitioning=False):
- """
- >>> rdd = sc.parallelize([2, 3, 4])
- >>> sorted(rdd.flatMapPairs(lambda x: [(x, x), (x, x)]).collect())
+ >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
- keySerializer = keySerializer or self.ctx.defaultSerializer
- valSerializer = valSerializer or self.ctx.defaultSerializer
- dumpk = keySerializer.dumps
- dumpv = valSerializer.dumps
- loads = self.serializer.loads
- def func(x):
- pairs = f(loads(x))
- pickled_pairs = ((dumpk(k), dumpv(v)) for (k, v) in pairs)
- return "\n".join(chain.from_iterable(pickled_pairs)) or None
- pipe_command = RDD._get_pipe_command("map", [func])
- class_manifest = self._jrdd.classManifest()
- python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command,
- preservesPartitioning, self.ctx.pythonExec, class_manifest)
- return PairRDD(python_rdd.asJavaPairRDD(), self.ctx, keySerializer,
- valSerializer)
+ return MappedRDD(self, f, preservesPartitioning=False, command='flatmap')
def filter(self, f):
"""
@@ -97,9 +46,8 @@ class RDD(object):
>>> rdd.filter(lambda x: x % 2 == 0).collect()
[2, 4]
"""
- loads = self.serializer.loads
- def filter_func(x): return x if f(loads(x)) else None
- return self._builder(self._pipe(filter_func), self.ctx)
+ def filter_func(x): return x if f(x) else None
+ return RDD(self._pipe(filter_func), self.ctx)
def _pipe(self, functions, command="map"):
class_manifest = self._jrdd.classManifest()
@@ -108,32 +56,22 @@ class RDD(object):
False, self.ctx.pythonExec, class_manifest)
return python_rdd.asJavaRDD()
- def _pipePairs(self, functions, command="mapPairs",
- preservesPartitioning=False):
- class_manifest = self._jrdd.classManifest()
- pipe_command = RDD._get_pipe_command(command, functions)
- python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), pipe_command,
- preservesPartitioning, self.ctx.pythonExec, class_manifest)
- return python_rdd.asJavaPairRDD()
-
def distinct(self):
"""
>>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
[1, 2, 3]
"""
- if self.serializer.is_comparable:
- return self._builder(self._jrdd.distinct(), self.ctx)
- return self.mapPairs(lambda x: (x, "")) \
+ return self.map(lambda x: (x, "")) \
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)
def sample(self, withReplacement, fraction, seed):
jrdd = self._jrdd.sample(withReplacement, fraction, seed)
- return self._builder(jrdd, self.ctx)
+ return RDD(jrdd, self.ctx)
def takeSample(self, withReplacement, num, seed):
vals = self._jrdd.takeSample(withReplacement, num, seed)
- return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals]
+ return [PickleSerializer.loads(x) for x in vals]
def union(self, other):
"""
@@ -141,7 +79,7 @@ class RDD(object):
>>> rdd.union(rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
"""
- return self._builder(self._jrdd.union(other._jrdd), self.ctx)
+ return RDD(self._jrdd.union(other._jrdd), self.ctx)
# TODO: sort
@@ -155,16 +93,17 @@ class RDD(object):
>>> sorted(rdd.cartesian(rdd).collect())
[(1, 1), (1, 2), (2, 1), (2, 2)]
"""
- return PairRDD(self._jrdd.cartesian(other._jrdd), self.ctx)
+ return RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
# numsplits
def groupBy(self, f, numSplits=None):
"""
>>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
- >>> sorted(rdd.groupBy(lambda x: x % 2).collect())
+ >>> result = rdd.groupBy(lambda x: x % 2).collect()
+ >>> sorted([(x, sorted(y)) for (x, y) in result])
[(0, [2, 8]), (1, [1, 1, 3, 5])]
"""
- return self.mapPairs(lambda x: (f(x), x)).groupByKey(numSplits)
+ return self.map(lambda x: (f(x), x)).groupByKey(numSplits)
# TODO: pipe
@@ -178,25 +117,19 @@ class RDD(object):
self.map(f).collect() # Force evaluation
def collect(self):
- vals = self._jrdd.collect()
- return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals]
+ pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect())
+ return PickleSerializer.loads(bytes(pickle))
- def reduce(self, f, serializer=None):
+ def reduce(self, f):
"""
- >>> import operator
- >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(operator.add)
+ >>> from operator import add
+ >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
15
+ >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
+ 10
"""
- serializer = serializer or self.ctx.defaultSerializer
- loads = self.serializer.loads
- dumps = serializer.dumps
- def reduceFunction(x, acc):
- if acc is None:
- return loads(x)
- else:
- return f(loads(x), acc)
- vals = self._pipe([reduceFunction, dumps], command="reduce").collect()
- return reduce(f, (serializer.loads(x) for x in vals))
+ vals = MappedRDD(self, f, command="reduce", preservesPartitioning=False).collect()
+ return reduce(f, vals)
# TODO: fold
@@ -216,36 +149,35 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4]).take(2)
[2, 3]
"""
- vals = self._jrdd.take(num)
- return [self.serializer.loads(self.ctx.python_dump(x)) for x in vals]
+ pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num))
+ return PickleSerializer.loads(bytes(pickle))
def first(self):
"""
>>> sc.parallelize([2, 3, 4]).first()
2
"""
- return self.serializer.loads(self.ctx.python_dump(self._jrdd.first()))
+ return PickleSerializer.loads(bytes(self.ctx.asPickle(self._jrdd.first())))
# TODO: saveAsTextFile
# TODO: saveAsObjectFile
+ # Pair functions
-class PairRDD(RDD):
-
- def __init__(self, jrdd, ctx, keySerializer=None, valSerializer=None):
- RDD.__init__(self, jrdd, ctx)
- self.keySerializer = keySerializer or ctx.defaultSerializer
- self.valSerializer = valSerializer or ctx.defaultSerializer
- self.serializer = \
- PairSerializer(self.keySerializer, self.valSerializer)
-
- def _builder(self, jrdd, ctx):
- return PairRDD(jrdd, ctx, self.keySerializer, self.valSerializer)
+ def collectAsMap(self):
+ """
+ >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap()
+ >>> m[1]
+ 2
+ >>> m[3]
+ 4
+ """
+ return dict(self.collect())
def reduceByKey(self, func, numSplits=None):
"""
- >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)])
+ >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(x.reduceByKey(lambda a, b: a + b).collect())
[('a', 2), ('b', 1)]
"""
@@ -259,90 +191,67 @@ class PairRDD(RDD):
def join(self, other, numSplits=None):
"""
- >>> x = sc.parallelizePairs([("a", 1), ("b", 4)])
- >>> y = sc.parallelizePairs([("a", 2), ("a", 3)])
- >>> x.join(y).collect()
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2), ("a", 3)])
+ >>> sorted(x.join(y).collect())
[('a', (1, 2)), ('a', (1, 3))]
-
- Check that we get a PairRDD-like object back:
- >>> assert x.join(y).join
"""
- assert self.keySerializer.name == other.keySerializer.name
- if self.keySerializer.is_comparable:
- return PairRDD(self._jrdd.join(other._jrdd),
- self.ctx, self.keySerializer,
- PairSerializer(self.valSerializer, other.valSerializer))
- else:
- return python_join(self, other, numSplits)
+ return python_join(self, other, numSplits)
def leftOuterJoin(self, other, numSplits=None):
"""
- >>> x = sc.parallelizePairs([("a", 1), ("b", 4)])
- >>> y = sc.parallelizePairs([("a", 2)])
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
>>> sorted(x.leftOuterJoin(y).collect())
[('a', (1, 2)), ('b', (4, None))]
"""
- assert self.keySerializer.name == other.keySerializer.name
- if self.keySerializer.is_comparable:
- return PairRDD(self._jrdd.leftOuterJoin(other._jrdd),
- self.ctx, self.keySerializer,
- PairSerializer(self.valSerializer,
- OptionSerializer(other.valSerializer)))
- else:
- return python_left_outer_join(self, other, numSplits)
+ return python_left_outer_join(self, other, numSplits)
def rightOuterJoin(self, other, numSplits=None):
"""
- >>> x = sc.parallelizePairs([("a", 1), ("b", 4)])
- >>> y = sc.parallelizePairs([("a", 2)])
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
>>> sorted(y.rightOuterJoin(x).collect())
[('a', (2, 1)), ('b', (None, 4))]
"""
- assert self.keySerializer.name == other.keySerializer.name
- if self.keySerializer.is_comparable:
- return PairRDD(self._jrdd.rightOuterJoin(other._jrdd),
- self.ctx, self.keySerializer,
- PairSerializer(OptionSerializer(self.valSerializer),
- other.valSerializer))
- else:
- return python_right_outer_join(self, other, numSplits)
+ return python_right_outer_join(self, other, numSplits)
+
+ # TODO: pipelining
+ # TODO: optimizations
+ def shuffle(self, numSplits):
+ if numSplits is None:
+ numSplits = self.ctx.defaultParallelism
+ pipe_command = RDD._get_pipe_command('shuffle_map_step', [])
+ class_manifest = self._jrdd.classManifest()
+ python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(),
+ pipe_command, False, self.ctx.pythonExec, class_manifest)
+ partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
+ jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner)
+ jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
+ # TODO: extract second value.
+ return RDD(jrdd, self.ctx)
+
+
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
- numSplits=None, serializer=None):
+ numSplits=None):
"""
- >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)])
+ >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> def f(x): return x
>>> def add(a, b): return a + str(b)
>>> sorted(x.combineByKey(str, add, add).collect())
[('a', '11'), ('b', '1')]
"""
- serializer = serializer or self.ctx.defaultSerializer
if numSplits is None:
numSplits = self.ctx.defaultParallelism
- # Use hash() to create keys that are comparable in Java.
- loadkv = self.serializer.loads
- def pairify(kv):
- # TODO: add method to deserialize only the key or value from
- # a PairSerializer?
- key = loadkv(kv)[0]
- return (str(hash(key)), kv)
- partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
- jrdd = self._pipePairs(pairify).partitionBy(partitioner)
- pairified = PairRDD(jrdd, self.ctx, NopSerializer, self.serializer)
-
- loads = PairSerializer(NopSerializer, self.serializer).loads
- dumpk = self.keySerializer.dumps
- dumpc = serializer.dumps
-
- functions = [createCombiner, mergeValue, mergeCombiners, loads, dumpk,
- dumpc]
- jpairs = pairified._pipePairs(functions, "combine_by_key",
- preservesPartitioning=True)
- return PairRDD(jpairs, self.ctx, self.keySerializer, serializer)
+ shuffled = self.shuffle(numSplits)
+ functions = [createCombiner, mergeValue, mergeCombiners]
+ jpairs = shuffled._pipe(functions, "combine_by_key")
+ return RDD(jpairs, self.ctx)
def groupByKey(self, numSplits=None):
"""
- >>> x = sc.parallelizePairs([("a", 1), ("b", 1), ("a", 1)])
+ >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(x.groupByKey().collect())
[('a', [1, 1]), ('b', [1])]
"""
@@ -360,29 +269,15 @@ class PairRDD(RDD):
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
numSplits)
- def collectAsMap(self):
- """
- >>> m = sc.parallelizePairs([(1, 2), (3, 4)]).collectAsMap()
- >>> m[1]
- 2
- >>> m[3]
- 4
- """
- m = self._jrdd.collectAsMap()
- def loads(x):
- (k, v) = x
- return (self.keySerializer.loads(k), self.valSerializer.loads(v))
- return dict(loads(x) for x in m.items())
-
- def flatMapValues(self, f, valSerializer=None):
+ def flatMapValues(self, f):
flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
- return self.flatMapPairs(flat_map_fn, self.keySerializer,
- valSerializer, True)
+ return self.flatMap(flat_map_fn)
- def mapValues(self, f, valSerializer=None):
+ def mapValues(self, f):
map_values_fn = lambda (k, v): (k, f(v))
- return self.mapPairs(map_values_fn, self.keySerializer, valSerializer,
- True)
+ return self.map(map_values_fn, preservesPartitioning=True)
+
+ # TODO: implement shuffle.
# TODO: support varargs cogroup of several RDDs.
def groupWith(self, other):
@@ -390,20 +285,12 @@ class PairRDD(RDD):
def cogroup(self, other, numSplits=None):
"""
- >>> x = sc.parallelizePairs([("a", 1), ("b", 4)])
- >>> y = sc.parallelizePairs([("a", 2)])
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
>>> x.cogroup(y).collect()
[('a', ([1], [2])), ('b', ([4], []))]
"""
- assert self.keySerializer.name == other.keySerializer.name
- resultValSerializer = PairSerializer(
- ArraySerializer(self.valSerializer),
- ArraySerializer(other.valSerializer))
- if self.keySerializer.is_comparable:
- return PairRDD(self._jrdd.cogroup(other._jrdd),
- self.ctx, self.keySerializer, resultValSerializer)
- else:
- return python_cogroup(self, other, numSplits)
+ return python_cogroup(self, other, numSplits)
# TODO: `lookup` is disabled because we can't make direct comparisons based
# on the key; we need to compare the hash of the key to the hash of the
@@ -413,44 +300,84 @@ class PairRDD(RDD):
# TODO: file saving
-class MappedRDDBase(object):
- def __init__(self, prev, func, serializer, preservesPartitioning=False):
- if isinstance(prev, MappedRDDBase) and not prev.is_cached:
+class MappedRDD(RDD):
+ """
+ Pipelined maps:
+ >>> rdd = sc.parallelize([1, 2, 3, 4])
+ >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
+ [4, 8, 12, 16]
+ >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
+ [4, 8, 12, 16]
+
+ Pipelined reduces:
+ >>> from operator import add
+ >>> rdd.map(lambda x: 2 * x).reduce(add)
+ 20
+ >>> rdd.flatMap(lambda x: [x, x]).reduce(add)
+ 20
+ """
+ def __init__(self, prev, func, preservesPartitioning=False, command='map'):
+ if isinstance(prev, MappedRDD) and not prev.is_cached:
prev_func = prev.func
- self.func = lambda x: func(prev_func(x))
+ if command == 'reduce':
+ if prev.command == 'flatmap':
+ def flatmap_reduce_func(x, acc):
+ values = prev_func(x)
+ if values is None:
+ return acc
+ if not acc:
+ if len(values) == 1:
+ return values[0]
+ else:
+ return reduce(func, values[1:], values[0])
+ else:
+ return reduce(func, values, acc)
+ self.func = flatmap_reduce_func
+ else:
+ def reduce_func(x, acc):
+ val = prev_func(x)
+ if not val:
+ return acc
+ if acc is None:
+ return val
+ else:
+ return func(val, acc)
+ self.func = reduce_func
+ else:
+ if prev.command == 'flatmap':
+ command = 'flatmap'
+ self.func = lambda x: (func(y) for y in prev_func(x))
+ else:
+ self.func = lambda x: func(prev_func(x))
+
self.preservesPartitioning = \
prev.preservesPartitioning and preservesPartitioning
self._prev_jrdd = prev._prev_jrdd
- self._prev_serializer = prev._prev_serializer
+ self.is_pipelined = True
else:
- self.func = func
+ if command == 'reduce':
+ def reduce_func(val, acc):
+ if acc is None:
+ return val
+ else:
+ return func(val, acc)
+ self.func = reduce_func
+ else:
+ self.func = func
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
- self._prev_serializer = prev.serializer
- self.serializer = serializer or prev.ctx.defaultSerializer
+ self.is_pipelined = False
self.is_cached = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
-
-
-class MappedRDD(MappedRDDBase, RDD):
- """
- >>> rdd = sc.parallelize([1, 2, 3, 4])
- >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
- [4, 8, 12, 16]
- >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
- [4, 8, 12, 16]
- """
+ self.command = command
@property
def _jrdd(self):
if not self._jrdd_val:
- udf = self.func
- loads = self._prev_serializer.loads
- dumps = self.serializer.dumps
- func = lambda x: dumps(udf(loads(x)))
- pipe_command = RDD._get_pipe_command("map", [func])
+ funcs = [self.func]
+ pipe_command = RDD._get_pipe_command(self.command, funcs)
class_manifest = self._prev_jrdd.classManifest()
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
@@ -459,56 +386,11 @@ class MappedRDD(MappedRDDBase, RDD):
return self._jrdd_val
-class PairMappedRDD(MappedRDDBase, PairRDD):
- """
- >>> rdd = sc.parallelize([1, 2, 3, 4])
- >>> rdd.mapPairs(lambda x: (x, x)) \\
- ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\
- ... .collect()
- [(2, 2), (4, 4), (6, 6), (8, 8)]
- >>> rdd.mapPairs(lambda x: (x, x)) \\
- ... .mapPairs(lambda (x, y): (2*x, 2*y)) \\
- ... .map(lambda (x, _): x).collect()
- [2, 4, 6, 8]
- """
-
- def __init__(self, prev, func, keySerializer=None, valSerializer=None,
- preservesPartitioning=False):
- self.keySerializer = keySerializer or prev.ctx.defaultSerializer
- self.valSerializer = valSerializer or prev.ctx.defaultSerializer
- serializer = PairSerializer(self.keySerializer, self.valSerializer)
- MappedRDDBase.__init__(self, prev, func, serializer,
- preservesPartitioning)
-
- @property
- def _jrdd(self):
- if not self._jrdd_val:
- udf = self.func
- loads = self._prev_serializer.loads
- dumpk = self.keySerializer.dumps
- dumpv = self.valSerializer.dumps
- def func(x):
- (k, v) = udf(loads(x))
- return (dumpk(k), dumpv(v))
- pipe_command = RDD._get_pipe_command("mapPairs", [func])
- class_manifest = self._prev_jrdd.classManifest()
- self._jrdd_val = self.ctx.jvm.PythonPairRDD(self._prev_jrdd.rdd(),
- pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
- class_manifest).asJavaPairRDD()
- return self._jrdd_val
-
-
def _test():
import doctest
from pyspark.context import SparkContext
- from pyspark.serializers import PickleSerializer, JSONSerializer
globs = globals().copy()
- globs['sc'] = SparkContext('local', 'PythonTest',
- defaultSerializer=JSONSerializer)
- doctest.testmod(globs=globs)
- globs['sc'].stop()
- globs['sc'] = SparkContext('local', 'PythonTest',
- defaultSerializer=PickleSerializer)
+ globs['sc'] = SparkContext('local', 'PythonTest')
doctest.testmod(globs=globs)
globs['sc'].stop()
diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py
index b113f5656b..7b3e6966e1 100644
--- a/pyspark/pyspark/serializers.py
+++ b/pyspark/pyspark/serializers.py
@@ -2,228 +2,35 @@
Data serialization methods.
The Spark Python API is built on top of the Spark Java API. RDDs created in
-Python are stored in Java as RDDs of Strings. Python objects are automatically
-serialized/deserialized, so this representation is transparent to the end-user.
-
-------------------
-Serializer objects
-------------------
-
-`Serializer` objects are used to customize how an RDD's values are serialized.
-
-Each `Serializer` is a named tuple with four fields:
-
- - A `dumps` function, for serializing a Python object to a string.
-
- - A `loads` function, for deserializing a Python object from a string.
-
- - An `is_comparable` field, True if equal Python objects are serialized to
- equal strings, and False otherwise.
-
- - A `name` field, used to identify the Serializer. Serializers are
- compared for equality by comparing their names.
-
-The serializer's output should be base64-encoded.
-
-------------------------------------------------------------------
-`is_comparable`: comparing serialized representations for equality
-------------------------------------------------------------------
-
-If `is_comparable` is False, the serializer's representations of equal objects
-are not required to be equal:
-
->>> import pickle
->>> a = {1: 0, 9: 0}
->>> b = {9: 0, 1: 0}
->>> a == b
-True
->>> pickle.dumps(a) == pickle.dumps(b)
-False
-
-RDDs with comparable serializers can use native Java implementations of
-operations like join() and distinct(), which may lead to better performance by
-eliminating deserialization and Python comparisons.
-
-The default JSONSerializer produces comparable representations of common Python
-data structures.
-
---------------------------------------
-Examples of serialized representations
---------------------------------------
-
-The RDD transformations that use Python UDFs are implemented in terms of
-a modified `PipedRDD.pipe()` function. For each record `x` in the RDD, the
-`pipe()` function pipes `x.toString()` to a Python worker process, which
-deserializes the string into a Python object, executes user-defined functions,
-and outputs serialized Python objects.
-
-The regular `toString()` method returns an ambiguous representation, due to the
-way that Scala `Option` instances are printed:
-
->>> from context import SparkContext
->>> sc = SparkContext("local", "SerializerDocs")
->>> x = sc.parallelizePairs([("a", 1), ("b", 4)])
->>> y = sc.parallelizePairs([("a", 2)])
-
->>> print y.rightOuterJoin(x)._jrdd.first().toString()
-(ImEi,(Some(Mg==),MQ==))
-
-In Java, preprocessing is performed to handle Option instances, so the Python
-process receives unambiguous input:
-
->>> print sc.python_dump(y.rightOuterJoin(x)._jrdd.first())
-(ImEi,(Mg==,MQ==))
-
-The base64-encoding eliminates the need to escape newlines, parentheses and
-other special characters.
-
-----------------------
-Serializer composition
-----------------------
-
-In order to handle nested structures, which could contain object serialized
-with different serializers, the RDD module composes serializers. For example,
-the serializers in the previous example are:
-
->>> print x.serializer.name
-PairSerializer<JSONSerializer, JSONSerializer>
-
->>> print y.serializer.name
-PairSerializer<JSONSerializer, JSONSerializer>
-
->>> print y.rightOuterJoin(x).serializer.name
-PairSerializer<JSONSerializer, PairSerializer<OptionSerializer<JSONSerializer>, JSONSerializer>>
+Python are stored in Java as RDD[Array[Byte]]. Python objects are
+automatically serialized/deserialized, so this representation is transparent to
+the end-user.
"""
-from base64 import standard_b64encode, standard_b64decode
from collections import namedtuple
import cPickle
-import simplejson
-
-
-Serializer = namedtuple("Serializer",
- ["dumps","loads", "is_comparable", "name"])
-
-
-NopSerializer = Serializer(str, str, True, "NopSerializer")
+import struct
-JSONSerializer = Serializer(
- lambda obj: standard_b64encode(simplejson.dumps(obj, sort_keys=True,
- separators=(',', ':'))),
- lambda s: simplejson.loads(standard_b64decode(s)),
- True,
- "JSONSerializer"
-)
+Serializer = namedtuple("Serializer", ["dumps","loads"])
PickleSerializer = Serializer(
- lambda obj: standard_b64encode(cPickle.dumps(obj)),
- lambda s: cPickle.loads(standard_b64decode(s)),
- False,
- "PickleSerializer"
-)
-
-
-def OptionSerializer(serializer):
- """
- >>> ser = OptionSerializer(NopSerializer)
- >>> ser.loads(ser.dumps("Hello, World!"))
- 'Hello, World!'
- >>> ser.loads(ser.dumps(None)) is None
- True
- """
- none_placeholder = '*'
-
- def dumps(x):
- if x is None:
- return none_placeholder
- else:
- return serializer.dumps(x)
-
- def loads(x):
- if x == none_placeholder:
- return None
- else:
- return serializer.loads(x)
-
- name = "OptionSerializer<%s>" % serializer.name
- return Serializer(dumps, loads, serializer.is_comparable, name)
-
-
-def PairSerializer(keySerializer, valSerializer):
- """
- Returns a Serializer for a (key, value) pair.
-
- >>> ser = PairSerializer(JSONSerializer, JSONSerializer)
- >>> ser.loads(ser.dumps((1, 2)))
- (1, 2)
-
- >>> ser = PairSerializer(JSONSerializer, ser)
- >>> ser.loads(ser.dumps((1, (2, 3))))
- (1, (2, 3))
- """
- def loads(kv):
- try:
- (key, val) = kv[1:-1].split(',', 1)
- key = keySerializer.loads(key)
- val = valSerializer.loads(val)
- return (key, val)
- except:
- print "Error in deserializing pair from '%s'" % str(kv)
- raise
-
- def dumps(kv):
- (key, val) = kv
- return"(%s,%s)" % (keySerializer.dumps(key), valSerializer.dumps(val))
- is_comparable = \
- keySerializer.is_comparable and valSerializer.is_comparable
- name = "PairSerializer<%s, %s>" % (keySerializer.name, valSerializer.name)
- return Serializer(dumps, loads, is_comparable, name)
-
-
-def ArraySerializer(serializer):
- """
- >>> ser = ArraySerializer(JSONSerializer)
- >>> ser.loads(ser.dumps([1, 2, 3, 4]))
- [1, 2, 3, 4]
- >>> ser = ArraySerializer(PairSerializer(JSONSerializer, PickleSerializer))
- >>> ser.loads(ser.dumps([('a', 1), ('b', 2)]))
- [('a', 1), ('b', 2)]
- >>> ser.loads(ser.dumps([('a', 1)]))
- [('a', 1)]
- >>> ser.loads(ser.dumps([]))
- []
- """
- def dumps(arr):
- if arr == []:
- return '[]'
- else:
- return '[' + '|'.join(serializer.dumps(x) for x in arr) + ']'
-
- def loads(s):
- if s == '[]':
- return []
- items = s[1:-1]
- if '|' in items:
- items = items.split('|')
- else:
- items = [items]
- return [serializer.loads(x) for x in items]
-
- name = "ArraySerializer<%s>" % serializer.name
- return Serializer(dumps, loads, serializer.is_comparable, name)
-
-
-# TODO: IntegerSerializer
-
-
-# TODO: DoubleSerializer
+ lambda obj: cPickle.dumps(obj, -1),
+ cPickle.loads)
-def _test():
- import doctest
- doctest.testmod()
+def dumps(obj, stream):
+ # TODO: determining the length of non-byte objects.
+ stream.write(struct.pack("!i", len(obj)))
+ stream.write(obj)
-if __name__ == "__main__":
- _test()
+def loads(stream):
+ length = stream.read(4)
+ if length == "":
+ raise EOFError
+ length = struct.unpack("!i", length)[0]
+ obj = stream.read(length)
+ if obj == "":
+ raise EOFError
+ return obj
diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py
index 4c4b02fce4..21ff84fb17 100644
--- a/pyspark/pyspark/worker.py
+++ b/pyspark/pyspark/worker.py
@@ -6,9 +6,9 @@ from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
from pyspark.cloudpickle import CloudPickler
+from pyspark.serializers import dumps, loads, PickleSerializer
import cPickle
-
# Redirect stdout to stderr so that users must return values from functions.
old_stdout = sys.stdout
sys.stdout = sys.stderr
@@ -19,58 +19,64 @@ def load_function():
def output(x):
- for line in x.split("\n"):
- old_stdout.write(line.rstrip("\r\n") + "\n")
+ dumps(x, old_stdout)
def read_input():
- for line in sys.stdin:
- yield line.rstrip("\r\n")
-
+ try:
+ while True:
+ yield loads(sys.stdin)
+ except EOFError:
+ return
def do_combine_by_key():
create_combiner = load_function()
merge_value = load_function()
merge_combiners = load_function() # TODO: not used.
- depickler = load_function()
- key_pickler = load_function()
- combiner_pickler = load_function()
combiners = {}
- for line in read_input():
- # Discard the hashcode added in the Python combineByKey() method.
- (key, value) = depickler(line)[1]
+ for obj in read_input():
+ (key, value) = PickleSerializer.loads(obj)
if key not in combiners:
combiners[key] = create_combiner(value)
else:
combiners[key] = merge_value(combiners[key], value)
for (key, combiner) in combiners.iteritems():
- output(key_pickler(key))
- output(combiner_pickler(combiner))
+ output(PickleSerializer.dumps((key, combiner)))
-def do_map(map_pairs=False):
+def do_map(flat=False):
f = load_function()
- for line in read_input():
+ for obj in read_input():
try:
- out = f(line)
+ #from pickletools import dis
+ #print repr(obj)
+ #print dis(obj)
+ out = f(PickleSerializer.loads(obj))
if out is not None:
- if map_pairs:
+ if flat:
for x in out:
- output(x)
+ output(PickleSerializer.dumps(x))
else:
- output(out)
+ output(PickleSerializer.dumps(out))
except:
- sys.stderr.write("Error processing line '%s'\n" % line)
+ sys.stderr.write("Error processing obj %s\n" % repr(obj))
raise
+def do_shuffle_map_step():
+ for obj in read_input():
+ key = PickleSerializer.loads(obj)[1]
+ output(str(hash(key)))
+ output(obj)
+
+
def do_reduce():
f = load_function()
- dumps = load_function()
acc = None
- for line in read_input():
- acc = f(line, acc)
- output(dumps(acc))
+ for obj in read_input():
+ acc = f(PickleSerializer.loads(obj), acc)
+ if acc is not None:
+ output(PickleSerializer.dumps(acc))
def do_echo():
@@ -80,13 +86,15 @@ def do_echo():
def main():
command = sys.stdin.readline().strip()
if command == "map":
- do_map(map_pairs=False)
- elif command == "mapPairs":
- do_map(map_pairs=True)
+ do_map(flat=False)
+ elif command == "flatmap":
+ do_map(flat=True)
elif command == "combine_by_key":
do_combine_by_key()
elif command == "reduce":
do_reduce()
+ elif command == "shuffle_map_step":
+ do_shuffle_map_step()
elif command == "echo":
do_echo()
else: