aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@apache.org>2014-01-28 21:30:20 -0800
committerJosh Rosen <joshrosen@apache.org>2014-01-28 21:30:20 -0800
commitf8c742ce274fbae2a9e616d4c97469b6a22069bb (patch)
tree8ae129c4b291b4b5589a77b919f508c4535fbf2c
parent84670f2715392859624df290c1b52eb4ed4a9cb1 (diff)
parent1381fc72f7a34f690a98ab72cec8ffb61e0e564d (diff)
downloadspark-f8c742ce274fbae2a9e616d4c97469b6a22069bb.tar.gz
spark-f8c742ce274fbae2a9e616d4c97469b6a22069bb.tar.bz2
spark-f8c742ce274fbae2a9e616d4c97469b6a22069bb.zip
Merge pull request #523 from JoshRosen/SPARK-1043
Switch from MUTF8 to UTF8 in PySpark serializers. This fixes SPARK-1043, a bug introduced in 0.9.0 where PySpark couldn't serialize strings > 64kB. This fix was written by @tyro89 and @bouk in #512. This commit squashes and rebases their pull request in order to fix some merge conflicts.
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala18
-rw-r--r--core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala35
-rw-r--r--python/pyspark/context.py4
-rw-r--r--python/pyspark/serializers.py6
-rw-r--r--python/pyspark/worker.py8
5 files changed, 57 insertions, 14 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 70516bde8b..9cbd26b607 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -64,7 +64,7 @@ private[spark] class PythonRDD[T: ClassTag](
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
- dataOut.writeUTF(SparkFiles.getRootDirectory)
+ PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
@@ -74,7 +74,9 @@ private[spark] class PythonRDD[T: ClassTag](
}
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
- pythonIncludes.foreach(dataOut.writeUTF)
+ for (include <- pythonIncludes) {
+ PythonRDD.writeUTF(include, dataOut)
+ }
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
@@ -228,7 +230,7 @@ private[spark] object PythonRDD {
}
case string: String =>
newIter.asInstanceOf[Iterator[String]].foreach { str =>
- dataOut.writeUTF(str)
+ writeUTF(str, dataOut)
}
case pair: Tuple2[_, _] =>
pair._1 match {
@@ -241,8 +243,8 @@ private[spark] object PythonRDD {
}
case stringPair: String =>
newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
- dataOut.writeUTF(pair._1)
- dataOut.writeUTF(pair._2)
+ writeUTF(pair._1, dataOut)
+ writeUTF(pair._2, dataOut)
}
case other =>
throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
@@ -253,6 +255,12 @@ private[spark] object PythonRDD {
}
}
+ def writeUTF(str: String, dataOut: DataOutputStream) {
+ val bytes = str.getBytes("UTF-8")
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+
def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
writeToFile(items.asScala, filename)
diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
new file mode 100644
index 0000000000..1bebfe5ec8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.api.python.PythonRDD
+
+import java.io.{ByteArrayOutputStream, DataOutputStream}
+
+class PythonRDDSuite extends FunSuite {
+
+ test("Writing large strings to the worker") {
+ val input: List[String] = List("a"*100000)
+ val buffer = new DataOutputStream(new ByteArrayOutputStream)
+ PythonRDD.writeIteratorToStream(input.iterator, buffer)
+ }
+
+}
+
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index f955aad7a4..f318b5d9a7 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -27,7 +27,7 @@ from pyspark.broadcast import Broadcast
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
+from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
@@ -234,7 +234,7 @@ class SparkContext(object):
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
return RDD(self._jsc.textFile(name, minSplits), self,
- MUTF8Deserializer())
+ UTF8Deserializer())
def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 2a500ab919..8c6ad79059 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -261,13 +261,13 @@ class MarshalSerializer(FramedSerializer):
loads = marshal.loads
-class MUTF8Deserializer(Serializer):
+class UTF8Deserializer(Serializer):
"""
- Deserializes streams written by Java's DataOutputStream.writeUTF().
+ Deserializes streams written by getBytes.
"""
def loads(self, stream):
- length = struct.unpack('>H', stream.read(2))[0]
+ length = read_int(stream)
return stream.read(length).decode('utf8')
def load_stream(self, stream):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d77981f61f..4be4063dcf 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,11 +30,11 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
- write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
+ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
pickleSer = PickleSerializer()
-mutf8_deserializer = MUTF8Deserializer()
+utf8_deserializer = UTF8Deserializer()
def report_times(outfile, boot, init, finish):
@@ -51,7 +51,7 @@ def main(infile, outfile):
return
# fetch name of workdir
- spark_files_dir = mutf8_deserializer.loads(infile)
+ spark_files_dir = utf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
@@ -66,7 +66,7 @@ def main(infile, outfile):
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
- filename = mutf8_deserializer.loads(infile)
+ filename = utf8_deserializer.loads(infile)
sys.path.append(os.path.join(spark_files_dir, filename))
command = pickleSer._read_with_length(infile)