aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-25 13:59:01 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-08-27 00:16:47 -0700
commitf79a1e4d2a8643157136de69b8d7de84f0034712 (patch)
tree679f4453a382e121ffd4fc0c9f3a77a4b292d14c
parent65e8406029a0fe1e1c5c5d033d335b43f6743a04 (diff)
downloadspark-f79a1e4d2a8643157136de69b8d7de84f0034712.tar.gz
spark-f79a1e4d2a8643157136de69b8d7de84f0034712.tar.bz2
spark-f79a1e4d2a8643157136de69b8d7de84f0034712.zip
Add broadcast variables to Python API.
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala43
-rw-r--r--pyspark/pyspark/broadcast.py46
-rw-r--r--pyspark/pyspark/context.py17
-rw-r--r--pyspark/pyspark/rdd.py27
-rw-r--r--pyspark/pyspark/worker.py6
5 files changed, 110 insertions, 29 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 93847e2f14..5163812df4 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -7,14 +7,13 @@ import scala.collection.JavaConversions._
import scala.io.Source
import spark._
import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
-import scala.{collection, Some}
-import collection.parallel.mutable
+import broadcast.Broadcast
import scala.collection
-import scala.Some
trait PythonRDDBase {
def compute[T](split: Split, envVars: Map[String, String],
- command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[Array[Byte]] = {
+ command: Seq[String], parent: RDD[T], pythonExec: String,
+ broadcastVars: java.util.List[Broadcast[Array[Byte]]]): Iterator[Array[Byte]] = {
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py"))
@@ -42,11 +41,18 @@ trait PythonRDDBase {
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
+ val dOut = new DataOutputStream(proc.getOutputStream)
+ out.println(broadcastVars.length)
+ for (broadcast <- broadcastVars) {
+ out.print(broadcast.uuid.toString)
+ dOut.writeInt(broadcast.value.length)
+ dOut.write(broadcast.value)
+ dOut.flush()
+ }
for (elem <- command) {
out.println(elem)
}
out.flush()
- val dOut = new DataOutputStream(proc.getOutputStream)
for (elem <- parent.iterator(split)) {
if (elem.isInstanceOf[Array[Byte]]) {
val arr = elem.asInstanceOf[Array[Byte]]
@@ -121,16 +127,17 @@ trait PythonRDDBase {
class PythonRDD[T: ClassManifest](
parent: RDD[T], command: Seq[String], envVars: Map[String, String],
- preservePartitoning: Boolean, pythonExec: String)
+ preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]])
extends RDD[Array[Byte]](parent.context) with PythonRDDBase {
- def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) =
- this(parent, command, Map(), preservePartitoning, pythonExec)
+ def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean,
+ pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
+ this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars)
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) =
- this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec)
+ def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
+ this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars)
override def splits = parent.splits
@@ -139,23 +146,25 @@ class PythonRDD[T: ClassManifest](
override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Split): Iterator[Array[Byte]] =
- compute(split, envVars, command, parent, pythonExec)
+ compute(split, envVars, command, parent, pythonExec, broadcastVars)
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}
class PythonPairRDD[T: ClassManifest] (
parent: RDD[T], command: Seq[String], envVars: Map[String, String],
- preservePartitoning: Boolean, pythonExec: String)
+ preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]])
extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase {
- def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) =
- this(parent, command, Map(), preservePartitoning, pythonExec)
+ def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean,
+ pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
+ this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars)
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) =
- this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec)
+ def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String,
+ broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
+ this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars)
override def splits = parent.splits
@@ -164,7 +173,7 @@ class PythonPairRDD[T: ClassManifest] (
override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = {
- compute(split, envVars, command, parent, pythonExec).grouped(2).map {
+ compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map {
case Seq(a, b) => (a, b)
case x => throw new Exception("PythonPairRDD: unexpected value: " + x)
}
diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py
new file mode 100644
index 0000000000..1ea17d59af
--- /dev/null
+++ b/pyspark/pyspark/broadcast.py
@@ -0,0 +1,46 @@
+"""
+>>> from pyspark.context import SparkContext
+>>> sc = SparkContext('local', 'test')
+>>> b = sc.broadcast([1, 2, 3, 4, 5])
+>>> b.value
+[1, 2, 3, 4, 5]
+
+>>> from pyspark.broadcast import _broadcastRegistry
+>>> _broadcastRegistry[b.uuid] = b
+>>> from cPickle import dumps, loads
+>>> loads(dumps(b)).value
+[1, 2, 3, 4, 5]
+
+>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
+[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
+"""
+# Holds broadcasted data received from Java, keyed by UUID.
+_broadcastRegistry = {}
+
+
+def _from_uuid(uuid):
+ from pyspark.broadcast import _broadcastRegistry
+ if uuid not in _broadcastRegistry:
+ raise Exception("Broadcast variable '%s' not loaded!" % uuid)
+ return _broadcastRegistry[uuid]
+
+
+class Broadcast(object):
+ def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None):
+ self.value = value
+ self.uuid = uuid
+ self._jbroadcast = java_broadcast
+ self._pickle_registry = pickle_registry
+
+ def __reduce__(self):
+ self._pickle_registry.add(self)
+ return (_from_uuid, (self.uuid, ))
+
+
+def _test():
+ import doctest
+ doctest.testmod()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index ac7e4057e9..6f87206665 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -2,6 +2,7 @@ import os
import atexit
from tempfile import NamedTemporaryFile
+from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, dumps
from pyspark.rdd import RDD
@@ -24,6 +25,11 @@ class SparkContext(object):
self.defaultParallelism = \
defaultParallelism or self._jsc.sc().defaultParallelism()
self.pythonExec = pythonExec
+ # Broadcast's __reduce__ method stores Broadcast instances here.
+ # This allows other code to determine which Broadcast instances have
+ # been pickled, so it can determine which Java broadcast objects to
+ # send.
+ self._pickled_broadcast_vars = set()
def __del__(self):
if self._jsc:
@@ -52,7 +58,12 @@ class SparkContext(object):
jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
- def textFile(self, name, numSlices=None):
- numSlices = numSlices or self.defaultParallelism
- jrdd = self._jsc.textFile(name, numSlices)
+ def textFile(self, name, minSplits=None):
+ minSplits = minSplits or min(self.defaultParallelism, 2)
+ jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
+
+ def broadcast(self, value):
+ jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value)))
+ return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
+ self._pickled_broadcast_vars)
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index af7703fdfc..4459095391 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -6,6 +6,8 @@ from pyspark.serializers import PickleSerializer
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
+from py4j.java_collections import ListConverter
+
class RDD(object):
@@ -15,11 +17,15 @@ class RDD(object):
self.ctx = ctx
@classmethod
- def _get_pipe_command(cls, command, functions):
+ def _get_pipe_command(cls, ctx, command, functions):
worker_args = [command]
for f in functions:
worker_args.append(b64enc(cloudpickle.dumps(f)))
- return " ".join(worker_args)
+ broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars]
+ broadcast_vars = ListConverter().convert(broadcast_vars,
+ ctx.gateway._gateway_client)
+ ctx._pickled_broadcast_vars.clear()
+ return (" ".join(worker_args), broadcast_vars)
def cache(self):
self.is_cached = True
@@ -52,9 +58,10 @@ class RDD(object):
def _pipe(self, functions, command):
class_manifest = self._jrdd.classManifest()
- pipe_command = RDD._get_pipe_command(command, functions)
+ (pipe_command, broadcast_vars) = \
+ RDD._get_pipe_command(self.ctx, command, functions)
python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
- False, self.ctx.pythonExec, class_manifest)
+ False, self.ctx.pythonExec, broadcast_vars, class_manifest)
return python_rdd.asJavaRDD()
def distinct(self):
@@ -249,10 +256,12 @@ class RDD(object):
def shuffle(self, numSplits, hashFunc=hash):
if numSplits is None:
numSplits = self.ctx.defaultParallelism
- pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc])
+ (pipe_command, broadcast_vars) = \
+ RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc])
class_manifest = self._jrdd.classManifest()
python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(),
- pipe_command, False, self.ctx.pythonExec, class_manifest)
+ pipe_command, False, self.ctx.pythonExec, broadcast_vars,
+ class_manifest)
partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
@@ -360,12 +369,12 @@ class PipelinedRDD(RDD):
@property
def _jrdd(self):
if not self._jrdd_val:
- funcs = [self.func]
- pipe_command = RDD._get_pipe_command("pipeline", funcs)
+ (pipe_command, broadcast_vars) = \
+ RDD._get_pipe_command(self.ctx, "pipeline", [self.func])
class_manifest = self._prev_jrdd.classManifest()
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
- class_manifest)
+ broadcast_vars, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py
index 76b09918e7..7402897ac8 100644
--- a/pyspark/pyspark/worker.py
+++ b/pyspark/pyspark/worker.py
@@ -5,6 +5,7 @@ import sys
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
+from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import dumps, loads, PickleSerializer
import cPickle
@@ -63,6 +64,11 @@ def do_shuffle_map_step():
def main():
+ num_broadcast_variables = int(sys.stdin.readline().strip())
+ for _ in range(num_broadcast_variables):
+ uuid = sys.stdin.read(36)
+ value = loads(sys.stdin)
+ _broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value))
command = sys.stdin.readline().strip()
if command == "pipeline":
do_pipeline()