aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala71
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala23
6 files changed, 128 insertions, 4 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index eac1f2326a..fb3f7bd54b 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -99,6 +99,9 @@ private[spark] class Executor(
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ // Set the classloader for serializer
+ env.serializer.setDefaultClassLoader(urlClassLoader)
+
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 34bc312409..af33a2f2ca 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -63,7 +63,9 @@ extends DeserializationStream {
def close() { objIn.close() }
}
-private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
+private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader)
+ extends SerializerInstance {
+
def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
@@ -109,7 +111,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
- def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)
+ override def newInstance(): SerializerInstance = {
+ val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
+ new JavaSerializerInstance(counterReset, classLoader)
+ }
override def writeExternal(out: ObjectOutput) {
out.writeInt(counterReset)
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 85944eabcf..99682220b4 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf)
val instantiator = new EmptyScalaKryoInstantiator
val kryo = instantiator.newKryo()
kryo.setRegistrationRequired(registrationRequired)
- val classLoader = Thread.currentThread.getContextClassLoader
+
+ val oldClassLoader = Thread.currentThread.getContextClassLoader
+ val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
@@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf)
try {
val reg = Class.forName(regCls, true, classLoader).newInstance()
.asInstanceOf[KryoRegistrator]
+
+ // Use the default classloader when calling the user registrator.
+ Thread.currentThread.setContextClassLoader(classLoader)
reg.registerClasses(kryo)
} catch {
case e: Exception =>
throw new SparkException(s"Failed to invoke $regCls", e)
+ } finally {
+ Thread.currentThread.setContextClassLoader(oldClassLoader)
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index f2f5cea469..e674438c81 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
*/
@DeveloperApi
trait Serializer {
+
+ /**
+ * Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should
+ * make sure it is using this when set.
+ */
+ @volatile protected var defaultClassLoader: Option[ClassLoader] = None
+
+ /**
+ * Sets a class loader for the serializer to use in deserialization.
+ *
+ * @return this Serializer object
+ */
+ def setDefaultClassLoader(classLoader: ClassLoader): Serializer = {
+ defaultClassLoader = Some(classLoader)
+ this
+ }
+
def newInstance(): SerializerInstance
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
new file mode 100644
index 0000000000..11e8c9c4cb
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import org.apache.spark.util.Utils
+
+import com.esotericsoftware.kryo.Kryo
+import org.scalatest.FunSuite
+
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils}
+import org.apache.spark.SparkContext._
+import org.apache.spark.serializer.KryoDistributedTest._
+
+class KryoSerializerDistributedSuite extends FunSuite {
+
+ test("kryo objects are serialised consistently in different processes") {
+ val conf = new SparkConf(false)
+ conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
+ conf.set("spark.task.maxFailures", "1")
+
+ val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
+ conf.setJars(List(jar.getPath))
+
+ val sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ val original = Thread.currentThread.getContextClassLoader
+ val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
+ SparkEnv.get.serializer.setDefaultClassLoader(loader)
+
+ val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache()
+
+ // Randomly mix the keys so that the join below will require a shuffle with each partition
+ // sending data to multiple other partitions.
+ val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)}
+
+ // Join the two RDDs, and force evaluation
+ assert(shuffledRDD.join(cachedRDD).collect().size == 1)
+
+ LocalSparkContext.stop(sc)
+ }
+}
+
+object KryoDistributedTest {
+ class MyCustomClass
+
+ class AppJarRegistrator extends KryoRegistrator {
+ override def registerClasses(k: Kryo) {
+ val classLoader = Thread.currentThread.getContextClassLoader
+ k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader))
+ }
+ }
+
+ object AppJarRegistrator {
+ val customClassName = "KryoSerializerDistributedSuiteCustomClass"
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 3bf9efebb3..a579fd50bd 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
import org.scalatest.FunSuite
-import org.apache.spark.SharedSparkContext
+import org.apache.spark.{SparkConf, SharedSparkContext}
import org.apache.spark.serializer.KryoTest._
class KryoSerializerSuite extends FunSuite with SharedSparkContext {
@@ -217,8 +217,29 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance())
assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist"))
}
+
+ test("default class loader can be set by a different thread") {
+ val ser = new KryoSerializer(new SparkConf)
+
+ // First serialize the object
+ val serInstance = ser.newInstance()
+ val bytes = serInstance.serialize(new ClassLoaderTestingObject)
+
+ // Deserialize the object to make sure normal deserialization works
+ serInstance.deserialize[ClassLoaderTestingObject](bytes)
+
+ // Set a special, broken ClassLoader and make sure we get an exception on deserialization
+ ser.setDefaultClassLoader(new ClassLoader() {
+ override def loadClass(name: String) = throw new UnsupportedOperationException
+ })
+ intercept[UnsupportedOperationException] {
+ ser.newInstance().deserialize[ClassLoaderTestingObject](bytes)
+ }
+ }
}
+class ClassLoaderTestingObject
+
class KryoSerializerResizableOutputSuite extends FunSuite {
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext