From f79a1e4d2a8643157136de69b8d7de84f0034712 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 25 Aug 2012 13:59:01 -0700 Subject: Add broadcast variables to Python API. --- .../main/scala/spark/api/python/PythonRDD.scala | 43 ++++++++++++-------- pyspark/pyspark/broadcast.py | 46 ++++++++++++++++++++++ pyspark/pyspark/context.py | 17 ++++++-- pyspark/pyspark/rdd.py | 27 ++++++++----- pyspark/pyspark/worker.py | 6 +++ 5 files changed, 110 insertions(+), 29 deletions(-) create mode 100644 pyspark/pyspark/broadcast.py diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 93847e2f14..5163812df4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -7,14 +7,13 @@ import scala.collection.JavaConversions._ import scala.io.Source import spark._ import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} -import scala.{collection, Some} -import collection.parallel.mutable +import broadcast.Broadcast import scala.collection -import scala.Some trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], - command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[Array[Byte]] = { + command: Seq[String], parent: RDD[T], pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) @@ -42,11 +41,18 @@ trait PythonRDDBase { override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) + val dOut = new DataOutputStream(proc.getOutputStream) + out.println(broadcastVars.length) + for (broadcast <- broadcastVars) { + out.print(broadcast.uuid.toString) + dOut.writeInt(broadcast.value.length) + dOut.write(broadcast.value) + dOut.flush() + } for (elem <- command) { out.println(elem) } out.flush() - val dOut = new DataOutputStream(proc.getOutputStream) for (elem <- parent.iterator(split)) { if (elem.isInstanceOf[Array[Byte]]) { val arr = elem.asInstanceOf[Array[Byte]] @@ -121,16 +127,17 @@ trait PythonRDDBase { class PythonRDD[T: ClassManifest]( parent: RDD[T], command: Seq[String], envVars: Map[String, String], - preservePartitoning: Boolean, pythonExec: String) + preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) with PythonRDDBase { - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = - this(parent, command, Map(), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, + pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) override def splits = parent.splits @@ -139,23 +146,25 @@ class PythonRDD[T: ClassManifest]( override val partitioner = if (preservePartitoning) parent.partitioner else None override def compute(split: Split): Iterator[Array[Byte]] = - compute(split, envVars, command, parent, pythonExec) + compute(split, envVars, command, parent, pythonExec, broadcastVars) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } class PythonPairRDD[T: ClassManifest] ( parent: RDD[T], command: Seq[String], envVars: Map[String, String], - preservePartitoning: Boolean, pythonExec: String) + preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { - def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = - this(parent, command, Map(), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, + pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = - this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) + def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, + broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars) override def splits = parent.splits @@ -164,7 +173,7 @@ class PythonPairRDD[T: ClassManifest] ( override val partitioner = if (preservePartitoning) parent.partitioner else None override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { - compute(split, envVars, command, parent, pythonExec).grouped(2).map { + compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map { case Seq(a, b) => (a, b) case x => throw new Exception("PythonPairRDD: unexpected value: " + x) } diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py new file mode 100644 index 0000000000..1ea17d59af --- /dev/null +++ b/pyspark/pyspark/broadcast.py @@ -0,0 +1,46 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> b = sc.broadcast([1, 2, 3, 4, 5]) +>>> b.value +[1, 2, 3, 4, 5] + +>>> from pyspark.broadcast import _broadcastRegistry +>>> _broadcastRegistry[b.uuid] = b +>>> from cPickle import dumps, loads +>>> loads(dumps(b)).value +[1, 2, 3, 4, 5] + +>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() +[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] +""" +# Holds broadcasted data received from Java, keyed by UUID. +_broadcastRegistry = {} + + +def _from_uuid(uuid): + from pyspark.broadcast import _broadcastRegistry + if uuid not in _broadcastRegistry: + raise Exception("Broadcast variable '%s' not loaded!" % uuid) + return _broadcastRegistry[uuid] + + +class Broadcast(object): + def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None): + self.value = value + self.uuid = uuid + self._jbroadcast = java_broadcast + self._pickle_registry = pickle_registry + + def __reduce__(self): + self._pickle_registry.add(self) + return (_from_uuid, (self.uuid, )) + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index ac7e4057e9..6f87206665 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -2,6 +2,7 @@ import os import atexit from tempfile import NamedTemporaryFile +from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, dumps from pyspark.rdd import RDD @@ -24,6 +25,11 @@ class SparkContext(object): self.defaultParallelism = \ defaultParallelism or self._jsc.sc().defaultParallelism() self.pythonExec = pythonExec + # Broadcast's __reduce__ method stores Broadcast instances here. + # This allows other code to determine which Broadcast instances have + # been pickled, so it can determine which Java broadcast objects to + # send. + self._pickled_broadcast_vars = set() def __del__(self): if self._jsc: @@ -52,7 +58,12 @@ class SparkContext(object): jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) - def textFile(self, name, numSlices=None): - numSlices = numSlices or self.defaultParallelism - jrdd = self._jsc.textFile(name, numSlices) + def textFile(self, name, minSplits=None): + minSplits = minSplits or min(self.defaultParallelism, 2) + jrdd = self._jsc.textFile(name, minSplits) return RDD(jrdd, self) + + def broadcast(self, value): + jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value))) + return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, + self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index af7703fdfc..4459095391 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -6,6 +6,8 @@ from pyspark.serializers import PickleSerializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup +from py4j.java_collections import ListConverter + class RDD(object): @@ -15,11 +17,15 @@ class RDD(object): self.ctx = ctx @classmethod - def _get_pipe_command(cls, command, functions): + def _get_pipe_command(cls, ctx, command, functions): worker_args = [command] for f in functions: worker_args.append(b64enc(cloudpickle.dumps(f))) - return " ".join(worker_args) + broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars] + broadcast_vars = ListConverter().convert(broadcast_vars, + ctx.gateway._gateway_client) + ctx._pickled_broadcast_vars.clear() + return (" ".join(worker_args), broadcast_vars) def cache(self): self.is_cached = True @@ -52,9 +58,10 @@ class RDD(object): def _pipe(self, functions, command): class_manifest = self._jrdd.classManifest() - pipe_command = RDD._get_pipe_command(command, functions) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, command, functions) python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, - False, self.ctx.pythonExec, class_manifest) + False, self.ctx.pythonExec, broadcast_vars, class_manifest) return python_rdd.asJavaRDD() def distinct(self): @@ -249,10 +256,12 @@ class RDD(object): def shuffle(self, numSplits, hashFunc=hash): if numSplits is None: numSplits = self.ctx.defaultParallelism - pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc]) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc]) class_manifest = self._jrdd.classManifest() python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), - pipe_command, False, self.ctx.pythonExec, class_manifest) + pipe_command, False, self.ctx.pythonExec, broadcast_vars, + class_manifest) partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) @@ -360,12 +369,12 @@ class PipelinedRDD(RDD): @property def _jrdd(self): if not self._jrdd_val: - funcs = [self.func] - pipe_command = RDD._get_pipe_command("pipeline", funcs) + (pipe_command, broadcast_vars) = \ + RDD._get_pipe_command(self.ctx, "pipeline", [self.func]) class_manifest = self._prev_jrdd.classManifest() python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, self.preservesPartitioning, self.ctx.pythonExec, - class_manifest) + broadcast_vars, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index 76b09918e7..7402897ac8 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -5,6 +5,7 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. +from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import dumps, loads, PickleSerializer import cPickle @@ -63,6 +64,11 @@ def do_shuffle_map_step(): def main(): + num_broadcast_variables = int(sys.stdin.readline().strip()) + for _ in range(num_broadcast_variables): + uuid = sys.stdin.read(36) + value = loads(sys.stdin) + _broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value)) command = sys.stdin.readline().strip() if command == "pipeline": do_pipeline() -- cgit v1.2.3