diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2010-10-16 16:14:13 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2010-10-16 16:14:13 -0700 |
commit | 1c082ad5fbfbcb72044a96b7c0b71329ae8e682a (patch) | |
tree | f88c22f559c5adb763a0d1ebda954e7f50b0532c | |
parent | c0b856a056f5e101bc00b5cd4a8d5b5d91e488f1 (diff) | |
download | spark-1c082ad5fbfbcb72044a96b7c0b71329ae8e682a.tar.gz spark-1c082ad5fbfbcb72044a96b7c0b71329ae8e682a.tar.bz2 spark-1c082ad5fbfbcb72044a96b7c0b71329ae8e682a.zip |
Added the ability to specify a list of JAR files when creating a
SparkContext and have the master node serve those to workers.
-rw-r--r-- | Makefile | 4 | ||||
-rw-r--r-- | src/scala/spark/Executor.scala | 154 | ||||
-rw-r--r-- | src/scala/spark/HttpServer.scala | 5 | ||||
-rw-r--r-- | src/scala/spark/MesosScheduler.scala | 91 | ||||
-rw-r--r-- | src/scala/spark/SparkContext.scala | 37 | ||||
-rw-r--r-- | src/scala/spark/Utils.scala | 45 | ||||
-rw-r--r-- | src/scala/spark/repl/SparkInterpreter.scala | 28 |
7 files changed, 247 insertions, 117 deletions
@@ -50,6 +50,8 @@ native: java jar: build/spark.jar build/spark-dep.jar +depjar: build/spark-dep.jar + build/spark.jar: scala java jar cf build/spark.jar -C build/classes spark @@ -67,4 +69,4 @@ clean: $(MAKE) -C src/native clean rm -rf build -.phony: default all clean scala java native jar +.phony: default all clean scala java native jar depjar diff --git a/src/scala/spark/Executor.scala b/src/scala/spark/Executor.scala index be73aae541..e47d3757b6 100644 --- a/src/scala/spark/Executor.scala +++ b/src/scala/spark/Executor.scala @@ -1,75 +1,115 @@ package spark +import java.io.{File, FileOutputStream} +import java.net.{URI, URL, URLClassLoader} import java.util.concurrent.{Executors, ExecutorService} +import scala.collection.mutable.ArrayBuffer + import mesos.{ExecutorArgs, ExecutorDriver, MesosExecutorDriver} import mesos.{TaskDescription, TaskState, TaskStatus} /** * The Mesos executor for Spark. */ -object Executor extends Logging { - def main(args: Array[String]) { - System.loadLibrary("mesos") +class Executor extends mesos.Executor with Logging { + var classLoader: ClassLoader = null + var threadPool: ExecutorService = null - // Create a new Executor implementation that will run our tasks - val exec = new mesos.Executor() { - var classLoader: ClassLoader = null - var threadPool: ExecutorService = null - - override def init(d: ExecutorDriver, args: ExecutorArgs) { - // Read spark.* system properties - val props = Utils.deserialize[Array[(String, String)]](args.getData) - for ((key, value) <- props) - System.setProperty(key, value) - - // Initialize broadcast system (uses some properties read above) - Broadcast.initialize(false) - - // If the REPL is in use, create a ClassLoader that will be able to - // read new classes defined by the REPL as the user types code - classLoader = this.getClass.getClassLoader - val classUri = System.getProperty("spark.repl.class.uri") - if (classUri != null) { - logInfo("Using REPL class URI: " + classUri) - classLoader = new repl.ExecutorClassLoader(classUri, classLoader) - } - Thread.currentThread.setContextClassLoader(classLoader) - - // Start worker thread pool (they will inherit our context ClassLoader) - threadPool = Executors.newCachedThreadPool() - } - - override def launchTask(d: ExecutorDriver, desc: TaskDescription) { - // Pull taskId and arg out of TaskDescription because it won't be a - // valid pointer after this method call (TODO: fix this in C++/SWIG) - val taskId = desc.getTaskId - val arg = desc.getArg - threadPool.execute(new Runnable() { - def run() = { - logInfo("Running task ID " + taskId) - try { - Accumulators.clear - val task = Utils.deserialize[Task[Any]](arg, classLoader) - val value = task.run - val accumUpdates = Accumulators.values - val result = new TaskResult(value, accumUpdates) - d.sendStatusUpdate(new TaskStatus( - taskId, TaskState.TASK_FINISHED, Utils.serialize(result))) - logInfo("Finished task ID " + taskId) - } catch { - case e: Exception => { - // TODO: Handle errors in tasks less dramatically - logError("Exception in task ID " + taskId, e) - System.exit(1) - } - } + override def init(d: ExecutorDriver, args: ExecutorArgs) { + // Read spark.* system properties from executor arg + val props = Utils.deserialize[Array[(String, String)]](args.getData) + for ((key, value) <- props) + System.setProperty(key, value) + + // Initialize broadcast system (uses some properties read above) + Broadcast.initialize(false) + + // Create our ClassLoader (using spark properties) and set it on this thread + classLoader = createClassLoader() + Thread.currentThread.setContextClassLoader(classLoader) + + // Start worker thread pool (they will inherit our context ClassLoader) + threadPool = Executors.newCachedThreadPool() + } + + override def launchTask(d: ExecutorDriver, desc: TaskDescription) { + // Pull taskId and arg out of TaskDescription because it won't be a + // valid pointer after this method call (TODO: fix this in C++/SWIG) + val taskId = desc.getTaskId + val arg = desc.getArg + threadPool.execute(new Runnable() { + def run() = { + logInfo("Running task ID " + taskId) + try { + Accumulators.clear + val task = Utils.deserialize[Task[Any]](arg, classLoader) + val value = task.run + val accumUpdates = Accumulators.values + val result = new TaskResult(value, accumUpdates) + d.sendStatusUpdate(new TaskStatus( + taskId, TaskState.TASK_FINISHED, Utils.serialize(result))) + logInfo("Finished task ID " + taskId) + } catch { + case e: Exception => { + // TODO: Handle errors in tasks less dramatically + logError("Exception in task ID " + taskId, e) + System.exit(1) } - }) + } } + }) + } + + // Create a ClassLoader for use in tasks, adding any JARs specified by the + // user or any classes created by the interpreter to the search path + private def createClassLoader(): ClassLoader = { + var loader = this.getClass.getClassLoader + + // If any JAR URIs are given through spark.jar.uris, fetch them to the + // current directory and put them all on the classpath. We assume that + // each URL has a unique file name so that no local filenames will clash + // in this process. This is guaranteed by MesosScheduler. + val uris = System.getProperty("spark.jar.uris", "") + val localFiles = ArrayBuffer[String]() + for (uri <- uris.split(",").filter(_.size > 0)) { + val url = new URL(uri) + val filename = url.getPath.split("/").last + downloadFile(url, filename) + localFiles += filename + } + if (localFiles.size > 0) { + val urls = localFiles.map(f => new File(f).toURI.toURL).toArray + loader = new URLClassLoader(urls, loader) } - // Start it running and connect it to the slave + // If the REPL is in use, add another ClassLoader that will read + // new classes defined by the REPL as the user types code + val classUri = System.getProperty("spark.repl.class.uri") + if (classUri != null) { + logInfo("Using REPL class URI: " + classUri) + loader = new repl.ExecutorClassLoader(classUri, loader) + } + + return loader + } + + // Download a file from a given URL to the local filesystem + private def downloadFile(url: URL, localPath: String) { + val in = url.openStream() + val out = new FileOutputStream(localPath) + Utils.copyStream(in, out, true) + } +} + +/** + * Executor entry point. + */ +object Executor extends Logging { + def main(args: Array[String]) { + System.loadLibrary("mesos") + // Create a new Executor and start it running + val exec = new Executor new MesosExecutorDriver(exec).run() } } diff --git a/src/scala/spark/HttpServer.scala b/src/scala/spark/HttpServer.scala index 55fb0a2218..08eb08bbe3 100644 --- a/src/scala/spark/HttpServer.scala +++ b/src/scala/spark/HttpServer.scala @@ -7,6 +7,7 @@ import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.handler.DefaultHandler import org.eclipse.jetty.server.handler.HandlerList import org.eclipse.jetty.server.handler.ResourceHandler +import org.eclipse.jetty.util.thread.QueuedThreadPool /** @@ -30,6 +31,9 @@ class HttpServer(resourceBase: File) extends Logging { throw new ServerStateException("Server is already started") } else { server = new Server(0) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) val resHandler = new ResourceHandler resHandler.setResourceBase(resourceBase.getAbsolutePath) val handlerList = new HandlerList @@ -37,7 +41,6 @@ class HttpServer(resourceBase: File) extends Logging { server.setHandler(handlerList) server.start() port = server.getConnectors()(0).getLocalPort() - logDebug("HttpServer started at " + uri) } } diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala index 470be69e50..0f7adb4826 100644 --- a/src/scala/spark/MesosScheduler.scala +++ b/src/scala/spark/MesosScheduler.scala @@ -1,10 +1,11 @@ package spark -import java.io.File +import java.io.{File, FileInputStream, FileOutputStream} import java.util.{ArrayList => JArrayList} import java.util.{List => JList} import java.util.{HashMap => JHashMap} +import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.collection.mutable.Map @@ -19,7 +20,7 @@ import mesos._ * first call start(), then submit tasks through the runTasks method. */ private class MesosScheduler( - sc: SparkContext, master: String, frameworkName: String, execArg: Array[Byte]) + sc: SparkContext, master: String, frameworkName: String) extends MScheduler with spark.Scheduler with Logging { // Environment variables to pass to our executors @@ -39,27 +40,37 @@ extends MScheduler with spark.Scheduler with Logging private var taskIdToJobId = new HashMap[Int, Int] private var jobTasks = new HashMap[Int, HashSet[Int]] + // Incrementing job and task IDs private var nextJobId = 0 - + private var nextTaskId = 0 + + // Driver for talking to Mesos + var driver: SchedulerDriver = null + + // JAR server, if any JARs were added by the user to the SparkContext + var jarServer: HttpServer = null + + // URIs of JARs to pass to executor + var jarUris: String = "" + def newJobId(): Int = this.synchronized { val id = nextJobId nextJobId += 1 return id } - // Incrementing task ID - private var nextTaskId = 0 - def newTaskId(): Int = { val id = nextTaskId; nextTaskId += 1; return id } - - // Driver for talking to Mesos - var driver: SchedulerDriver = null override def start() { + if (sc.jars.size > 0) { + // If the user added any JARS to the SparkContext, create an HTTP server + // to serve them to our executors + createJarServer() + } new Thread("Spark scheduler") { setDaemon(true) override def run { @@ -83,10 +94,11 @@ extends MScheduler with spark.Scheduler with Logging val execScript = new File(sparkHome, "spark-executor").getCanonicalPath val params = new JHashMap[String, String] for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) { - if (System.getenv(key) != null) + if (System.getenv(key) != null) { params("env." + key) = System.getenv(key) + } } - new ExecutorInfo(execScript, execArg) + new ExecutorInfo(execScript, createExecArg()) } /** @@ -220,10 +232,63 @@ extends MScheduler with spark.Scheduler with Logging } override def stop() { - if (driver != null) + if (driver != null) { driver.stop() + } + if (jarServer != null) { + jarServer.stop() + } } // TODO: query Mesos for number of cores - override def numCores() = System.getProperty("spark.default.parallelism", "2").toInt + override def numCores() = + System.getProperty("spark.default.parallelism", "2").toInt + + // Create a server for all the JARs added by the user to SparkContext. + // We first copy the JARs to a temp directory for easier server setup. + private def createJarServer() { + val jarDir = Utils.createTempDir() + logInfo("Temp directory for JARs: " + jarDir) + val filenames = ArrayBuffer[String]() + // Copy each JAR to a unique filename in the jarDir + for ((path, index) <- sc.jars.zipWithIndex) { + val file = new File(path) + val filename = index + "_" + file.getName + copyFile(file, new File(jarDir, filename)) + filenames += filename + } + // Create the server + jarServer = new HttpServer(jarDir) + jarServer.start() + // Build up the jar URI list + val serverUri = jarServer.uri + jarUris = filenames.map(f => serverUri + "/" + f).mkString(",") + logInfo("JAR server started at " + serverUri) + } + + // Copy a file on the local file system + private def copyFile(source: File, dest: File) { + val in = new FileInputStream(source) + val out = new FileOutputStream(dest) + Utils.copyStream(in, out, true) + } + + // Create and serialize the executor argument to pass to Mesos. + // Our executor arg is an array containing all the spark.* system properties + // in the form of (String, String) pairs. + private def createExecArg(): Array[Byte] = { + val props = new HashMap[String, String] + val iter = System.getProperties.entrySet.iterator + while (iter.hasNext) { + val entry = iter.next + val (key, value) = (entry.getKey.toString, entry.getValue.toString) + if (key.startsWith("spark.")) { + props(key) = value + } + } + // Set spark.jar.uris to our JAR URIs, regardless of system property + props("spark.jar.uris") = jarUris + // Serialize the map as an array of (String, String) pairs + return Utils.serialize(props.toArray) + } } diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index 3abce13a86..b9870cc3b9 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -1,12 +1,18 @@ package spark import java.io._ -import java.util.UUID import scala.collection.mutable.ArrayBuffer -class SparkContext(master: String, frameworkName: String) extends Logging { +class SparkContext( + master: String, + frameworkName: String, + val jars: Seq[String] = Nil) +extends Logging { + // Spark home directory, used to resolve executor when running on Mesos + private var sparkHome: Option[String] = None + private[spark] var scheduler: Scheduler = { // Regular expression used for local[N] master format val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r @@ -17,18 +23,16 @@ class SparkContext(master: String, frameworkName: String) extends Logging { new LocalScheduler(threads.toInt) case _ => System.loadLibrary("mesos") - new MesosScheduler(this, master, frameworkName, createExecArg()) + new MesosScheduler(this, master, frameworkName) } } - private val local = scheduler.isInstanceOf[LocalScheduler] + private val isLocal = scheduler.isInstanceOf[LocalScheduler] + // Start the scheduler and the broadcast system scheduler.start() - Broadcast.initialize(true) - private var sparkHome: Option[String] = None - // Methods for creating RDDs def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int) = @@ -45,22 +49,8 @@ class SparkContext(master: String, frameworkName: String) extends Logging { new Accumulator(initialValue, param) // TODO: Keep around a weak hash map of values to Cached versions? - def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, local) - //def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, local) - - // Create and serialize an executor argument to use when running on Mesos - private def createExecArg(): Array[Byte] = { - // Our executor arg is an array containing all the spark.* system properties - val props = new ArrayBuffer[(String, String)] - val iter = System.getProperties.entrySet.iterator - while (iter.hasNext) { - val entry = iter.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) - props += key -> value - } - return Utils.serialize(props.toArray) - } + def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, isLocal) + //def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, isLocal) // Stop the SparkContext def stop() { @@ -94,7 +84,6 @@ class SparkContext(master: String, frameworkName: String) extends Logging { None } - // Submit an array of tasks (passed as functions) to the scheduler def runTasks[T: ClassManifest](tasks: Array[() => T]): Array[T] = { runTaskObjects(tasks.map(f => new FunctionTask(f))) diff --git a/src/scala/spark/Utils.scala b/src/scala/spark/Utils.scala index 27d73aefbd..9d300d229a 100644 --- a/src/scala/spark/Utils.scala +++ b/src/scala/spark/Utils.scala @@ -1,9 +1,13 @@ package spark import java.io._ +import java.util.UUID import scala.collection.mutable.ArrayBuffer +/** + * Various utility methods used by Spark. + */ object Utils { def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream @@ -50,4 +54,45 @@ object Utils { } return buf } + + // Create a temporary directory inside the given parent directory + def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = + { + var attempts = 0 + val maxAttempts = 10 + var dir: File = null + while (dir == null) { + attempts += 1 + if (attempts > maxAttempts) { + throw new IOException("Failed to create a temp directory " + + "after " + maxAttempts + " attempts!") + } + try { + dir = new File(root, "spark-" + UUID.randomUUID.toString) + if (dir.exists() || !dir.mkdirs()) { + dir = null + } + } catch { case e: IOException => ; } + } + return dir + } + + // Copy all data from an InputStream to an OutputStream + def copyStream(in: InputStream, + out: OutputStream, + closeStreams: Boolean = false) + { + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + out.write(buf, 0, n) + } + } + if (closeStreams) { + in.close() + out.close() + } + } } diff --git a/src/scala/spark/repl/SparkInterpreter.scala b/src/scala/spark/repl/SparkInterpreter.scala index 41324333a3..10ea346658 100644 --- a/src/scala/spark/repl/SparkInterpreter.scala +++ b/src/scala/spark/repl/SparkInterpreter.scala @@ -37,6 +37,7 @@ import interpreter._ import SparkInterpreter._ import spark.HttpServer +import spark.Utils /** <p> * An interpreter for Scala code. @@ -94,27 +95,12 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { /** Local directory to save .class files too */ val outputDir = { - val rootDir = new File(System.getProperty("spark.repl.classdir", - System.getProperty("java.io.tmpdir"))) - var attempts = 0 - val maxAttempts = 10 - var dir: File = null - while (dir == null) { - attempts += 1 - if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory " + - "after " + maxAttempts + " attempts!") - } - try { - dir = new File(rootDir, "spark-" + UUID.randomUUID.toString) - if (dir.exists() || !dir.mkdirs()) - dir = null - } catch { case e: IOException => ; } - } - if (SPARK_DEBUG_REPL) { - println("Output directory: " + dir) - } - dir + val tmp = System.getProperty("java.io.tmpdir") + val rootDir = System.getProperty("spark.repl.classdir", tmp) + Utils.createTempDir(rootDir) + } + if (SPARK_DEBUG_REPL) { + println("Output directory: " + outputDir) } /** Scala compiler virtual directory for outputDir */ |