From 1381fc72f7a34f690a98ab72cec8ffb61e0e564d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jan 2014 19:50:26 -0800 Subject: Switch from MUTF8 to UTF8 in PySpark serializers. This fixes SPARK-1043, a bug introduced in 0.9.0 where PySpark couldn't serialize strings > 64kB. This fix was written by @tyro89 and @bouk in #512. This commit squashes and rebases their pull request in order to fix some merge conflicts. --- .../org/apache/spark/api/python/PythonRDD.scala | 18 +++++++---- .../apache/spark/api/python/PythonRDDSuite.scala | 35 ++++++++++++++++++++++ python/pyspark/context.py | 4 +-- python/pyspark/serializers.py | 6 ++-- python/pyspark/worker.py | 8 ++--- 5 files changed, 57 insertions(+), 14 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 70516bde8b..9cbd26b607 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -64,7 +64,7 @@ private[spark] class PythonRDD[T: ClassTag]( // Partition index dataOut.writeInt(split.index) // sparkFilesDir - dataOut.writeUTF(SparkFiles.getRootDirectory) + PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { @@ -74,7 +74,9 @@ private[spark] class PythonRDD[T: ClassTag]( } // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.length) - pythonIncludes.foreach(dataOut.writeUTF) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } dataOut.flush() // Serialized command: dataOut.writeInt(command.length) @@ -228,7 +230,7 @@ private[spark] object PythonRDD { } case string: String => newIter.asInstanceOf[Iterator[String]].foreach { str => - dataOut.writeUTF(str) + writeUTF(str, dataOut) } case pair: Tuple2[_, _] => pair._1 match { @@ -241,8 +243,8 @@ private[spark] object PythonRDD { } case stringPair: String => newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => - dataOut.writeUTF(pair._1) - dataOut.writeUTF(pair._2) + writeUTF(pair._1, dataOut) + writeUTF(pair._2, dataOut) } case other => throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) @@ -253,6 +255,12 @@ private[spark] object PythonRDD { } } + def writeUTF(str: String, dataOut: DataOutputStream) { + val bytes = str.getBytes("UTF-8") + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + def writeToFile[T](items: java.util.Iterator[T], filename: String) { import scala.collection.JavaConverters._ writeToFile(items.asScala, filename) diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala new file mode 100644 index 0000000000..1bebfe5ec8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.api.python.PythonRDD + +import java.io.{ByteArrayOutputStream, DataOutputStream} + +class PythonRDDSuite extends FunSuite { + + test("Writing large strings to the worker") { + val input: List[String] = List("a"*100000) + val buffer = new DataOutputStream(new ByteArrayOutputStream) + PythonRDD.writeIteratorToStream(input.iterator, buffer) + } + +} + diff --git a/python/pyspark/context.py b/python/pyspark/context.py index f955aad7a4..f318b5d9a7 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -27,7 +27,7 @@ from pyspark.broadcast import Broadcast from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway -from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer +from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD @@ -234,7 +234,7 @@ class SparkContext(object): """ minSplits = minSplits or min(self.defaultParallelism, 2) return RDD(self._jsc.textFile(name, minSplits), self, - MUTF8Deserializer()) + UTF8Deserializer()) def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2a500ab919..8c6ad79059 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -261,13 +261,13 @@ class MarshalSerializer(FramedSerializer): loads = marshal.loads -class MUTF8Deserializer(Serializer): +class UTF8Deserializer(Serializer): """ - Deserializes streams written by Java's DataOutputStream.writeUTF(). + Deserializes streams written by getBytes. """ def loads(self, stream): - length = struct.unpack('>H', stream.read(2))[0] + length = read_int(stream) return stream.read(length).decode('utf8') def load_stream(self, stream): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d77981f61f..4be4063dcf 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,11 +30,11 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer pickleSer = PickleSerializer() -mutf8_deserializer = MUTF8Deserializer() +utf8_deserializer = UTF8Deserializer() def report_times(outfile, boot, init, finish): @@ -51,7 +51,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = mutf8_deserializer.loads(infile) + spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -66,7 +66,7 @@ def main(infile, outfile): sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - filename = mutf8_deserializer.loads(infile) + filename = utf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) command = pickleSer._read_with_length(infile) -- cgit v1.2.3