aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/worker.py
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@apache.org>2013-11-05 17:52:39 -0800
committerJosh Rosen <joshrosen@apache.org>2013-11-10 16:45:38 -0800
commitcbb7f04aef2220ece93dea9f3fa98b5db5f270d6 (patch)
tree5feaed6b6064b81272fcb74b48ee2579e32de4e6 /python/pyspark/worker.py
parent7d68a81a8ed5f49fefb3bd0fa0b9d3835cc7d86e (diff)
downloadspark-cbb7f04aef2220ece93dea9f3fa98b5db5f270d6.tar.gz
spark-cbb7f04aef2220ece93dea9f3fa98b5db5f270d6.tar.bz2
spark-cbb7f04aef2220ece93dea9f3fa98b5db5f270d6.zip
Add custom serializer support to PySpark.
For now, this only adds MarshalSerializer, but it lays the groundwork for other supporting custom serializers. Many of these mechanisms can also be used to support deserialization of different data formats sent by Java, such as data encoded by MsgPack. This also fixes a bug in SparkContext.union().
Diffstat (limited to 'python/pyspark/worker.py')
-rw-r--r--python/pyspark/worker.py41
1 files changed, 19 insertions, 22 deletions
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 4e64557fc4..5b16d5db7e 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,13 +30,17 @@ from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, read_with_length, write_int, \
- read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \
- SpecialLengths, read_mutf8, read_pairs_from_pickle_file
+from pyspark.serializers import write_with_length, write_int, read_long, \
+ write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
+
+
+pickleSer = PickleSerializer()
+mutf8_deserializer = MUTF8Deserializer()
def load_obj(infile):
- return load_pickle(standard_b64decode(infile.readline().strip()))
+ decoded = standard_b64decode(infile.readline().strip())
+ return pickleSer._loads(decoded)
def report_times(outfile, boot, init, finish):
@@ -53,7 +57,7 @@ def main(infile, outfile):
return
# fetch name of workdir
- spark_files_dir = read_mutf8(infile)
+ spark_files_dir = mutf8_deserializer._loads(infile)
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
@@ -61,31 +65,24 @@ def main(infile, outfile):
num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
- value = read_with_length(infile)
- _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+ value = pickleSer._read_with_length(infile)
+ _broadcastRegistry[bid] = Broadcast(bid, value)
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
- sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile)))
+ filename = mutf8_deserializer._loads(infile)
+ sys.path.append(os.path.join(spark_files_dir, filename))
- # now load function
+ # Load this stage's function and serializer:
func = load_obj(infile)
- bypassSerializer = load_obj(infile)
- stageInputIsPairs = load_obj(infile)
- if bypassSerializer:
- dumps = lambda x: x
- else:
- dumps = dump_pickle
+ deserializer = load_obj(infile)
+ serializer = load_obj(infile)
init_time = time.time()
- if stageInputIsPairs:
- iterator = read_pairs_from_pickle_file(infile)
- else:
- iterator = read_from_pickle_file(infile)
try:
- for obj in func(split_index, iterator):
- write_with_length(dumps(obj), outfile)
+ iterator = deserializer.load_stream(infile)
+ serializer.dump_stream(func(split_index, iterator), outfile)
except Exception as e:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
@@ -96,7 +93,7 @@ def main(infile, outfile):
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
write_int(len(_accumulatorRegistry), outfile)
for (aid, accum) in _accumulatorRegistry.items():
- write_with_length(dump_pickle((aid, accum._value)), outfile)
+ pickleSer._write_with_length((aid, accum._value), outfile)
if __name__ == '__main__':