aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-25 13:59:01 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-27 00:16:47 -0700
commitf79a1e4d2a8643157136de69b8d7de84f0034712 (patch)
tree679f4453a382e121ffd4fc0c9f3a77a4b292d14c /pyspark
parent65e8406029a0fe1e1c5c5d033d335b43f6743a04 (diff)
downloadspark-f79a1e4d2a8643157136de69b8d7de84f0034712.tar.gz
spark-f79a1e4d2a8643157136de69b8d7de84f0034712.tar.bz2
spark-f79a1e4d2a8643157136de69b8d7de84f0034712.zip
Add broadcast variables to Python API.
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/broadcast.py46
-rw-r--r--pyspark/pyspark/context.py17
-rw-r--r--pyspark/pyspark/rdd.py27
-rw-r--r--pyspark/pyspark/worker.py6
4 files changed, 84 insertions, 12 deletions
diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py
new file mode 100644
index 0000000000..1ea17d59af
--- /dev/null
+++ b/pyspark/pyspark/broadcast.py
@@ -0,0 +1,46 @@
+"""
+>>> from pyspark.context import SparkContext
+>>> sc = SparkContext('local', 'test')
+>>> b = sc.broadcast([1, 2, 3, 4, 5])
+>>> b.value
+[1, 2, 3, 4, 5]
+
+>>> from pyspark.broadcast import _broadcastRegistry
+>>> _broadcastRegistry[b.uuid] = b
+>>> from cPickle import dumps, loads
+>>> loads(dumps(b)).value
+[1, 2, 3, 4, 5]
+
+>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
+[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
+"""
+# Holds broadcasted data received from Java, keyed by UUID.
+_broadcastRegistry = {}
+
+
+def _from_uuid(uuid):
+ from pyspark.broadcast import _broadcastRegistry
+ if uuid not in _broadcastRegistry:
+ raise Exception("Broadcast variable '%s' not loaded!" % uuid)
+ return _broadcastRegistry[uuid]
+
+
+class Broadcast(object):
+ def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None):
+ self.value = value
+ self.uuid = uuid
+ self._jbroadcast = java_broadcast
+ self._pickle_registry = pickle_registry
+
+ def __reduce__(self):
+ self._pickle_registry.add(self)
+ return (_from_uuid, (self.uuid, ))
+
+
+def _test():
+ import doctest
+ doctest.testmod()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index ac7e4057e9..6f87206665 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -2,6 +2,7 @@ import os
import atexit
from tempfile import NamedTemporaryFile
+from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, dumps
from pyspark.rdd import RDD
@@ -24,6 +25,11 @@ class SparkContext(object):
self.defaultParallelism = \
defaultParallelism or self._jsc.sc().defaultParallelism()
self.pythonExec = pythonExec
+ # Broadcast's __reduce__ method stores Broadcast instances here.
+ # This allows other code to determine which Broadcast instances have
+ # been pickled, so it can determine which Java broadcast objects to
+ # send.
+ self._pickled_broadcast_vars = set()
def __del__(self):
if self._jsc:
@@ -52,7 +58,12 @@ class SparkContext(object):
jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
- def textFile(self, name, numSlices=None):
- numSlices = numSlices or self.defaultParallelism
- jrdd = self._jsc.textFile(name, numSlices)
+ def textFile(self, name, minSplits=None):
+ minSplits = minSplits or min(self.defaultParallelism, 2)
+ jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
+
+ def broadcast(self, value):
+ jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value)))
+ return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
+ self._pickled_broadcast_vars)
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index af7703fdfc..4459095391 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -6,6 +6,8 @@ from pyspark.serializers import PickleSerializer
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
+from py4j.java_collections import ListConverter
+
class RDD(object):
@@ -15,11 +17,15 @@ class RDD(object):
self.ctx = ctx
@classmethod
- def _get_pipe_command(cls, command, functions):
+ def _get_pipe_command(cls, ctx, command, functions):
worker_args = [command]
for f in functions:
worker_args.append(b64enc(cloudpickle.dumps(f)))
- return " ".join(worker_args)
+ broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars]
+ broadcast_vars = ListConverter().convert(broadcast_vars,
+ ctx.gateway._gateway_client)
+ ctx._pickled_broadcast_vars.clear()
+ return (" ".join(worker_args), broadcast_vars)
def cache(self):
self.is_cached = True
@@ -52,9 +58,10 @@ class RDD(object):
def _pipe(self, functions, command):
class_manifest = self._jrdd.classManifest()
- pipe_command = RDD._get_pipe_command(command, functions)
+ (pipe_command, broadcast_vars) = \
+ RDD._get_pipe_command(self.ctx, command, functions)
python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
- False, self.ctx.pythonExec, class_manifest)
+ False, self.ctx.pythonExec, broadcast_vars, class_manifest)
return python_rdd.asJavaRDD()
def distinct(self):
@@ -249,10 +256,12 @@ class RDD(object):
def shuffle(self, numSplits, hashFunc=hash):
if numSplits is None:
numSplits = self.ctx.defaultParallelism
- pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc])
+ (pipe_command, broadcast_vars) = \
+ RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc])
class_manifest = self._jrdd.classManifest()
python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(),
- pipe_command, False, self.ctx.pythonExec, class_manifest)
+ pipe_command, False, self.ctx.pythonExec, broadcast_vars,
+ class_manifest)
partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
@@ -360,12 +369,12 @@ class PipelinedRDD(RDD):
@property
def _jrdd(self):
if not self._jrdd_val:
- funcs = [self.func]
- pipe_command = RDD._get_pipe_command("pipeline", funcs)
+ (pipe_command, broadcast_vars) = \
+ RDD._get_pipe_command(self.ctx, "pipeline", [self.func])
class_manifest = self._prev_jrdd.classManifest()
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
- class_manifest)
+ broadcast_vars, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py
index 76b09918e7..7402897ac8 100644
--- a/pyspark/pyspark/worker.py
+++ b/pyspark/pyspark/worker.py
@@ -5,6 +5,7 @@ import sys
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
+from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import dumps, loads, PickleSerializer
import cPickle
@@ -63,6 +64,11 @@ def do_shuffle_map_step():
def main():
+ num_broadcast_variables = int(sys.stdin.readline().strip())
+ for _ in range(num_broadcast_variables):
+ uuid = sys.stdin.read(36)
+ value = loads(sys.stdin)
+ _broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value))
command = sys.stdin.readline().strip()
if command == "pipeline":
do_pipeline()