aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala')
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala17
4 files changed, 35 insertions, 3 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index eac1f2326a..fb3f7bd54b 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -99,6 +99,9 @@ private[spark] class Executor(
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ // Set the classloader for serializer
+ env.serializer.setDefaultClassLoader(urlClassLoader)
+
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 34bc312409..af33a2f2ca 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -63,7 +63,9 @@ extends DeserializationStream {
def close() { objIn.close() }
}
-private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
+private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader)
+ extends SerializerInstance {
+
def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
@@ -109,7 +111,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
- def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)
+ override def newInstance(): SerializerInstance = {
+ val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
+ new JavaSerializerInstance(counterReset, classLoader)
+ }
override def writeExternal(out: ObjectOutput) {
out.writeInt(counterReset)
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 85944eabcf..99682220b4 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf)
val instantiator = new EmptyScalaKryoInstantiator
val kryo = instantiator.newKryo()
kryo.setRegistrationRequired(registrationRequired)
- val classLoader = Thread.currentThread.getContextClassLoader
+
+ val oldClassLoader = Thread.currentThread.getContextClassLoader
+ val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
// Do this before we invoke the user registrator so the user registrator can override this.
@@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf)
try {
val reg = Class.forName(regCls, true, classLoader).newInstance()
.asInstanceOf[KryoRegistrator]
+
+ // Use the default classloader when calling the user registrator.
+ Thread.currentThread.setContextClassLoader(classLoader)
reg.registerClasses(kryo)
} catch {
case e: Exception =>
throw new SparkException(s"Failed to invoke $regCls", e)
+ } finally {
+ Thread.currentThread.setContextClassLoader(oldClassLoader)
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index f2f5cea469..e674438c81 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
*/
@DeveloperApi
trait Serializer {
+
+ /**
+ * Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should
+ * make sure it is using this when set.
+ */
+ @volatile protected var defaultClassLoader: Option[ClassLoader] = None
+
+ /**
+ * Sets a class loader for the serializer to use in deserialization.
+ *
+ * @return this Serializer object
+ */
+ def setDefaultClassLoader(classLoader: ClassLoader): Serializer = {
+ defaultClassLoader = Some(classLoader)
+ this
+ }
+
def newInstance(): SerializerInstance
}