From ffa5bedf46fbc89ad5c5658f3b423dfff49b70f0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 Nov 2013 12:58:28 -0800 Subject: Send PySpark commands as bytes insetad of strings. --- .../org/apache/spark/api/python/PythonRDD.scala | 24 ++++------------------ python/pyspark/rdd.py | 12 +++++------ python/pyspark/serializers.py | 5 +++++ python/pyspark/worker.py | 12 ++--------- 4 files changed, 17 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ef9bf4db9b..132e4fb0d2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -27,13 +27,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PipedRDD import org.apache.spark.util.Utils private[spark] class PythonRDD[T: ClassManifest]( parent: RDD[T], - command: Seq[String], + command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], preservePartitoning: Boolean, @@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest]( val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], - accumulator: Accumulator[JList[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec, - broadcastVars, accumulator) - override def getPartitions = parent.partitions override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get @@ -71,7 +59,6 @@ private[spark] class PythonRDD[T: ClassManifest]( SparkEnv.set(env) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) - val printOut = new PrintWriter(stream) // Partition index dataOut.writeInt(split.index) // sparkFilesDir @@ -87,17 +74,14 @@ private[spark] class PythonRDD[T: ClassManifest]( dataOut.writeInt(pythonIncludes.length) pythonIncludes.foreach(dataOut.writeUTF) dataOut.flush() - // Serialized user code - for (elem <- command) { - printOut.println(elem) - } - printOut.flush() + // Serialized command: + dataOut.writeInt(command.length) + dataOut.write(command) // Data values for (elem <- parent.iterator(split, context)) { PythonRDD.writeToStream(elem, dataOut) } dataOut.flush() - printOut.flush() worker.shutdownOutput() } catch { case e: IOException => diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6691c30519..062f44f81e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -27,9 +27,8 @@ from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread -from pyspark import cloudpickle from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ - BatchedSerializer, pack_long + BatchedSerializer, CloudPickleSerializer, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -970,8 +969,8 @@ class PipelinedRDD(RDD): serializer = NoOpSerializer() else: serializer = self.ctx.serializer - cmds = [self.func, self._prev_jrdd_deserializer, serializer] - pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) + command = (self.func, self._prev_jrdd_deserializer, serializer) + pickled_command = CloudPickleSerializer()._dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) @@ -982,8 +981,9 @@ class PipelinedRDD(RDD): includes = ListConverter().convert(self.ctx._python_includes, self.ctx._gateway._gateway_client) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator, class_manifest) + bytearray(pickled_command), env, includes, self.preservesPartitioning, + self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, + class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 4fb444443f..b23804b33c 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -64,6 +64,7 @@ import cPickle from itertools import chain, izip, product import marshal import struct +from pyspark import cloudpickle __all__ = ["PickleSerializer", "MarshalSerializer"] @@ -244,6 +245,10 @@ class PickleSerializer(FramedSerializer): def _dumps(self, obj): return cPickle.dumps(obj, 2) _loads = cPickle.loads +class CloudPickleSerializer(PickleSerializer): + + def _dumps(self, obj): return cloudpickle.dumps(obj, 2) + class MarshalSerializer(FramedSerializer): """ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5b16d5db7e..2751f1239e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,7 +23,6 @@ import sys import time import socket import traceback -from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. from pyspark.accumulators import _accumulatorRegistry @@ -38,11 +37,6 @@ pickleSer = PickleSerializer() mutf8_deserializer = MUTF8Deserializer() -def load_obj(infile): - decoded = standard_b64decode(infile.readline().strip()) - return pickleSer._loads(decoded) - - def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) write_long(1000 * boot, outfile) @@ -75,10 +69,8 @@ def main(infile, outfile): filename = mutf8_deserializer._loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) - # Load this stage's function and serializer: - func = load_obj(infile) - deserializer = load_obj(infile) - serializer = load_obj(infile) + command = pickleSer._read_with_length(infile) + (func, deserializer, serializer) = command init_time = time.time() try: iterator = deserializer.load_stream(infile) -- cgit v1.2.3