aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorPatrick Wendell <pwendell@gmail.com>2014-01-23 19:08:34 -0800
committerPatrick Wendell <pwendell@gmail.com>2014-01-23 19:08:34 -0800
commitcad3002fead89d3c9a8de4fa989e88f367bc0b05 (patch)
treef3b623618e384e14925d57158efce1ca755e67da /core/src
parentfad6aacfb0a2ac3766417e4a0e3933277ce99d98 (diff)
parent61569906ccafe4f1d10a61882d564e4bb16665ef (diff)
downloadspark-cad3002fead89d3c9a8de4fa989e88f367bc0b05.tar.gz
spark-cad3002fead89d3c9a8de4fa989e88f367bc0b05.tar.bz2
spark-cad3002fead89d3c9a8de4fa989e88f367bc0b05.zip
Merge pull request #501 from JoshRosen/cartesian-rdd-fixes
Fix two bugs in PySpark cartesian(): SPARK-978 and SPARK-1034 This pull request fixes two bugs in PySpark's `cartesian()` method: - [SPARK-978](https://spark-project.atlassian.net/browse/SPARK-978): PySpark's cartesian method throws ClassCastException exception - [SPARK-1034](https://spark-project.atlassian.net/browse/SPARK-1034): Py4JException on PySpark Cartesian Result The JIRAs have more details describing the fixes.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala59
2 files changed, 40 insertions, 22 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 0fb7e195b3..f430a33db1 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -49,8 +49,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
- override val classTag: ClassTag[(K, V)] =
- implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]]
+ override val classTag: ClassTag[(K, V)] = rdd.elementClassTag
import JavaPairRDD._
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 82527fe663..57bde8d85f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -78,9 +78,7 @@ private[spark] class PythonRDD[T: ClassTag](
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
- for (elem <- parent.iterator(split, context)) {
- PythonRDD.writeToStream(elem, dataOut)
- }
+ PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.flush()
worker.shutdownOutput()
} catch {
@@ -206,20 +204,43 @@ private[spark] object PythonRDD {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
- def writeToStream(elem: Any, dataOut: DataOutputStream) {
- elem match {
- case bytes: Array[Byte] =>
- dataOut.writeInt(bytes.length)
- dataOut.write(bytes)
- case pair: (Array[Byte], Array[Byte]) =>
- dataOut.writeInt(pair._1.length)
- dataOut.write(pair._1)
- dataOut.writeInt(pair._2.length)
- dataOut.write(pair._2)
- case str: String =>
- dataOut.writeUTF(str)
- case other =>
- throw new SparkException("Unexpected element type " + other.getClass)
+ def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
+ // The right way to implement this would be to use TypeTags to get the full
+ // type of T. Since I don't want to introduce breaking changes throughout the
+ // entire Spark API, I have to use this hacky approach:
+ if (iter.hasNext) {
+ val first = iter.next()
+ val newIter = Seq(first).iterator ++ iter
+ first match {
+ case arr: Array[Byte] =>
+ newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes =>
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+ case string: String =>
+ newIter.asInstanceOf[Iterator[String]].foreach { str =>
+ dataOut.writeUTF(str)
+ }
+ case pair: Tuple2[_, _] =>
+ pair._1 match {
+ case bytePair: Array[Byte] =>
+ newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair =>
+ dataOut.writeInt(pair._1.length)
+ dataOut.write(pair._1)
+ dataOut.writeInt(pair._2.length)
+ dataOut.write(pair._2)
+ }
+ case stringPair: String =>
+ newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
+ dataOut.writeUTF(pair._1)
+ dataOut.writeUTF(pair._2)
+ }
+ case other =>
+ throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
+ }
+ case other =>
+ throw new SparkException("Unexpected element type " + first.getClass)
+ }
}
}
@@ -230,9 +251,7 @@ private[spark] object PythonRDD {
def writeToFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
- for (item <- items) {
- writeToStream(item, file)
- }
+ writeIteratorToStream(items, file)
file.close()
}