From 52989c8a2c8c10d7f5610c033f6782e58fd3abc2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 19 Oct 2012 10:24:49 -0700 Subject: Update Python API for v0.6.0 compatibility. --- core/src/main/scala/spark/api/python/PythonRDD.scala | 18 +++++++++++------- core/src/main/scala/spark/broadcast/Broadcast.scala | 2 +- pyspark/pyspark/broadcast.py | 18 +++++++++--------- pyspark/pyspark/context.py | 2 +- pyspark/pyspark/java_gateway.py | 3 ++- pyspark/pyspark/serializers.py | 18 ++++++++++++++---- pyspark/pyspark/worker.py | 8 ++++---- 7 files changed, 42 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 4d3bdb3963..528885fe5c 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -5,11 +5,15 @@ import java.io._ import scala.collection.Map import scala.collection.JavaConversions._ import scala.io.Source -import spark._ -import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} -import broadcast.Broadcast -import scala.collection -import java.nio.charset.Charset + +import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import spark.broadcast.Broadcast +import spark.SparkEnv +import spark.Split +import spark.RDD +import spark.OneToOneDependency +import spark.rdd.PipedRDD + trait PythonRDDBase { def compute[T](split: Split, envVars: Map[String, String], @@ -43,9 +47,9 @@ trait PythonRDDBase { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) val dOut = new DataOutputStream(proc.getOutputStream) - out.println(broadcastVars.length) + dOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { - out.print(broadcast.uuid.toString) + dOut.writeLong(broadcast.id) dOut.writeInt(broadcast.value.length) dOut.write(broadcast.value) dOut.flush() diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index 6055bfd045..2ffe7f741d 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong import spark._ -abstract class Broadcast[T](id: Long) extends Serializable { +abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { def value: T // We cannot have an abstract readObject here due to some weird issues with diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py index 1ea17d59af..4cff02b36d 100644 --- a/pyspark/pyspark/broadcast.py +++ b/pyspark/pyspark/broadcast.py @@ -6,7 +6,7 @@ [1, 2, 3, 4, 5] >>> from pyspark.broadcast import _broadcastRegistry ->>> _broadcastRegistry[b.uuid] = b +>>> _broadcastRegistry[b.bid] = b >>> from cPickle import dumps, loads >>> loads(dumps(b)).value [1, 2, 3, 4, 5] @@ -14,27 +14,27 @@ >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] """ -# Holds broadcasted data received from Java, keyed by UUID. +# Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} -def _from_uuid(uuid): +def _from_id(bid): from pyspark.broadcast import _broadcastRegistry - if uuid not in _broadcastRegistry: - raise Exception("Broadcast variable '%s' not loaded!" % uuid) - return _broadcastRegistry[uuid] + if bid not in _broadcastRegistry: + raise Exception("Broadcast variable '%s' not loaded!" % bid) + return _broadcastRegistry[bid] class Broadcast(object): - def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None): + def __init__(self, bid, value, java_broadcast=None, pickle_registry=None): self.value = value - self.uuid = uuid + self.bid = bid self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry def __reduce__(self): self._pickle_registry.add(self) - return (_from_uuid, (self.uuid, )) + return (_from_id, (self.bid, )) def _test(): diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 04932c93f2..3f4db26644 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -66,5 +66,5 @@ class SparkContext(object): def broadcast(self, value): jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) - return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, + return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py index bcb405ba72..3726bcbf17 100644 --- a/pyspark/pyspark/java_gateway.py +++ b/pyspark/pyspark/java_gateway.py @@ -7,7 +7,8 @@ SPARK_HOME = os.environ["SPARK_HOME"] assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \ - "/spark-core-assembly-*-SNAPSHOT.jar")[0] + "/spark-core-assembly-*.jar")[0] + # TODO: what if multiple assembly jars are found? def launch_gateway(): diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py index faa1e683c7..21ef8b106c 100644 --- a/pyspark/pyspark/serializers.py +++ b/pyspark/pyspark/serializers.py @@ -9,16 +9,26 @@ def dump_pickle(obj): load_pickle = cPickle.loads +def read_long(stream): + length = stream.read(8) + if length == "": + raise EOFError + return struct.unpack("!q", length)[0] + + +def read_int(stream): + length = stream.read(4) + if length == "": + raise EOFError + return struct.unpack("!i", length)[0] + def write_with_length(obj, stream): stream.write(struct.pack("!i", len(obj))) stream.write(obj) def read_with_length(stream): - length = stream.read(4) - if length == "": - raise EOFError - length = struct.unpack("!i", length)[0] + length = read_int(stream) obj = stream.read(length) if obj == "": raise EOFError diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py index a9ed71892f..62824a1c9b 100644 --- a/pyspark/pyspark/worker.py +++ b/pyspark/pyspark/worker.py @@ -8,7 +8,7 @@ from base64 import standard_b64decode from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.serializers import write_with_length, read_with_length, \ - dump_pickle, load_pickle + read_long, read_int, dump_pickle, load_pickle # Redirect stdout to stderr so that users must return values from functions. @@ -29,11 +29,11 @@ def read_input(): def main(): - num_broadcast_variables = int(sys.stdin.readline().strip()) + num_broadcast_variables = read_int(sys.stdin) for _ in range(num_broadcast_variables): - uuid = sys.stdin.read(36) + bid = read_long(sys.stdin) value = read_with_length(sys.stdin) - _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value)) + _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) func = load_obj() bypassSerializer = load_obj() if bypassSerializer: -- cgit v1.2.3