aboutsummaryrefslogtreecommitdiff
path: root/pyspark
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@eecs.berkeley.edu>2012-10-19 10:24:49 -0700
committerJosh Rosen <joshrosen@eecs.berkeley.edu>2012-10-19 10:24:49 -0700
commit52989c8a2c8c10d7f5610c033f6782e58fd3abc2 (patch)
tree507990a04e38087ab7a1229859a47c0b602c316c /pyspark
parente21eb6e00ddb77f40ecca9144b7405a293b97573 (diff)
downloadspark-52989c8a2c8c10d7f5610c033f6782e58fd3abc2.tar.gz
spark-52989c8a2c8c10d7f5610c033f6782e58fd3abc2.tar.bz2
spark-52989c8a2c8c10d7f5610c033f6782e58fd3abc2.zip
Update Python API for v0.6.0 compatibility.
Diffstat (limited to 'pyspark')
-rw-r--r--pyspark/pyspark/broadcast.py18
-rw-r--r--pyspark/pyspark/context.py2
-rw-r--r--pyspark/pyspark/java_gateway.py3
-rw-r--r--pyspark/pyspark/serializers.py18
-rw-r--r--pyspark/pyspark/worker.py8
5 files changed, 30 insertions, 19 deletions
diff --git a/pyspark/pyspark/broadcast.py b/pyspark/pyspark/broadcast.py
index 1ea17d59af..4cff02b36d 100644
--- a/pyspark/pyspark/broadcast.py
+++ b/pyspark/pyspark/broadcast.py
@@ -6,7 +6,7 @@
[1, 2, 3, 4, 5]
>>> from pyspark.broadcast import _broadcastRegistry
->>> _broadcastRegistry[b.uuid] = b
+>>> _broadcastRegistry[b.bid] = b
>>> from cPickle import dumps, loads
>>> loads(dumps(b)).value
[1, 2, 3, 4, 5]
@@ -14,27 +14,27 @@
>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
"""
-# Holds broadcasted data received from Java, keyed by UUID.
+# Holds broadcasted data received from Java, keyed by its id.
_broadcastRegistry = {}
-def _from_uuid(uuid):
+def _from_id(bid):
from pyspark.broadcast import _broadcastRegistry
- if uuid not in _broadcastRegistry:
- raise Exception("Broadcast variable '%s' not loaded!" % uuid)
- return _broadcastRegistry[uuid]
+ if bid not in _broadcastRegistry:
+ raise Exception("Broadcast variable '%s' not loaded!" % bid)
+ return _broadcastRegistry[bid]
class Broadcast(object):
- def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None):
+ def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
self.value = value
- self.uuid = uuid
+ self.bid = bid
self._jbroadcast = java_broadcast
self._pickle_registry = pickle_registry
def __reduce__(self):
self._pickle_registry.add(self)
- return (_from_uuid, (self.uuid, ))
+ return (_from_id, (self.bid, ))
def _test():
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index 04932c93f2..3f4db26644 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -66,5 +66,5 @@ class SparkContext(object):
def broadcast(self, value):
jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
- return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
+ return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
diff --git a/pyspark/pyspark/java_gateway.py b/pyspark/pyspark/java_gateway.py
index bcb405ba72..3726bcbf17 100644
--- a/pyspark/pyspark/java_gateway.py
+++ b/pyspark/pyspark/java_gateway.py
@@ -7,7 +7,8 @@ SPARK_HOME = os.environ["SPARK_HOME"]
assembly_jar = glob.glob(os.path.join(SPARK_HOME, "core/target") + \
- "/spark-core-assembly-*-SNAPSHOT.jar")[0]
+ "/spark-core-assembly-*.jar")[0]
+ # TODO: what if multiple assembly jars are found?
def launch_gateway():
diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py
index faa1e683c7..21ef8b106c 100644
--- a/pyspark/pyspark/serializers.py
+++ b/pyspark/pyspark/serializers.py
@@ -9,16 +9,26 @@ def dump_pickle(obj):
load_pickle = cPickle.loads
+def read_long(stream):
+ length = stream.read(8)
+ if length == "":
+ raise EOFError
+ return struct.unpack("!q", length)[0]
+
+
+def read_int(stream):
+ length = stream.read(4)
+ if length == "":
+ raise EOFError
+ return struct.unpack("!i", length)[0]
+
def write_with_length(obj, stream):
stream.write(struct.pack("!i", len(obj)))
stream.write(obj)
def read_with_length(stream):
- length = stream.read(4)
- if length == "":
- raise EOFError
- length = struct.unpack("!i", length)[0]
+ length = read_int(stream)
obj = stream.read(length)
if obj == "":
raise EOFError
diff --git a/pyspark/pyspark/worker.py b/pyspark/pyspark/worker.py
index a9ed71892f..62824a1c9b 100644
--- a/pyspark/pyspark/worker.py
+++ b/pyspark/pyspark/worker.py
@@ -8,7 +8,7 @@ from base64 import standard_b64decode
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import write_with_length, read_with_length, \
- dump_pickle, load_pickle
+ read_long, read_int, dump_pickle, load_pickle
# Redirect stdout to stderr so that users must return values from functions.
@@ -29,11 +29,11 @@ def read_input():
def main():
- num_broadcast_variables = int(sys.stdin.readline().strip())
+ num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
- uuid = sys.stdin.read(36)
+ bid = read_long(sys.stdin)
value = read_with_length(sys.stdin)
- _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value))
+ _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
func = load_obj()
bypassSerializer = load_obj()
if bypassSerializer: