aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala66
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala23
3 files changed, 41 insertions, 53 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 119e0459c5..b89effc16d 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -316,6 +316,7 @@ private object SpecialLengths {
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
val END_OF_STREAM = -4
+ val NULL = -5
}
private[spark] object PythonRDD extends Logging {
@@ -374,54 +375,25 @@ private[spark] object PythonRDD extends Logging {
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
- // The right way to implement this would be to use TypeTags to get the full
- // type of T. Since I don't want to introduce breaking changes throughout the
- // entire Spark API, I have to use this hacky approach:
- if (iter.hasNext) {
- val first = iter.next()
- val newIter = Seq(first).iterator ++ iter
- first match {
- case arr: Array[Byte] =>
- newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes =>
- dataOut.writeInt(bytes.length)
- dataOut.write(bytes)
- }
- case string: String =>
- newIter.asInstanceOf[Iterator[String]].foreach { str =>
- writeUTF(str, dataOut)
- }
- case stream: PortableDataStream =>
- newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream =>
- val bytes = stream.toArray()
- dataOut.writeInt(bytes.length)
- dataOut.write(bytes)
- }
- case (key: String, stream: PortableDataStream) =>
- newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach {
- case (key, stream) =>
- writeUTF(key, dataOut)
- val bytes = stream.toArray()
- dataOut.writeInt(bytes.length)
- dataOut.write(bytes)
- }
- case (key: String, value: String) =>
- newIter.asInstanceOf[Iterator[(String, String)]].foreach {
- case (key, value) =>
- writeUTF(key, dataOut)
- writeUTF(value, dataOut)
- }
- case (key: Array[Byte], value: Array[Byte]) =>
- newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
- case (key, value) =>
- dataOut.writeInt(key.length)
- dataOut.write(key)
- dataOut.writeInt(value.length)
- dataOut.write(value)
- }
- case other =>
- throw new SparkException("Unexpected element type " + first.getClass)
- }
+
+ def write(obj: Any): Unit = obj match {
+ case null =>
+ dataOut.writeInt(SpecialLengths.NULL)
+ case arr: Array[Byte] =>
+ dataOut.writeInt(arr.length)
+ dataOut.write(arr)
+ case str: String =>
+ writeUTF(str, dataOut)
+ case stream: PortableDataStream =>
+ write(stream.toArray())
+ case (key, value) =>
+ write(key)
+ write(value)
+ case other =>
+ throw new SparkException("Unexpected element type " + other.getClass)
}
+
+ iter.foreach(write)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index be5ebfa921..b7cfc8bd9c 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -22,6 +22,7 @@ import java.io.{File, InputStream, IOException, OutputStream}
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkContext
+import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
private[spark] object PythonUtils {
/** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */
@@ -39,4 +40,8 @@ private[spark] object PythonUtils {
def mergePythonPaths(paths: String*): String = {
paths.filter(_ != "").mkString(File.pathSeparator)
}
+
+ def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = {
+ sc.parallelize(List("a", null, "b"))
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
index 7b866f08a0..c63d834f90 100644
--- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -23,11 +23,22 @@ import org.scalatest.FunSuite
class PythonRDDSuite extends FunSuite {
- test("Writing large strings to the worker") {
- val input: List[String] = List("a"*100000)
- val buffer = new DataOutputStream(new ByteArrayOutputStream)
- PythonRDD.writeIteratorToStream(input.iterator, buffer)
- }
+ test("Writing large strings to the worker") {
+ val input: List[String] = List("a"*100000)
+ val buffer = new DataOutputStream(new ByteArrayOutputStream)
+ PythonRDD.writeIteratorToStream(input.iterator, buffer)
+ }
+ test("Handle nulls gracefully") {
+ val buffer = new DataOutputStream(new ByteArrayOutputStream)
+ // Should not have NPE when write an Iterator with null in it
+ // The correctness will be tested in Python
+ PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
+ PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
+ PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer)
+ PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer)
+ PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
+ PythonRDD.writeIteratorToStream(
+ Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer)
+ }
}
-