aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-25 16:46:07 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-27 00:24:39 -0700
commit200d248dcc5903295296bf897211cf543b37f8c1 (patch)
tree46df15fbccf99489a1f7f240c71cc56ef083d6d8 /pyspark
parent6904cb77d4306a14891cc71338c8f9f966d009f1 (diff)
downloadspark-200d248dcc5903295296bf897211cf543b37f8c1.tar.gz
spark-200d248dcc5903295296bf897211cf543b37f8c1.tar.bz2
spark-200d248dcc5903295296bf897211cf543b37f8c1.zip
Simplify Python worker; pipeline the map step of partitionBy().
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/context.py9
-rw-r--r--pyspark/pyspark/rdd.py70
-rw-r--r--pyspark/pyspark/serializers.py23
-rw-r--r--pyspark/pyspark/worker.py50
4 files changed, 52 insertions, 100 deletions
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index 6f87206665..b8490019e3 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile
from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import PickleSerializer, dumps
+from pyspark.serializers import dump_pickle, write_with_length
from pyspark.rdd import RDD
@@ -16,9 +16,8 @@ class SparkContext(object):
asPickle = jvm.spark.api.python.PythonRDD.asPickle
arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle
-
def __init__(self, master, name, defaultParallelism=None,
- pythonExec='python'):
+ pythonExec='python'):
self.master = master
self.name = name
self._jsc = self.jvm.JavaSparkContext(master, name)
@@ -52,7 +51,7 @@ class SparkContext(object):
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False)
for x in c:
- dumps(PickleSerializer.dumps(x), tempFile)
+ write_with_length(dump_pickle(x), tempFile)
tempFile.close()
atexit.register(lambda: os.unlink(tempFile.name))
jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
@@ -64,6 +63,6 @@ class SparkContext(object):
return RDD(jrdd, self)
def broadcast(self, value):
- jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value)))
+ jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
self._pickled_broadcast_vars)
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 3528b8f308..21e822ba9f 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -3,7 +3,7 @@ from collections import Counter
from itertools import chain, ifilter, imap
from pyspark import cloudpickle
-from pyspark.serializers import PickleSerializer
+from pyspark.serializers import dump_pickle, load_pickle
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
@@ -17,17 +17,6 @@ class RDD(object):
self.is_cached = False
self.ctx = ctx
- @classmethod
- def _get_pipe_command(cls, ctx, command, functions):
- worker_args = [command]
- for f in functions:
- worker_args.append(b64enc(cloudpickle.dumps(f)))
- broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars]
- broadcast_vars = ListConverter().convert(broadcast_vars,
- ctx.gateway._gateway_client)
- ctx._pickled_broadcast_vars.clear()
- return (" ".join(worker_args), broadcast_vars)
-
def cache(self):
self.is_cached = True
self._jrdd.cache()
@@ -66,14 +55,6 @@ class RDD(object):
def func(iterator): return ifilter(f, iterator)
return self.mapPartitions(func)
- def _pipe(self, functions, command):
- class_manifest = self._jrdd.classManifest()
- (pipe_command, broadcast_vars) = \
- RDD._get_pipe_command(self.ctx, command, functions)
- python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
- False, self.ctx.pythonExec, broadcast_vars, class_manifest)
- return python_rdd.asJavaRDD()
-
def distinct(self):
"""
>>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
@@ -89,7 +70,7 @@ class RDD(object):
def takeSample(self, withReplacement, num, seed):
vals = self._jrdd.takeSample(withReplacement, num, seed)
- return [PickleSerializer.loads(bytes(x)) for x in vals]
+ return [load_pickle(bytes(x)) for x in vals]
def union(self, other):
"""
@@ -148,7 +129,7 @@ class RDD(object):
def collect(self):
pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect())
- return PickleSerializer.loads(bytes(pickle))
+ return load_pickle(bytes(pickle))
def reduce(self, f):
"""
@@ -216,19 +197,17 @@ class RDD(object):
[2, 3]
"""
pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num))
- return PickleSerializer.loads(bytes(pickle))
+ return load_pickle(bytes(pickle))
def first(self):
"""
>>> sc.parallelize([2, 3, 4]).first()
2
"""
- return PickleSerializer.loads(bytes(self.ctx.asPickle(self._jrdd.first())))
+ return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first())))
# TODO: saveAsTextFile
- # TODO: saveAsObjectFile
-
# Pair functions
def collectAsMap(self):
@@ -303,19 +282,18 @@ class RDD(object):
"""
return python_right_outer_join(self, other, numSplits)
- # TODO: pipelining
- # TODO: optimizations
def partitionBy(self, numSplits, hashFunc=hash):
if numSplits is None:
numSplits = self.ctx.defaultParallelism
- (pipe_command, broadcast_vars) = \
- RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc])
- class_manifest = self._jrdd.classManifest()
- python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(),
- pipe_command, False, self.ctx.pythonExec, broadcast_vars,
- class_manifest)
+ def add_shuffle_key(iterator):
+ for (k, v) in iterator:
+ yield str(hashFunc(k))
+ yield dump_pickle((k, v))
+ keyed = PipelinedRDD(self, add_shuffle_key)
+ keyed._bypass_serializer = True
+ pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
- jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner)
+ jrdd = pairRDD.partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
return RDD(jrdd, self.ctx)
@@ -430,17 +408,23 @@ class PipelinedRDD(RDD):
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
+ self._bypass_serializer = False
@property
def _jrdd(self):
- if not self._jrdd_val:
- (pipe_command, broadcast_vars) = \
- RDD._get_pipe_command(self.ctx, "pipeline", [self.func])
- class_manifest = self._prev_jrdd.classManifest()
- python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
- pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
- broadcast_vars, class_manifest)
- self._jrdd_val = python_rdd.asJavaRDD()
+ if self._jrdd_val:
+ return self._jrdd_val
+ funcs = [self.func, self._bypass_serializer]
+ pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs)
+ broadcast_vars = ListConverter().convert(
+ [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
+ self.ctx.gateway._gateway_client)
+ self.ctx._pickled_broadcast_vars.clear()
+ class_manifest = self._prev_jrdd.classManifest()
+ python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
+ pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
+ broadcast_vars, class_manifest)
+ self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py
index 7b3e6966e1..faa1e683c7 100644
--- a/pyspark/pyspark/serializers.py
+++ b/pyspark/pyspark/serializers.py
@@ -1,31 +1,20 @@
-"""
-Data serialization methods.
-
-The Spark Python API is built on top of the Spark Java API. RDDs created in
-Python are stored in Java as RDD[Array[Byte]]. Python objects are
-automatically serialized/deserialized, so this representation is transparent to
-the end-user.
-"""
-from collections import namedtuple
-import cPickle
import struct
+import cPickle
-Serializer = namedtuple("Serializer", ["dumps","loads"])
+def dump_pickle(obj):
+ return cPickle.dumps(obj, 2)
-PickleSerializer = Serializer(
- lambda obj: cPickle.dumps(obj, -1),
- cPickle.loads)
+load_pickle = cPickle.loads
-def dumps(obj, stream):
- # TODO: determining the length of non-byte objects.
+def write_with_length(obj, stream):
stream.write(struct.pack("!i", len(obj)))
stream.write(obj)
-def loads(stream):
+def read_with_length(stream):
length = stream.read(4)
if length == "":
raise EOFError
diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py
index 0f90c6ff46..a9ed71892f 100644
--- a/pyspark/pyspark/worker.py
+++ b/pyspark/pyspark/worker.py
@@ -7,61 +7,41 @@ from base64 import standard_b64decode
# copy_reg module.
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import dumps, loads, PickleSerializer
-import cPickle
+from pyspark.serializers import write_with_length, read_with_length, \
+ dump_pickle, load_pickle
+
# Redirect stdout to stderr so that users must return values from functions.
old_stdout = sys.stdout
sys.stdout = sys.stderr
-def load_function():
- return cPickle.loads(standard_b64decode(sys.stdin.readline().strip()))
-
-
-def output(x):
- dumps(x, old_stdout)
+def load_obj():
+ return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
def read_input():
try:
while True:
- yield cPickle.loads(loads(sys.stdin))
+ yield load_pickle(read_with_length(sys.stdin))
except EOFError:
return
-def do_pipeline():
- f = load_function()
- for obj in f(read_input()):
- output(PickleSerializer.dumps(obj))
-
-
-def do_shuffle_map_step():
- hashFunc = load_function()
- while True:
- try:
- pickled = loads(sys.stdin)
- except EOFError:
- return
- key = cPickle.loads(pickled)[0]
- output(str(hashFunc(key)))
- output(pickled)
-
-
def main():
num_broadcast_variables = int(sys.stdin.readline().strip())
for _ in range(num_broadcast_variables):
uuid = sys.stdin.read(36)
- value = loads(sys.stdin)
- _broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value))
- command = sys.stdin.readline().strip()
- if command == "pipeline":
- do_pipeline()
- elif command == "shuffle_map_step":
- do_shuffle_map_step()
+ value = read_with_length(sys.stdin)
+ _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value))
+ func = load_obj()
+ bypassSerializer = load_obj()
+ if bypassSerializer:
+ dumps = lambda x: x
else:
- raise Exception("Unsupported command %s" % command)
+ dumps = dump_pickle
+ for obj in func(read_input()):
+ write_with_length(dumps(obj), old_stdout)
if __name__ == '__main__':