diff options
author | jerryshao <saisai.shao@intel.com> | 2014-06-29 23:00:00 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-06-29 23:00:00 -0700 |
commit | 66135a341d9f8baecc149d13ae5511f14578c395 (patch) | |
tree | a7abe8a4aeb0e33b8c78b53fde6c62dbd08bdf5a | |
parent | 7b71a0e09622e09285a9884ebb67b5fb1c5caa53 (diff) | |
download | spark-66135a341d9f8baecc149d13ae5511f14578c395.tar.gz spark-66135a341d9f8baecc149d13ae5511f14578c395.tar.bz2 spark-66135a341d9f8baecc149d13ae5511f14578c395.zip |
[SPARK-2104] Fix task serializing issues when sort with Java non serializable class
Details can be see in [SPARK-2104](https://issues.apache.org/jira/browse/SPARK-2104). This work is based on Reynold's work, add some unit tests to validate the issue.
@rxin , would you please take a look at this PR, thanks a lot.
Author: jerryshao <saisai.shao@intel.com>
Closes #1245 from jerryshao/SPARK-2104 and squashes the following commits:
c8ee362 [jerryshao] Make field partitions transient
2b41917 [jerryshao] Minor changes
47d763c [jerryshao] Fix task serializing issue when sort with Java non serializable class
-rw-r--r-- | core/src/main/scala/org/apache/spark/Partitioner.scala | 52 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/ShuffleSuite.scala | 42 |
2 files changed, 86 insertions, 8 deletions
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index e7f7548193..ec99648a84 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -17,11 +17,13 @@ package org.apache.spark +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import scala.reflect.ClassTag import org.apache.spark.rdd.RDD -import org.apache.spark.util.CollectionsUtils -import org.apache.spark.util.Utils +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.util.{CollectionsUtils, Utils} /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. @@ -96,15 +98,15 @@ class HashPartitioner(partitions: Int) extends Partitioner { * the value of `partitions`. */ class RangePartitioner[K : Ordering : ClassTag, V]( - partitions: Int, + @transient partitions: Int, @transient rdd: RDD[_ <: Product2[K,V]], - private val ascending: Boolean = true) + private var ascending: Boolean = true) extends Partitioner { - private val ordering = implicitly[Ordering[K]] + private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions - private val rangeBounds: Array[K] = { + private var rangeBounds: Array[K] = { if (partitions == 1) { Array() } else { @@ -127,7 +129,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( def numPartitions = rangeBounds.length + 1 - private val binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] + private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] def getPartition(key: Any): Int = { val k = key.asInstanceOf[K] @@ -173,4 +175,40 @@ class RangePartitioner[K : Ordering : ClassTag, V]( result = prime * result + ascending.hashCode result } + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream) { + val sfactory = SparkEnv.get.serializer + sfactory match { + case js: JavaSerializer => out.defaultWriteObject() + case _ => + out.writeBoolean(ascending) + out.writeObject(ordering) + out.writeObject(binarySearch) + + val ser = sfactory.newInstance() + Utils.serializeViaNestedStream(out, ser) { stream => + stream.writeObject(scala.reflect.classTag[Array[K]]) + stream.writeObject(rangeBounds) + } + } + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream) { + val sfactory = SparkEnv.get.serializer + sfactory match { + case js: JavaSerializer => in.defaultReadObject() + case _ => + ascending = in.readBoolean() + ordering = in.readObject().asInstanceOf[Ordering[K]] + binarySearch = in.readObject().asInstanceOf[(Array[K], K) => Int] + + val ser = sfactory.newInstance() + Utils.deserializeViaNestedStream(in, ser) { ds => + implicit val classTag = ds.readObject[ClassTag[Array[K]]]() + rangeBounds = ds.readObject[Array[K]]() + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index b40fee7e9a..c4f2f7e34f 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -206,6 +206,42 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // substracted rdd return results as Tuple2 results(0) should be ((3, 33)) } + + test("sort with Java non serializable class - Kryo") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + val conf = new SparkConf() + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .setAppName("test") + .setMaster("local-cluster[2,1,512]") + sc = new SparkContext(conf) + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (new NonJavaSerializableClass(x), x) + } + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = b.sortByKey().map(x => x._2) + assert(c.collect() === Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + } + + test("sort with Java non serializable class - Java") { + // Use a local cluster with 2 processes to make sure there are both local and remote blocks + val conf = new SparkConf() + .setAppName("test") + .setMaster("local-cluster[2,1,512]") + sc = new SparkContext(conf) + val a = sc.parallelize(1 to 10, 2) + val b = a.map { x => + (new NonJavaSerializableClass(x), x) + } + // default Java serializer cannot handle the non serializable class. + val thrown = intercept[SparkException] { + b.sortByKey().collect() + } + + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("NotSerializableException")) + } } object ShuffleSuite { @@ -215,5 +251,9 @@ object ShuffleSuite { x + y } - class NonJavaSerializableClass(val value: Int) + class NonJavaSerializableClass(val value: Int) extends Comparable[NonJavaSerializableClass] { + override def compareTo(o: NonJavaSerializableClass): Int = { + value - o.value + } + } } |