aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-08-15 17:04:15 -0700
committerReynold Xin <rxin@apache.org>2014-08-15 17:04:15 -0700
commitcc3648774e9a744850107bb187f2828d447e0a48 (patch)
tree66c0d5194762cdc156d2769e2bcd912302dac62c /core
parentc7032290a3f0f5545aa4f0a9a144c62571344dc8 (diff)
downloadspark-cc3648774e9a744850107bb187f2828d447e0a48.tar.gz
spark-cc3648774e9a744850107bb187f2828d447e0a48.tar.bz2
spark-cc3648774e9a744850107bb187f2828d447e0a48.zip
[SPARK-3046] use executor's class loader as the default serializer classloader
The serializer is not always used in an executor thread (e.g. connection manager, broadcast), in which case the classloader might not have the user jar set, leading to corruption in deserialization. https://issues.apache.org/jira/browse/SPARK-3046 https://issues.apache.org/jira/browse/SPARK-2878 Author: Reynold Xin <rxin@apache.org> Closes #1972 from rxin/kryoBug and squashes the following commits: c1c7bf0 [Reynold Xin] Made change to JavaSerializer. 7204c33 [Reynold Xin] Added imports back. d879e67 [Reynold Xin] [SPARK-3046] use executor's class loader as the default serializer class loader.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala71
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala23
6 files changed, 128 insertions, 4 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index eac1f2326a..fb3f7bd54b 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -99,6 +99,9 @@ private[spark] class Executor(
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ // Set the classloader for serializer
+ env.serializer.setDefaultClassLoader(urlClassLoader)
+
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 34bc312409..af33a2f2ca 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -63,7 +63,9 @@ extends DeserializationStream {
def close() { objIn.close() }
}
-private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
+private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader)
+ extends SerializerInstance {
+
def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
@@ -109,7 +111,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
- def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)
+ override def newInstance(): SerializerInstance = {
+ val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
+ new JavaSerializerInstance(counterReset, classLoader)
+ }
override def writeExternal(out: ObjectOutput) {
out.writeInt(counterReset)
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 85944eabcf..99682220b4 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf)
val instantiator = new EmptyScalaKryoInstantiator
val kryo = instantiator.newKryo()
kryo.setRegistrationRequired(registrationRequired)
- val classLoader = Thread.currentThread.getContextClassLoader
+
+ val oldClassLoader = Thread.currentThread.getContextClassLoader
+ val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
@@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf)
try {
val reg = Class.forName(regCls, true, classLoader).newInstance()
.asInstanceOf[KryoRegistrator]
+
+ // Use the default classloader when calling the user registrator.
+ Thread.currentThread.setContextClassLoader(classLoader)
reg.registerClasses(kryo)
} catch {
case e: Exception =>
throw new SparkException(s"Failed to invoke $regCls", e)
+ } finally {
+ Thread.currentThread.setContextClassLoader(oldClassLoader)
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index f2f5cea469..e674438c81 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
*/
@DeveloperApi
trait Serializer {
+
+ /**
+ * Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should
+ * make sure it is using this when set.
+ */
+ @volatile protected var defaultClassLoader: Option[ClassLoader] = None
+
+ /**
+ * Sets a class loader for the serializer to use in deserialization.
+ *
+ * @return this Serializer object
+ */
+ def setDefaultClassLoader(classLoader: ClassLoader): Serializer = {
+ defaultClassLoader = Some(classLoader)
+ this
+ }
+
def newInstance(): SerializerInstance
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
new file mode 100644
index 0000000000..11e8c9c4cb
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import org.apache.spark.util.Utils
+
+import com.esotericsoftware.kryo.Kryo
+import org.scalatest.FunSuite
+
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils}
+import org.apache.spark.SparkContext._
+import org.apache.spark.serializer.KryoDistributedTest._
+
+class KryoSerializerDistributedSuite extends FunSuite {
+
+ test("kryo objects are serialised consistently in different processes") {
+ val conf = new SparkConf(false)
+ conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
+ conf.set("spark.task.maxFailures", "1")
+
+ val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
+ conf.setJars(List(jar.getPath))
+
+ val sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ val original = Thread.currentThread.getContextClassLoader
+ val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
+ SparkEnv.get.serializer.setDefaultClassLoader(loader)
+
+ val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache()
+
+ // Randomly mix the keys so that the join below will require a shuffle with each partition
+ // sending data to multiple other partitions.
+ val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)}
+
+ // Join the two RDDs, and force evaluation
+ assert(shuffledRDD.join(cachedRDD).collect().size == 1)
+
+ LocalSparkContext.stop(sc)
+ }
+}
+
+object KryoDistributedTest {
+ class MyCustomClass
+
+ class AppJarRegistrator extends KryoRegistrator {
+ override def registerClasses(k: Kryo) {
+ val classLoader = Thread.currentThread.getContextClassLoader
+ k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader))
+ }
+ }
+
+ object AppJarRegistrator {
+ val customClassName = "KryoSerializerDistributedSuiteCustomClass"
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 3bf9efebb3..a579fd50bd 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
import org.scalatest.FunSuite
-import org.apache.spark.SharedSparkContext
+import org.apache.spark.{SparkConf, SharedSparkContext}
import org.apache.spark.serializer.KryoTest._
class KryoSerializerSuite extends FunSuite with SharedSparkContext {
@@ -217,8 +217,29 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance())
assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist"))
}
+
+ test("default class loader can be set by a different thread") {
+ val ser = new KryoSerializer(new SparkConf)
+
+ // First serialize the object
+ val serInstance = ser.newInstance()
+ val bytes = serInstance.serialize(new ClassLoaderTestingObject)
+
+ // Deserialize the object to make sure normal deserialization works
+ serInstance.deserialize[ClassLoaderTestingObject](bytes)
+
+ // Set a special, broken ClassLoader and make sure we get an exception on deserialization
+ ser.setDefaultClassLoader(new ClassLoader() {
+ override def loadClass(name: String) = throw new UnsupportedOperationException
+ })
+ intercept[UnsupportedOperationException] {
+ ser.newInstance().deserialize[ClassLoaderTestingObject](bytes)
+ }
+ }
}
+class ClassLoaderTestingObject
+
class KryoSerializerResizableOutputSuite extends FunSuite {
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext