aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorPatrick Wendell <pwendell@gmail.com>2014-04-13 08:58:37 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-04-13 08:58:37 -0700
commit4bc07eebbf5e2ea0c0b6f1642049515025d88d07 (patch)
treefc314a1c1d68055b04cdc37553669ea5f12c628b /core
parentca11919e6e97a62eb3e3ce882ffa29eae36f50f7 (diff)
downloadspark-4bc07eebbf5e2ea0c0b6f1642049515025d88d07.tar.gz
spark-4bc07eebbf5e2ea0c0b6f1642049515025d88d07.tar.bz2
spark-4bc07eebbf5e2ea0c0b6f1642049515025d88d07.zip
SPARK-1480: Clean up use of classloaders
The Spark codebase is a bit fast-and-loose when accessing classloaders and this has caused a few bugs to surface in master. This patch defines some utility methods for accessing classloaders. This makes the intention when accessing a classloader much more explicit in the code and fixes a few cases where the wrong one was chosen. case (a) -> We want the classloader that loaded Spark case (b) -> We want the context class loader, or if not present, we want (a) This patch provides a better fix for SPARK-1403 (https://issues.apache.org/jira/browse/SPARK-1403) than the current work around, which it reverts. It also fixes a previously unreported bug that the `./spark-submit` script did not work for running with `local` master. It didn't work because the executor classloader did not properly delegate to the context class loader (if it is defined) and in local mode the context class loader is set by the `./spark-submit` script. A unit test is added for that case. Author: Patrick Wendell <pwendell@gmail.com> Closes #398 from pwendell/class-loaders and squashes the following commits: b4a1a58 [Patrick Wendell] Minor clean up 14f1272 [Patrick Wendell] SPARK-1480: Clean up use of classloaders
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/Logging.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/ui/JettyUtils.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala32
11 files changed, 67 insertions, 29 deletions
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index 9d429dceeb..50d8e93e1f 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -22,6 +22,7 @@ import org.slf4j.{Logger, LoggerFactory}
import org.slf4j.impl.StaticLoggerBinder
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -115,8 +116,7 @@ trait Logging {
val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
if (!log4jInitialized && usingLog4j) {
val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
- val classLoader = this.getClass.getClassLoader
- Option(classLoader.getResource(defaultLogProps)) match {
+ Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
case Some(url) =>
PropertyConfigurator.configure(url)
log.info(s"Using Spark's default log4j profile: $defaultLogProps")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index c12bd922d4..f89b2bffd1 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -292,7 +292,7 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): MutableURLClassLoader = {
- val loader = this.getClass.getClassLoader
+ val currentLoader = Utils.getContextOrSparkClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
@@ -301,8 +301,8 @@ private[spark] class Executor(
}.toArray
val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false)
userClassPathFirst match {
- case true => new ChildExecutorURLClassLoader(urls, loader)
- case false => new ExecutorURLClassLoader(urls, loader)
+ case true => new ChildExecutorURLClassLoader(urls, currentLoader)
+ case false => new ExecutorURLClassLoader(urls, currentLoader)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index df36a06485..6fc702fdb1 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -50,21 +50,13 @@ private[spark] class MesosExecutorBackend
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
- val cl = Thread.currentThread.getContextClassLoader
- try {
- // Work around for SPARK-1480
- Thread.currentThread.setContextClassLoader(getClass.getClassLoader)
- logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
- this.driver = driver
- val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
- executor = new Executor(
- executorInfo.getExecutorId.getValue,
- slaveInfo.getHostname,
- properties)
- } finally {
- // Work around for SPARK-1480
- Thread.currentThread.setContextClassLoader(cl)
- }
+ logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
+ this.driver = driver
+ val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
+ executor = new Executor(
+ executorInfo.getExecutorId.getValue,
+ slaveInfo.getHostname,
+ properties)
}
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
index 3e3e18c353..1b7a5d1f19 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable
import scala.util.matching.Regex
import org.apache.spark.Logging
+import org.apache.spark.util.Utils
private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging {
@@ -50,7 +51,7 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi
try {
is = configFile match {
case Some(f) => new FileInputStream(f)
- case None => getClass.getClassLoader.getResourceAsStream(METRICS_CONF)
+ case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF)
}
if (is != null) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 083fb895d8..0b381308b6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -54,7 +54,6 @@ private[spark] object ResultTask {
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
{
- val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index e4eced383c..6c5827f75e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -23,6 +23,7 @@ import java.util.{NoSuchElementException, Properties}
import scala.xml.XML
import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.util.Utils
/**
* An interface to build Schedulable tree
@@ -72,7 +73,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
schedulerAllocFile.map { f =>
new FileInputStream(f)
}.getOrElse {
- getClass.getClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
+ Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index cb4ad4ae93..c9ad2b151d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -85,13 +85,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskEndReason](
- serializedData, getClass.getClassLoader)
+ serializedData, Utils.getSparkClassLoader)
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastropic if we can't
// deserialize the reason.
- val loader = Thread.currentThread.getContextClassLoader
+ val loader = Utils.getContextOrSparkClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case ex: Throwable => {}
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 5e5883554f..e9163deaf2 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ByteBufferInputStream
+import org.apache.spark.util.Utils
private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
extends SerializationStream {
@@ -86,7 +87,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
}
def deserializeStream(s: InputStream): DeserializationStream = {
- new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader)
+ new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
}
def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 62a4e3d0f6..3ae147a36c 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -33,6 +33,7 @@ import org.json4s.JValue
import org.json4s.jackson.JsonMethods.{pretty, render}
import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.util.Utils
/**
* Utilities for launching a web server using Jetty's HTTP Server class
@@ -124,7 +125,7 @@ private[spark] object JettyUtils extends Logging {
contextHandler.setInitParameter("org.eclipse.jetty.servlet.Default.gzip", "false")
val staticHandler = new DefaultServlet
val holder = new ServletHolder(staticHandler)
- Option(getClass.getClassLoader.getResource(resourceBase)) match {
+ Option(Utils.getSparkClassLoader.getResource(resourceBase)) match {
case Some(res) =>
holder.setInitParameter("resourceBase", res.toString)
case None =>
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 166f48ce73..a3af4e7b91 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -117,6 +117,21 @@ private[spark] object Utils extends Logging {
}
/**
+ * Get the ClassLoader which loaded Spark.
+ */
+ def getSparkClassLoader = getClass.getClassLoader
+
+ /**
+ * Get the Context ClassLoader on this thread or, if not present, the ClassLoader that
+ * loaded Spark.
+ *
+ * This should be used whenever passing a ClassLoader to Class.ForName or finding the currently
+ * active loader when setting up ClassLoader delegation chains.
+ */
+ def getContextOrSparkClassLoader =
+ Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)
+
+ /**
* Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}.
*/
def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = {
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala
index c40cfc0696..e2050e95a1 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala
@@ -17,12 +17,12 @@
package org.apache.spark.executor
-import java.io.File
import java.net.URLClassLoader
import org.scalatest.FunSuite
-import org.apache.spark.TestUtils
+import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, TestUtils}
+import org.apache.spark.util.Utils
class ExecutorURLClassLoaderSuite extends FunSuite {
@@ -63,5 +63,33 @@ class ExecutorURLClassLoaderSuite extends FunSuite {
}
}
+ test("driver sets context class loader in local mode") {
+ // Test the case where the driver program sets a context classloader and then runs a job
+ // in local mode. This is what happens when ./spark-submit is called with "local" as the
+ // master.
+ val original = Thread.currentThread().getContextClassLoader
+ val className = "ClassForDriverTest"
+ val jar = TestUtils.createJarWithClasses(Seq(className))
+ val contextLoader = new URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
+ Thread.currentThread().setContextClassLoader(contextLoader)
+
+ val sc = new SparkContext("local", "driverLoaderTest")
+
+ try {
+ sc.makeRDD(1 to 5, 2).mapPartitions { x =>
+ val loader = Thread.currentThread().getContextClassLoader
+ Class.forName(className, true, loader).newInstance()
+ Seq().iterator
+ }.count()
+ }
+ catch {
+ case e: SparkException if e.getMessage.contains("ClassNotFoundException") =>
+ fail("Local executor could not find class", e)
+ case t: Throwable => fail("Unexpected exception ", t)
+ }
+
+ sc.stop()
+ Thread.currentThread().setContextClassLoader(original)
+ }
}