aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala24
-rw-r--r--python/pyspark/rdd.py12
-rw-r--r--python/pyspark/serializers.py5
-rw-r--r--python/pyspark/worker.py12
4 files changed, 17 insertions, 36 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ef9bf4db9b..132e4fb0d2 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -27,13 +27,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark._
import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.PipedRDD
import org.apache.spark.util.Utils
private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T],
- command: Seq[String],
+ command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
preservePartitoning: Boolean,
@@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest](
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
- // Similar to Runtime.exec(), if we are given a single string, split it into words
- // using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String, envVars: JMap[String, String],
- pythonIncludes: JList[String],
- preservePartitoning: Boolean, pythonExec: String,
- broadcastVars: JList[Broadcast[Array[Byte]]],
- accumulator: Accumulator[JList[Array[Byte]]]) =
- this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
- broadcastVars, accumulator)
-
override def getPartitions = parent.partitions
override val partitioner = if (preservePartitoning) parent.partitioner else None
-
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
@@ -71,7 +59,6 @@ private[spark] class PythonRDD[T: ClassManifest](
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
- val printOut = new PrintWriter(stream)
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
@@ -87,17 +74,14 @@ private[spark] class PythonRDD[T: ClassManifest](
dataOut.writeInt(pythonIncludes.length)
pythonIncludes.foreach(dataOut.writeUTF)
dataOut.flush()
- // Serialized user code
- for (elem <- command) {
- printOut.println(elem)
- }
- printOut.flush()
+ // Serialized command:
+ dataOut.writeInt(command.length)
+ dataOut.write(command)
// Data values
for (elem <- parent.iterator(split, context)) {
PythonRDD.writeToStream(elem, dataOut)
}
dataOut.flush()
- printOut.flush()
worker.shutdownOutput()
} catch {
case e: IOException =>
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 6691c30519..062f44f81e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -27,9 +27,8 @@ from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
-from pyspark import cloudpickle
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
- BatchedSerializer, pack_long
+ BatchedSerializer, CloudPickleSerializer, pack_long
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -970,8 +969,8 @@ class PipelinedRDD(RDD):
serializer = NoOpSerializer()
else:
serializer = self.ctx.serializer
- cmds = [self.func, self._prev_jrdd_deserializer, serializer]
- pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+ command = (self.func, self._prev_jrdd_deserializer, serializer)
+ pickled_command = CloudPickleSerializer()._dumps(command)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
@@ -982,8 +981,9 @@ class PipelinedRDD(RDD):
includes = ListConverter().convert(self.ctx._python_includes,
self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator, class_manifest)
+ bytearray(pickled_command), env, includes, self.preservesPartitioning,
+ self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
+ class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 4fb444443f..b23804b33c 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -64,6 +64,7 @@ import cPickle
from itertools import chain, izip, product
import marshal
import struct
+from pyspark import cloudpickle
__all__ = ["PickleSerializer", "MarshalSerializer"]
@@ -244,6 +245,10 @@ class PickleSerializer(FramedSerializer):
def _dumps(self, obj): return cPickle.dumps(obj, 2)
_loads = cPickle.loads
+class CloudPickleSerializer(PickleSerializer):
+
+ def _dumps(self, obj): return cloudpickle.dumps(obj, 2)
+
class MarshalSerializer(FramedSerializer):
"""
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 5b16d5db7e..2751f1239e 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,7 +23,6 @@ import sys
import time
import socket
import traceback
-from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
from pyspark.accumulators import _accumulatorRegistry
@@ -38,11 +37,6 @@ pickleSer = PickleSerializer()
mutf8_deserializer = MUTF8Deserializer()
-def load_obj(infile):
- decoded = standard_b64decode(infile.readline().strip())
- return pickleSer._loads(decoded)
-
-
def report_times(outfile, boot, init, finish):
write_int(SpecialLengths.TIMING_DATA, outfile)
write_long(1000 * boot, outfile)
@@ -75,10 +69,8 @@ def main(infile, outfile):
filename = mutf8_deserializer._loads(infile)
sys.path.append(os.path.join(spark_files_dir, filename))
- # Load this stage's function and serializer:
- func = load_obj(infile)
- deserializer = load_obj(infile)
- serializer = load_obj(infile)
+ command = pickleSer._read_with_length(infile)
+ (func, deserializer, serializer) = command
init_time = time.time()
try:
iterator = deserializer.load_stream(infile)