From 0121a26bd150e5f76d950e08cf4d536fad635a40 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Fri, 28 Sep 2012 16:14:05 -0700 Subject: Changed the way tasks' dependency files are sent to workers so that custom serializers or Kryo registrators can be loaded. --- .../examples/WikipediaPageRankStandalone.scala | 5 +- core/src/main/scala/spark/JavaSerializer.scala | 2 +- core/src/main/scala/spark/KryoSerializer.scala | 5 +- core/src/main/scala/spark/Serializer.scala | 2 +- core/src/main/scala/spark/SparkEnv.scala | 48 +++------ core/src/main/scala/spark/executor/Executor.scala | 56 ++++++----- .../spark/executor/ExecutorURLClassLoader.scala | 15 +++ .../scala/spark/scheduler/ShuffleMapTask.scala | 41 -------- core/src/main/scala/spark/scheduler/Task.scala | 107 +++++++++++++++------ .../spark/scheduler/cluster/ClusterScheduler.scala | 6 +- .../spark/scheduler/cluster/TaskSetManager.scala | 3 +- .../spark/scheduler/local/LocalScheduler.scala | 68 +++++++++---- .../scala/spark/storage/BlockManagerMaster.scala | 2 + .../scala/spark/util/ByteBufferInputStream.scala | 2 + 14 files changed, 206 insertions(+), 156 deletions(-) create mode 100644 core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala index ed8ace3a57..8ced0f9c73 100644 --- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala +++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala @@ -142,7 +142,7 @@ class WPRSerializerInstance extends SerializerInstance { class WPRSerializationStream(os: OutputStream) extends SerializationStream { val dos = new DataOutputStream(os) - def writeObject[T](t: T): Unit = t match { + def writeObject[T](t: T): SerializationStream = t match { case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { case links: Array[String] => { dos.writeInt(0) // links @@ -151,17 +151,20 @@ class WPRSerializationStream(os: OutputStream) extends SerializationStream { for (link <- links) { dos.writeUTF(link) } + this } case rank: Double => { dos.writeInt(1) // rank dos.writeUTF(id) dos.writeDouble(rank) + this } } case (id: String, rank: Double) => { dos.writeInt(2) // rank without wrapper dos.writeUTF(id) dos.writeDouble(rank) + this } } diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala index d11ba5167d..1511c2620e 100644 --- a/core/src/main/scala/spark/JavaSerializer.scala +++ b/core/src/main/scala/spark/JavaSerializer.scala @@ -7,7 +7,7 @@ import spark.util.ByteBufferInputStream class JavaSerializationStream(out: OutputStream) extends SerializationStream { val objOut = new ObjectOutputStream(out) - def writeObject[T](t: T) { objOut.writeObject(t) } + def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this } def flush() { objOut.flush() } def close() { objOut.close() } } diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 8aa27a747b..376fcff4c8 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -72,12 +72,13 @@ class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputS extends SerializationStream { val channel = Channels.newChannel(out) - def writeObject[T](t: T) { + def writeObject[T](t: T): SerializationStream = { kryo.writeClassAndObject(threadBuffer, t) ZigZag.writeInt(threadBuffer.position(), out) threadBuffer.flip() channel.write(threadBuffer) threadBuffer.clear() + this } def flush() { out.flush() } @@ -161,6 +162,8 @@ trait KryoRegistrator { } class KryoSerializer extends Serializer with Logging { + // Make this lazy so that it only gets called once we receive our first task on each executor, + // so we can pull out any custom Kryo registrator from the user's JARs. lazy val kryo = createKryo() val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024 diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala index 5f26bd2a7b..9ec07cc173 100644 --- a/core/src/main/scala/spark/Serializer.scala +++ b/core/src/main/scala/spark/Serializer.scala @@ -51,7 +51,7 @@ trait SerializerInstance { * A stream for writing serialized objects. */ trait SerializationStream { - def writeObject[T](t: T): Unit + def writeObject[T](t: T): SerializationStream def flush(): Unit def close(): Unit diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 6ffae8e85f..2c9f46b1a0 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -74,11 +74,18 @@ object SparkEnv { System.setProperty("spark.master.port", boundPort.toString) } - val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer") - val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] + val classLoader = Thread.currentThread.getContextClassLoader + + // Create an instance of the class named by the given Java system property, or by + // defaultClassName if the property is not set, and return it as a T + def instantiateClass[T](propertyName: String, defaultClassName: String): T = { + val name = System.getProperty(propertyName, defaultClassName) + Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] + } + + val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal) - val blockManager = new BlockManager(blockManagerMaster, serializer) val connectionManager = blockManager.connectionManager @@ -87,45 +94,22 @@ object SparkEnv { val broadcastManager = new BroadcastManager(isMaster) - val closureSerializerClass = - System.getProperty("spark.closure.serializer", "spark.JavaSerializer") - val closureSerializer = - Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer] - val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache") - val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] + val closureSerializer = instantiateClass[Serializer]( + "spark.closure.serializer", "spark.JavaSerializer") + + val cache = instantiateClass[Cache]("spark.cache.class", "spark.BoundedMemoryCache") val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager) blockManager.cacheTracker = cacheTracker val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) - val shuffleFetcherClass = - System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") - val shuffleFetcher = - Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher] + val shuffleFetcher = instantiateClass[ShuffleFetcher]( + "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") val httpFileServer = new HttpFileServer() httpFileServer.initialize() System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) - - /* - if (System.getProperty("spark.stream.distributed", "false") == "true") { - val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]] - if (isLocal || !isMaster) { - (new Thread() { - override def run() { - println("Wait started") - Thread.sleep(60000) - println("Wait ended") - val receiverClass = Class.forName("spark.stream.TestStreamReceiver4") - val constructor = receiverClass.getConstructor(blockManagerClass) - val receiver = constructor.newInstance(blockManager) - receiver.asInstanceOf[Thread].start() - } - }).start() - } - } - */ new SparkEnv( actorSystem, diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 9999b6ba80..820428c727 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -20,10 +20,11 @@ class Executor extends Logging { var urlClassLoader : ExecutorURLClassLoader = null var threadPool: ExecutorService = null var env: SparkEnv = null - - val fileSet: HashMap[String, Long] = new HashMap[String, Long]() - val jarSet: HashMap[String, Long] = new HashMap[String, Long]() - + + // Application dependencies (added through SparkContext) that we've fetched so far on this node. + // Each map holds the master's timestamp for the version of that file or JAR we got. + val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() + val currentJars: HashMap[String, Long] = new HashMap[String, Long]() val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) @@ -67,9 +68,9 @@ class Executor extends Logging { try { SparkEnv.set(env) Accumulators.clear() - val task = ser.deserialize[Task[Any]](serializedTask, urlClassLoader) - task.downloadDependencies(fileSet, jarSet) - updateClassLoader() + val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) + updateDependencies(taskFiles, taskJars) + val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) logInfo("Its generation is " + task.generation) env.mapOutputTracker.updateGeneration(task.generation) val value = task.run(taskId.toInt) @@ -104,12 +105,11 @@ class Executor extends Logging { * created by the interpreter to the search path */ private def createClassLoader(): ExecutorURLClassLoader = { - - var loader = this.getClass().getClassLoader() + var loader = this.getClass.getClassLoader // For each of the jars in the jarSet, add them to the class loader. // We assume each of the files has already been fetched. - val urls = jarSet.keySet.map { uri => + val urls = currentJars.keySet.map { uri => new File(uri.split("/").last).toURI.toURL }.toArray loader = new URLClassLoader(urls, loader) @@ -134,22 +134,28 @@ class Executor extends Logging { return new ExecutorURLClassLoader(Array(), loader) } - def updateClassLoader() { - val currentURLs = urlClassLoader.getURLs() - val urlSet = jarSet.keySet.map { x => new File(x.split("/").last).toURI.toURL } - urlSet.filterNot(currentURLs.contains(_)).foreach { url => - logInfo("Adding " + url + " to the class loader.") - urlClassLoader.addURL(url) + /** + * Download any missing dependencies if we receive a new set of files and JARs from the + * SparkContext. Also adds any new JARs we fetched to the class loader. + */ + private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name) + Utils.fetchFile(name, new File(".")) + currentFiles(name) = timestamp } - - } - - // The addURL method in URLClassLoader is protected. We subclass it to make it accessible. - class ExecutorURLClassLoader(urls : Array[URL], parent : ClassLoader) - extends URLClassLoader(urls, parent) { - override def addURL(url: URL) { - super.addURL(url) + for ((name, timestamp) <- newJars if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name) + Utils.fetchFile(name, new File(".")) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(".", localName).toURI.toURL + if (!urlClassLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } } } - } diff --git a/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala new file mode 100644 index 0000000000..f74f036c4c --- /dev/null +++ b/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala @@ -0,0 +1,15 @@ +package spark.executor + +import java.net.{URLClassLoader, URL} + +/** + * The addURL method in URLClassLoader is protected. We subclass it to make this accessible. + */ +private[spark] +class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) + extends URLClassLoader(urls, parent) { + + override def addURL(url: URL) { + super.addURL(url) + } +} diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 745aa0c939..d70a061366 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -21,8 +21,6 @@ object ShuffleMapTask { // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new JHashMap[Int, Array[Byte]] - val fileSetCache = new JHashMap[Int, Array[Byte]] - val jarSetCache = new JHashMap[Int, Array[Byte]] def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = { synchronized { @@ -43,23 +41,6 @@ object ShuffleMapTask { } } - // Since both the JarSet and FileSet have the same format this is used for both. - def serializeFileSet( - set : HashMap[String, Long], stageId: Int, cache : JHashMap[Int, Array[Byte]]) : Array[Byte] = { - val old = cache.get(stageId) - if (old != null) { - return old - } else { - val out = new ByteArrayOutputStream - val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - objOut.writeObject(set.toArray) - objOut.close() - val bytes = out.toByteArray - cache.put(stageId, bytes) - return bytes - } - } - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = { synchronized { val loader = Thread.currentThread.getContextClassLoader @@ -83,8 +64,6 @@ object ShuffleMapTask { def clearCache() { synchronized { serializedInfoCache.clear() - fileSetCache.clear() - jarSetCache.clear() } } } @@ -112,15 +91,6 @@ class ShuffleMapTask( val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) out.writeInt(bytes.length) out.write(bytes) - - val fileSetBytes = ShuffleMapTask.serializeFileSet( - fileSet, stageId, ShuffleMapTask.fileSetCache) - out.writeInt(fileSetBytes.length) - out.write(fileSetBytes) - val jarSetBytes = ShuffleMapTask.serializeFileSet(jarSet, stageId, ShuffleMapTask.jarSetCache) - out.writeInt(jarSetBytes.length) - out.write(jarSetBytes) - out.writeInt(partition) out.writeLong(generation) out.writeObject(split) @@ -134,17 +104,6 @@ class ShuffleMapTask( val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) rdd = rdd_ dep = dep_ - - val fileSetNumBytes = in.readInt() - val fileSetBytes = new Array[Byte](fileSetNumBytes) - in.readFully(fileSetBytes) - fileSet = ShuffleMapTask.deserializeFileSet(fileSetBytes) - - val jarSetNumBytes = in.readInt() - val jarSetBytes = new Array[Byte](jarSetNumBytes) - in.readFully(jarSetBytes) - jarSet = ShuffleMapTask.deserializeFileSet(jarSetBytes) - partition = in.readInt() generation = in.readLong() split = in.readObject().asInstanceOf[Split] diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala index 6128e0b273..d69c259362 100644 --- a/core/src/main/scala/spark/scheduler/Task.scala +++ b/core/src/main/scala/spark/scheduler/Task.scala @@ -1,9 +1,12 @@ package spark.scheduler import scala.collection.mutable.{HashMap} -import spark.HttpFileServer -import spark.Utils -import java.io.File +import spark.{SerializerInstance, Serializer, Utils} +import java.io.{DataInputStream, DataOutputStream, File} +import java.nio.ByteBuffer +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import spark.util.ByteBufferInputStream +import scala.collection.mutable.HashMap /** * A task to execute on a worker node. @@ -13,30 +16,80 @@ abstract class Task[T](val stageId: Int) extends Serializable { def preferredLocations: Seq[String] = Nil var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler. - - // Stores jar and file dependencies for this task. - var fileSet : HashMap[String, Long] = new HashMap[String, Long]() - var jarSet : HashMap[String, Long] = new HashMap[String, Long]() - - // Downloads all file dependencies from the Master file server - def downloadDependencies(currentFileSet : HashMap[String, Long], - currentJarSet : HashMap[String, Long]) { - - // Fetch missing file dependencies - fileSet.filter { case(k,v) => - !currentFileSet.contains(k) || currentFileSet(k) < v - }.foreach { case (k,v) => - Utils.fetchFile(k, new File(System.getProperty("user.dir"))) - currentFileSet(k) = v +} + +/** + * Handles transmission of tasks and their dependencies, because this can be slightly tricky. We + * need to send the list of JARs and files added to the SparkContext with each task to ensure that + * worker nodes find out about it, but we can't make it part of the Task because the user's code in + * the task might depend on one of the JARs. Thus we serialize each task as multiple objects, by + * first writing out its dependencies. + */ +object Task { + /** + * Serialize a task and the current app dependencies (files and JARs added to the SparkContext) + */ + def serializeWithDependencies( + task: Task[_], + currentFiles: HashMap[String, Long], + currentJars: HashMap[String, Long], + serializer: SerializerInstance) + : ByteBuffer = { + + val out = new FastByteArrayOutputStream(4096) + val dataOut = new DataOutputStream(out) + + // Write currentFiles + dataOut.writeInt(currentFiles.size) + for ((name, timestamp) <- currentFiles) { + dataOut.writeUTF(name) + dataOut.writeLong(timestamp) } - // Fetch missing jar dependencies - jarSet.filter { case(k,v) => - !currentJarSet.contains(k) || currentJarSet(k) < v - }.foreach { case (k,v) => - Utils.fetchFile(k, new File(System.getProperty("user.dir"))) - currentJarSet(k) = v + + // Write currentJars + dataOut.writeInt(currentJars.size) + for ((name, timestamp) <- currentJars) { + dataOut.writeUTF(name) + dataOut.writeLong(timestamp) } - + + // Write the task itself and finish + dataOut.flush() + val taskBytes = serializer.serialize(task).array() + out.write(taskBytes) + out.trim() + ByteBuffer.wrap(out.array) } - -} + + /** + * Deserialize the list of dependencies in a task serialized with serializeWithDependencies, + * and return the task itself as a serialized ByteBuffer. The caller can then update its + * ClassLoaders and deserialize the task. + * + * @return (taskFiles, taskJars, taskBytes) + */ + def deserializeWithDependencies(serializedTask: ByteBuffer) + : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { + + val in = new ByteBufferInputStream(serializedTask) + val dataIn = new DataInputStream(in) + + // Read task's files + val taskFiles = new HashMap[String, Long]() + val numFiles = dataIn.readInt() + for (i <- 0 until numFiles) { + taskFiles(dataIn.readUTF()) = dataIn.readLong() + } + + // Read task's JARs + val taskJars = new HashMap[String, Long]() + val numJars = dataIn.readInt() + for (i <- 0 until numJars) { + taskJars(dataIn.readUTF()) = dataIn.readLong() + } + + // Create a sub-buffer for the rest of the data, which is the serialized Task object + val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task + (taskFiles, taskJars, subBuffer) + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 952c9766bf..16fe5761c8 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -16,7 +16,7 @@ import java.util.concurrent.atomic.AtomicLong * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call * start(), then submit task sets through the runTasks method. */ -class ClusterScheduler(sc: SparkContext) +class ClusterScheduler(val sc: SparkContext) extends TaskScheduler with Logging { @@ -87,10 +87,6 @@ class ClusterScheduler(sc: SparkContext) def submitTasks(taskSet: TaskSet) { val tasks = taskSet.tasks - tasks.foreach { task => - task.fileSet ++= sc.addedFiles - task.jarSet ++= sc.addedJars - } logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = new TaskSetManager(this, taskSet) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index e25a11e7c5..aa37462fb0 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -214,7 +214,8 @@ class TaskSetManager( } // Serialize and return the task val startTime = System.currentTimeMillis - val serializedTask = ser.serialize(task) + val serializedTask = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) val timeTaken = System.currentTimeMillis - startTime logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 65078b026e..53fc659345 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -7,6 +7,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.HashMap import spark._ +import executor.ExecutorURLClassLoader import spark.scheduler._ /** @@ -14,13 +15,21 @@ import spark.scheduler._ * the scheduler also allows each task to fail up to maxFailures times, which is useful for * testing fault recovery. */ -class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends TaskScheduler with Logging { +class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) + extends TaskScheduler + with Logging { + var attemptId = new AtomicInteger(0) var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) val env = SparkEnv.get var listener: TaskSchedulerListener = null - val fileSet: HashMap[String, Long] = new HashMap[String, Long]() - val jarSet: HashMap[String, Long] = new HashMap[String, Long]() + + // Application dependencies (added through SparkContext) that we've fetched so far on this node. + // Each map holds the master's timestamp for the version of that file or JAR we got. + val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() + val currentJars: HashMap[String, Long] = new HashMap[String, Long]() + + val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) // TODO: Need to take into account stage priority in scheduling @@ -35,8 +44,6 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T val failCount = new Array[Int](tasks.size) def submitTask(task: Task[_], idInJob: Int) { - task.fileSet ++= sc.addedFiles - task.jarSet ++= sc.addedJars val myAttemptId = attemptId.getAndIncrement() threadPool.submit(new Runnable { def run() { @@ -49,19 +56,23 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T logInfo("Running task " + idInJob) // Set the Spark execution environment for the worker thread SparkEnv.set(env) - task.downloadDependencies(fileSet, jarSet) - // Create a new classLaoder for the downloaded JARs - Thread.currentThread.setContextClassLoader(createClassLoader()) try { + Accumulators.clear() + Thread.currentThread().setContextClassLoader(classLoader) + // Serialize and deserialize the task so that accumulators are changed to thread-local ones; // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - Accumulators.clear val ser = SparkEnv.get.closureSerializer.newInstance() - val bytes = ser.serialize(task) + val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser) logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes") + val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) + updateDependencies(taskFiles, taskJars) // Download any files added with addFile val deserializedTask = ser.deserialize[Task[_]]( - bytes, Thread.currentThread.getContextClassLoader) + taskBytes, Thread.currentThread.getContextClassLoader) + + // Run it val result: Any = deserializedTask.run(attemptId) + // Serialize and deserialize the result to emulate what the Mesos // executor does. This is useful to catch serialization errors early // on in development (so when users move their local Spark programs @@ -90,20 +101,35 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T submitTask(task, i) } } - + + /** + * Download any missing dependencies if we receive a new set of files and JARs from the + * SparkContext. Also adds any new JARs we fetched to the class loader. + */ + private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name) + Utils.fetchFile(name, new File(".")) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name) + Utils.fetchFile(name, new File(".")) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(".", localName).toURI.toURL + if (!classLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + classLoader.addURL(url) + } + } + } override def stop() { threadPool.shutdownNow() } - private def createClassLoader() : ClassLoader = { - val currentLoader = Thread.currentThread.getContextClassLoader() - val urls = jarSet.keySet.map { uri => - new File(uri.split("/").last).toURI.toURL - }.toArray - logInfo("Creating ClassLoader with jars: " + urls.mkString) - return new URLClassLoader(urls, currentLoader) - } - override def defaultParallelism() = threads } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 2f14db4e28..8e4f9f7c15 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -395,10 +395,12 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B } def mustRegisterBlockManager(msg: RegisterBlockManager) { + logInfo("Trying to register BlockManager") while (! syncRegisterBlockManager(msg)) { logWarning("Failed to register " + msg) Thread.sleep(REQUEST_RETRY_INTERVAL_MS) } + logInfo("Done registering BlockManager") } def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala index 0ce255105a..c92b60a40c 100644 --- a/core/src/main/scala/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/spark/util/ByteBufferInputStream.scala @@ -31,4 +31,6 @@ class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream { buffer.position(buffer.position + amountToSkip) return amountToSkip } + + def position: Int = buffer.position } -- cgit v1.2.3