aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjerryshao <saisai.shao@intel.com>2014-06-29 23:00:00 -0700
committerReynold Xin <rxin@apache.org>2014-06-29 23:00:00 -0700
commit66135a341d9f8baecc149d13ae5511f14578c395 (patch)
treea7abe8a4aeb0e33b8c78b53fde6c62dbd08bdf5a
parent7b71a0e09622e09285a9884ebb67b5fb1c5caa53 (diff)
downloadspark-66135a341d9f8baecc149d13ae5511f14578c395.tar.gz
spark-66135a341d9f8baecc149d13ae5511f14578c395.tar.bz2
spark-66135a341d9f8baecc149d13ae5511f14578c395.zip
[SPARK-2104] Fix task serializing issues when sort with Java non serializable class
Details can be see in [SPARK-2104](https://issues.apache.org/jira/browse/SPARK-2104). This work is based on Reynold's work, add some unit tests to validate the issue. @rxin , would you please take a look at this PR, thanks a lot. Author: jerryshao <saisai.shao@intel.com> Closes #1245 from jerryshao/SPARK-2104 and squashes the following commits: c8ee362 [jerryshao] Make field partitions transient 2b41917 [jerryshao] Minor changes 47d763c [jerryshao] Fix task serializing issue when sort with Java non serializable class
-rw-r--r--core/src/main/scala/org/apache/spark/Partitioner.scala52
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala42
2 files changed, 86 insertions, 8 deletions
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index e7f7548193..ec99648a84 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -17,11 +17,13 @@
package org.apache.spark
+import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+
import scala.reflect.ClassTag
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.CollectionsUtils
-import org.apache.spark.util.Utils
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.util.{CollectionsUtils, Utils}
/**
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
@@ -96,15 +98,15 @@ class HashPartitioner(partitions: Int) extends Partitioner {
* the value of `partitions`.
*/
class RangePartitioner[K : Ordering : ClassTag, V](
- partitions: Int,
+ @transient partitions: Int,
@transient rdd: RDD[_ <: Product2[K,V]],
- private val ascending: Boolean = true)
+ private var ascending: Boolean = true)
extends Partitioner {
- private val ordering = implicitly[Ordering[K]]
+ private var ordering = implicitly[Ordering[K]]
// An array of upper bounds for the first (partitions - 1) partitions
- private val rangeBounds: Array[K] = {
+ private var rangeBounds: Array[K] = {
if (partitions == 1) {
Array()
} else {
@@ -127,7 +129,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
def numPartitions = rangeBounds.length + 1
- private val binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
+ private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
def getPartition(key: Any): Int = {
val k = key.asInstanceOf[K]
@@ -173,4 +175,40 @@ class RangePartitioner[K : Ordering : ClassTag, V](
result = prime * result + ascending.hashCode
result
}
+
+ @throws(classOf[IOException])
+ private def writeObject(out: ObjectOutputStream) {
+ val sfactory = SparkEnv.get.serializer
+ sfactory match {
+ case js: JavaSerializer => out.defaultWriteObject()
+ case _ =>
+ out.writeBoolean(ascending)
+ out.writeObject(ordering)
+ out.writeObject(binarySearch)
+
+ val ser = sfactory.newInstance()
+ Utils.serializeViaNestedStream(out, ser) { stream =>
+ stream.writeObject(scala.reflect.classTag[Array[K]])
+ stream.writeObject(rangeBounds)
+ }
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(in: ObjectInputStream) {
+ val sfactory = SparkEnv.get.serializer
+ sfactory match {
+ case js: JavaSerializer => in.defaultReadObject()
+ case _ =>
+ ascending = in.readBoolean()
+ ordering = in.readObject().asInstanceOf[Ordering[K]]
+ binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int]
+
+ val ser = sfactory.newInstance()
+ Utils.deserializeViaNestedStream(in, ser) { ds =>
+ implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
+ rangeBounds = ds.readObject[Array[K]]()
+ }
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index b40fee7e9a..c4f2f7e34f 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -206,6 +206,42 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// substracted rdd return results as Tuple2
results(0) should be ((3, 33))
}
+
+ test("sort with Java non serializable class - Kryo") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ val conf = new SparkConf()
+ .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ .setAppName("test")
+ .setMaster("local-cluster[2,1,512]")
+ sc = new SparkContext(conf)
+ val a = sc.parallelize(1 to 10, 2)
+ val b = a.map { x =>
+ (new NonJavaSerializableClass(x), x)
+ }
+ // If the Kryo serializer is not used correctly, the shuffle would fail because the
+ // default Java serializer cannot handle the non serializable class.
+ val c = b.sortByKey().map(x => x._2)
+ assert(c.collect() === Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+ }
+
+ test("sort with Java non serializable class - Java") {
+ // Use a local cluster with 2 processes to make sure there are both local and remote blocks
+ val conf = new SparkConf()
+ .setAppName("test")
+ .setMaster("local-cluster[2,1,512]")
+ sc = new SparkContext(conf)
+ val a = sc.parallelize(1 to 10, 2)
+ val b = a.map { x =>
+ (new NonJavaSerializableClass(x), x)
+ }
+ // default Java serializer cannot handle the non serializable class.
+ val thrown = intercept[SparkException] {
+ b.sortByKey().collect()
+ }
+
+ assert(thrown.getClass === classOf[SparkException])
+ assert(thrown.getMessage.contains("NotSerializableException"))
+ }
}
object ShuffleSuite {
@@ -215,5 +251,9 @@ object ShuffleSuite {
x + y
}
- class NonJavaSerializableClass(val value: Int)
+ class NonJavaSerializableClass(val value: Int) extends Comparable[NonJavaSerializableClass] {
+ override def compareTo(o: NonJavaSerializableClass): Int = {
+ value - o.value
+ }
+ }
}