aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala106
-rw-r--r--python/pyspark/context.py10
-rw-r--r--python/pyspark/rdd.py11
-rw-r--r--python/pyspark/serializers.py18
-rw-r--r--python/pyspark/worker.py14
5 files changed, 78 insertions, 81 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0d5913ec60..eb0b0db0cc 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -75,7 +75,7 @@ private[spark] class PythonRDD[T: ClassManifest](
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
- PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+ dataOut.writeUTF(SparkFiles.getRootDirectory)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
@@ -85,9 +85,7 @@ private[spark] class PythonRDD[T: ClassManifest](
}
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
- for (f <- pythonIncludes) {
- PythonRDD.writeAsPickle(f, dataOut)
- }
+ pythonIncludes.foreach(dataOut.writeUTF)
dataOut.flush()
// Serialized user code
for (elem <- command) {
@@ -96,7 +94,7 @@ private[spark] class PythonRDD[T: ClassManifest](
printOut.flush()
// Data values
for (elem <- parent.iterator(split, context)) {
- PythonRDD.writeAsPickle(elem, dataOut)
+ PythonRDD.writeToStream(elem, dataOut)
}
dataOut.flush()
printOut.flush()
@@ -205,60 +203,7 @@ private object SpecialLengths {
private[spark] object PythonRDD {
- /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
- def stripPickle(arr: Array[Byte]) : Array[Byte] = {
- arr.slice(2, arr.length - 1)
- }
-
- /**
- * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
- * The data format is a 32-bit integer representing the pickled object's length (in bytes),
- * followed by the pickled data.
- *
- * Pickle module:
- *
- * http://docs.python.org/2/library/pickle.html
- *
- * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
- *
- * http://hg.python.org/cpython/file/2.6/Lib/pickle.py
- * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
- *
- * @param elem the object to write
- * @param dOut a data output stream
- */
- def writeAsPickle(elem: Any, dOut: DataOutputStream) {
- if (elem.isInstanceOf[Array[Byte]]) {
- val arr = elem.asInstanceOf[Array[Byte]]
- dOut.writeInt(arr.length)
- dOut.write(arr)
- } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
- val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
- val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.write(PythonRDD.stripPickle(t._1))
- dOut.write(PythonRDD.stripPickle(t._2))
- dOut.writeByte(Pickle.TUPLE2)
- dOut.writeByte(Pickle.STOP)
- } else if (elem.isInstanceOf[String]) {
- // For uniformity, strings are wrapped into Pickles.
- val s = elem.asInstanceOf[String].getBytes("UTF-8")
- val length = 2 + 1 + 4 + s.length + 1
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.write(Pickle.BINUNICODE)
- dOut.writeInt(Integer.reverseBytes(s.length))
- dOut.write(s)
- dOut.writeByte(Pickle.STOP)
- } else {
- throw new SparkException("Unexpected RDD type")
- }
- }
-
- def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+ def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -276,15 +221,46 @@ private[spark] object PythonRDD {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
- def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ def writeStringAsPickle(elem: String, dOut: DataOutputStream) {
+ val s = elem.getBytes("UTF-8")
+ val length = 2 + 1 + 4 + s.length + 1
+ dOut.writeInt(length)
+ dOut.writeByte(Pickle.PROTO)
+ dOut.writeByte(Pickle.TWO)
+ dOut.write(Pickle.BINUNICODE)
+ dOut.writeInt(Integer.reverseBytes(s.length))
+ dOut.write(s)
+ dOut.writeByte(Pickle.STOP)
+ }
+
+ def writeToStream(elem: Any, dataOut: DataOutputStream) {
+ elem match {
+ case bytes: Array[Byte] =>
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ case pair: (Array[Byte], Array[Byte]) =>
+ dataOut.writeInt(pair._1.length)
+ dataOut.write(pair._1)
+ dataOut.writeInt(pair._2.length)
+ dataOut.write(pair._2)
+ case str: String =>
+ // Until we've implemented full custom serializer support, we need to return
+ // strings as Pickles to properly support union() and cartesian():
+ writeStringAsPickle(str, dataOut)
+ case other =>
+ throw new SparkException("Unexpected element type " + other.getClass)
+ }
+ }
+
+ def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
- writeIteratorToPickleFile(items.asScala, filename)
+ writeToFile(items.asScala, filename)
}
- def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
+ def writeToFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
- writeAsPickle(item, file)
+ writeToStream(item, file)
}
file.close()
}
@@ -300,10 +276,6 @@ private object Pickle {
val TWO: Byte = 0x02.toByte
val BINUNICODE: Byte = 'X'
val STOP: Byte = '.'
- val TUPLE2: Byte = 0x86.toByte
- val EMPTY_LIST: Byte = ']'
- val MARK: Byte = '('
- val APPENDS: Byte = 'e'
}
private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a7ca8bc888..0fec1a6bf6 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -42,7 +42,7 @@ class SparkContext(object):
_gateway = None
_jvm = None
- _writeIteratorToPickleFile = None
+ _writeToFile = None
_takePartition = None
_next_accum_id = 0
_active_spark_context = None
@@ -125,8 +125,8 @@ class SparkContext(object):
if not SparkContext._gateway:
SparkContext._gateway = launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
- SparkContext._writeIteratorToPickleFile = \
- SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+ SparkContext._writeToFile = \
+ SparkContext._jvm.PythonRDD.writeToFile
SparkContext._takePartition = \
SparkContext._jvm.PythonRDD.takePartition
@@ -190,8 +190,8 @@ class SparkContext(object):
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
- readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
- jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+ readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+ jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
def textFile(self, name, minSplits=None):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7019fb8bee..d3c4d13a1e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -54,6 +54,7 @@ class RDD(object):
self.is_checkpointed = False
self.ctx = ctx
self._partitionFunc = None
+ self._stage_input_is_pairs = False
@property
def context(self):
@@ -344,6 +345,7 @@ class RDD(object):
yield pair
else:
yield pair
+ java_cartesian._stage_input_is_pairs = True
return java_cartesian.flatMap(unpack_batches)
def groupBy(self, f, numPartitions=None):
@@ -391,8 +393,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- picklesInJava = self._jrdd.collect().iterator()
- return list(self._collect_iterator_through_file(picklesInJava))
+ bytesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
@@ -400,7 +402,7 @@ class RDD(object):
# file and read it back.
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
- self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+ self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in read_from_pickle_file(tempFile):
@@ -941,6 +943,7 @@ class PipelinedRDD(RDD):
self.func = func
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
+ self._stage_input_is_pairs = prev._stage_input_is_pairs
self.is_cached = False
self.is_checkpointed = False
self.ctx = prev.ctx
@@ -959,7 +962,7 @@ class PipelinedRDD(RDD):
def batched_func(split, iterator):
return batched(oldfunc(split, iterator), batchSize)
func = batched_func
- cmds = [func, self._bypass_serializer]
+ cmds = [func, self._bypass_serializer, self._stage_input_is_pairs]
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fbc280fd37..fd02e1ee8f 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -93,6 +93,14 @@ def write_with_length(obj, stream):
stream.write(obj)
+def read_mutf8(stream):
+ """
+ Read a string written with Java's DataOutputStream.writeUTF() method.
+ """
+ length = struct.unpack('>H', stream.read(2))[0]
+ return stream.read(length).decode('utf8')
+
+
def read_with_length(stream):
length = read_int(stream)
obj = stream.read(length)
@@ -112,3 +120,13 @@ def read_from_pickle_file(stream):
yield obj
except EOFError:
return
+
+
+def read_pairs_from_pickle_file(stream):
+ try:
+ while True:
+ a = load_pickle(read_with_length(stream))
+ b = load_pickle(read_with_length(stream))
+ yield (a, b)
+ except EOFError:
+ return \ No newline at end of file
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 7696df9d1c..4e64557fc4 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,8 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, read_with_length, write_int, \
- read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \
- SpecialLengths
+ read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \
+ SpecialLengths, read_mutf8, read_pairs_from_pickle_file
def load_obj(infile):
@@ -53,7 +53,7 @@ def main(infile, outfile):
return
# fetch name of workdir
- spark_files_dir = load_pickle(read_with_length(infile))
+ spark_files_dir = read_mutf8(infile)
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
@@ -68,17 +68,21 @@ def main(infile, outfile):
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
- sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+ sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile)))
# now load function
func = load_obj(infile)
bypassSerializer = load_obj(infile)
+ stageInputIsPairs = load_obj(infile)
if bypassSerializer:
dumps = lambda x: x
else:
dumps = dump_pickle
init_time = time.time()
- iterator = read_from_pickle_file(infile)
+ if stageInputIsPairs:
+ iterator = read_pairs_from_pickle_file(infile)
+ else:
+ iterator = read_from_pickle_file(infile)
try:
for obj in func(split_index, iterator):
write_with_length(dumps(obj), outfile)