diff options
23 files changed, 922 insertions, 508 deletions
@@ -5,7 +5,7 @@ SPACE = $(EMPTY) $(EMPTY) JARS = third_party/mesos.jar JARS += third_party/asm-3.2/lib/all/asm-all-3.2.jar JARS += third_party/colt.jar -JARS += third_party/guava-r06/guava-r06.jar +JARS += third_party/guava-r07/guava-r07.jar JARS += third_party/hadoop-0.20.0/hadoop-0.20.0-core.jar JARS += third_party/hadoop-0.20.0/lib/commons-logging-1.0.4.jar JARS += third_party/scalatest-1.2/scalatest-1.2.jar @@ -34,7 +34,9 @@ else COMPILER = $(SCALA_HOME)/bin/$(COMPILER_NAME) endif -all: scala java +CONF_FILES = conf/spark-env.sh conf/log4j.properties conf/java-opts + +all: scala java conf-files build/classes: mkdir -p build/classes @@ -50,6 +52,8 @@ native: java jar: build/spark.jar build/spark-dep.jar +dep-jar: build/spark-dep.jar + build/spark.jar: scala java jar cf build/spark.jar -C build/classes spark @@ -58,6 +62,11 @@ build/spark-dep.jar: cd build/dep && for i in $(JARS); do jar xf ../../$$i; done jar cf build/spark-dep.jar -C build/dep . +conf-files: $(CONF_FILES) + +$(CONF_FILES): %: | %.template + cp $@.template $@ + test: all ./alltests @@ -67,4 +76,4 @@ clean: $(MAKE) -C src/native clean rm -rf build -.phony: default all clean scala java native jar +.phony: default all clean scala java native jar dep-jar conf-files @@ -1,3 +1,6 @@ #!/bin/bash -FWDIR=`dirname $0` +FWDIR="`dirname $0`" +if [ "x$SPARK_MEM" == "x" ]; then + export SPARK_MEM=500m +fi $FWDIR/run org.scalatest.tools.Runner -p $FWDIR/build/classes -o $@ diff --git a/conf/java-opts b/conf/java-opts.template index b61e8163b5..b61e8163b5 100644 --- a/conf/java-opts +++ b/conf/java-opts.template diff --git a/conf/log4j.properties b/conf/log4j.properties.template index d72dbadc39..d72dbadc39 100644 --- a/conf/log4j.properties +++ b/conf/log4j.properties.template diff --git a/conf/spark-env.sh b/conf/spark-env.sh.template index 77f9cb69b9..77f9cb69b9 100755 --- a/conf/spark-env.sh +++ b/conf/spark-env.sh.template @@ -1,16 +1,22 @@ #!/bin/bash # Figure out where the Scala framework is installed -FWDIR=`dirname $0` +FWDIR="$(cd `dirname $0`; pwd)" + +# Export this as SPARK_HOME +export SPARK_HOME="$FWDIR" # Load environment variables from conf/spark-env.sh, if it exists if [ -e $FWDIR/conf/spark-env.sh ] ; then . $FWDIR/conf/spark-env.sh fi +MESOS_CLASSPATH="" +MESOS_LIBRARY_PATH="" + if [ "x$MESOS_HOME" != "x" ] ; then - SPARK_CLASSPATH="$MESOS_HOME/lib/java/mesos.jar:$SPARK_CLASSPATH" - SPARK_LIBRARY_PATH="$MESOS_HOME/lib/java:$SPARK_LIBARY_PATH" + MESOS_CLASSPATH="$MESOS_HOME/lib/java/mesos.jar" + MESOS_LIBRARY_PATH="$MESOS_HOME/lib/java" fi if [ "x$SPARK_MEM" == "x" ] ; then @@ -19,7 +25,7 @@ fi # Set JAVA_OPTS to be able to load native libraries and to set heap size JAVA_OPTS="$SPARK_JAVA_OPTS" -JAVA_OPTS+=" -Djava.library.path=$SPARK_LIBRARY_PATH:$FWDIR/third_party:$FWDIR/src/native" +JAVA_OPTS+=" -Djava.library.path=$SPARK_LIBRARY_PATH:$FWDIR/third_party:$FWDIR/src/native:$MESOS_LIBRARY_PATH" JAVA_OPTS+=" -Xms$SPARK_MEM -Xmx$SPARK_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists if [ -e $FWDIR/conf/java-opts ] ; then @@ -28,12 +34,12 @@ fi export JAVA_OPTS # Build up classpath -CLASSPATH="$SPARK_CLASSPATH:$FWDIR/build/classes" +CLASSPATH="$SPARK_CLASSPATH:$FWDIR/build/classes:$MESOS_CLASSPATH" CLASSPATH+=:$FWDIR/conf CLASSPATH+=:$FWDIR/third_party/mesos.jar CLASSPATH+=:$FWDIR/third_party/asm-3.2/lib/all/asm-all-3.2.jar CLASSPATH+=:$FWDIR/third_party/colt.jar -CLASSPATH+=:$FWDIR/third_party/guava-r06/guava-r06.jar +CLASSPATH+=:$FWDIR/third_party/guava-r07/guava-r07.jar CLASSPATH+=:$FWDIR/third_party/hadoop-0.20.0/hadoop-0.20.0-core.jar CLASSPATH+=:$FWDIR/third_party/scalatest-1.2/scalatest-1.2.jar CLASSPATH+=:$FWDIR/third_party/scalacheck_2.8.0-1.7.jar diff --git a/src/scala/spark/Executor.scala b/src/scala/spark/Executor.scala index be73aae541..e47d3757b6 100644 --- a/src/scala/spark/Executor.scala +++ b/src/scala/spark/Executor.scala @@ -1,75 +1,115 @@ package spark +import java.io.{File, FileOutputStream} +import java.net.{URI, URL, URLClassLoader} import java.util.concurrent.{Executors, ExecutorService} +import scala.collection.mutable.ArrayBuffer + import mesos.{ExecutorArgs, ExecutorDriver, MesosExecutorDriver} import mesos.{TaskDescription, TaskState, TaskStatus} /** * The Mesos executor for Spark. */ -object Executor extends Logging { - def main(args: Array[String]) { - System.loadLibrary("mesos") +class Executor extends mesos.Executor with Logging { + var classLoader: ClassLoader = null + var threadPool: ExecutorService = null - // Create a new Executor implementation that will run our tasks - val exec = new mesos.Executor() { - var classLoader: ClassLoader = null - var threadPool: ExecutorService = null - - override def init(d: ExecutorDriver, args: ExecutorArgs) { - // Read spark.* system properties - val props = Utils.deserialize[Array[(String, String)]](args.getData) - for ((key, value) <- props) - System.setProperty(key, value) - - // Initialize broadcast system (uses some properties read above) - Broadcast.initialize(false) - - // If the REPL is in use, create a ClassLoader that will be able to - // read new classes defined by the REPL as the user types code - classLoader = this.getClass.getClassLoader - val classUri = System.getProperty("spark.repl.class.uri") - if (classUri != null) { - logInfo("Using REPL class URI: " + classUri) - classLoader = new repl.ExecutorClassLoader(classUri, classLoader) - } - Thread.currentThread.setContextClassLoader(classLoader) - - // Start worker thread pool (they will inherit our context ClassLoader) - threadPool = Executors.newCachedThreadPool() - } - - override def launchTask(d: ExecutorDriver, desc: TaskDescription) { - // Pull taskId and arg out of TaskDescription because it won't be a - // valid pointer after this method call (TODO: fix this in C++/SWIG) - val taskId = desc.getTaskId - val arg = desc.getArg - threadPool.execute(new Runnable() { - def run() = { - logInfo("Running task ID " + taskId) - try { - Accumulators.clear - val task = Utils.deserialize[Task[Any]](arg, classLoader) - val value = task.run - val accumUpdates = Accumulators.values - val result = new TaskResult(value, accumUpdates) - d.sendStatusUpdate(new TaskStatus( - taskId, TaskState.TASK_FINISHED, Utils.serialize(result))) - logInfo("Finished task ID " + taskId) - } catch { - case e: Exception => { - // TODO: Handle errors in tasks less dramatically - logError("Exception in task ID " + taskId, e) - System.exit(1) - } - } + override def init(d: ExecutorDriver, args: ExecutorArgs) { + // Read spark.* system properties from executor arg + val props = Utils.deserialize[Array[(String, String)]](args.getData) + for ((key, value) <- props) + System.setProperty(key, value) + + // Initialize broadcast system (uses some properties read above) + Broadcast.initialize(false) + + // Create our ClassLoader (using spark properties) and set it on this thread + classLoader = createClassLoader() + Thread.currentThread.setContextClassLoader(classLoader) + + // Start worker thread pool (they will inherit our context ClassLoader) + threadPool = Executors.newCachedThreadPool() + } + + override def launchTask(d: ExecutorDriver, desc: TaskDescription) { + // Pull taskId and arg out of TaskDescription because it won't be a + // valid pointer after this method call (TODO: fix this in C++/SWIG) + val taskId = desc.getTaskId + val arg = desc.getArg + threadPool.execute(new Runnable() { + def run() = { + logInfo("Running task ID " + taskId) + try { + Accumulators.clear + val task = Utils.deserialize[Task[Any]](arg, classLoader) + val value = task.run + val accumUpdates = Accumulators.values + val result = new TaskResult(value, accumUpdates) + d.sendStatusUpdate(new TaskStatus( + taskId, TaskState.TASK_FINISHED, Utils.serialize(result))) + logInfo("Finished task ID " + taskId) + } catch { + case e: Exception => { + // TODO: Handle errors in tasks less dramatically + logError("Exception in task ID " + taskId, e) + System.exit(1) } - }) + } } + }) + } + + // Create a ClassLoader for use in tasks, adding any JARs specified by the + // user or any classes created by the interpreter to the search path + private def createClassLoader(): ClassLoader = { + var loader = this.getClass.getClassLoader + + // If any JAR URIs are given through spark.jar.uris, fetch them to the + // current directory and put them all on the classpath. We assume that + // each URL has a unique file name so that no local filenames will clash + // in this process. This is guaranteed by MesosScheduler. + val uris = System.getProperty("spark.jar.uris", "") + val localFiles = ArrayBuffer[String]() + for (uri <- uris.split(",").filter(_.size > 0)) { + val url = new URL(uri) + val filename = url.getPath.split("/").last + downloadFile(url, filename) + localFiles += filename + } + if (localFiles.size > 0) { + val urls = localFiles.map(f => new File(f).toURI.toURL).toArray + loader = new URLClassLoader(urls, loader) } - // Start it running and connect it to the slave + // If the REPL is in use, add another ClassLoader that will read + // new classes defined by the REPL as the user types code + val classUri = System.getProperty("spark.repl.class.uri") + if (classUri != null) { + logInfo("Using REPL class URI: " + classUri) + loader = new repl.ExecutorClassLoader(classUri, loader) + } + + return loader + } + + // Download a file from a given URL to the local filesystem + private def downloadFile(url: URL, localPath: String) { + val in = url.openStream() + val out = new FileOutputStream(localPath) + Utils.copyStream(in, out, true) + } +} + +/** + * Executor entry point. + */ +object Executor extends Logging { + def main(args: Array[String]) { + System.loadLibrary("mesos") + // Create a new Executor and start it running + val exec = new Executor new MesosExecutorDriver(exec).run() } } diff --git a/src/scala/spark/HadoopFile.scala b/src/scala/spark/HadoopFile.scala new file mode 100644 index 0000000000..5746c433ee --- /dev/null +++ b/src/scala/spark/HadoopFile.scala @@ -0,0 +1,118 @@ +package spark + +import mesos.SlaveOffer + +import org.apache.hadoop.io.LongWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapred.FileInputFormat +import org.apache.hadoop.mapred.InputFormat +import org.apache.hadoop.mapred.InputSplit +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapred.RecordReader +import org.apache.hadoop.mapred.Reporter +import org.apache.hadoop.util.ReflectionUtils + +/** A Spark split class that wraps around a Hadoop InputSplit */ +@serializable class HadoopSplit(@transient s: InputSplit) +extends Split { + val inputSplit = new SerializableWritable[InputSplit](s) + + // Hadoop gives each split a unique toString value, so use this as our ID + override def getId() = "HadoopSplit(" + inputSplit.toString + ")" +} + + +/** + * An RDD that reads a Hadoop file (from HDFS, S3, the local filesystem, etc) + * and represents it as a set of key-value pairs using a given InputFormat. + */ +class HadoopFile[K, V]( + sc: SparkContext, + path: String, + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V]) +extends RDD[(K, V)](sc) { + @transient val splits_ : Array[Split] = ConfigureLock.synchronized { + val conf = new JobConf() + FileInputFormat.setInputPaths(conf, path) + val inputFormat = createInputFormat(conf) + val inputSplits = inputFormat.getSplits(conf, sc.scheduler.numCores) + inputSplits.map(x => new HadoopSplit(x): Split).toArray + } + + def createInputFormat(conf: JobConf): InputFormat[K, V] = { + ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) + .asInstanceOf[InputFormat[K, V]] + } + + override def splits = splits_ + + override def iterator(theSplit: Split) = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[HadoopSplit] + var reader: RecordReader[K, V] = null + + ConfigureLock.synchronized { + val conf = new JobConf() + val bufferSize = System.getProperty("spark.buffer.size", "65536") + conf.set("io.file.buffer.size", bufferSize) + val fmt = createInputFormat(conf) + reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) + } + + val key: K = keyClass.newInstance() + val value: V = valueClass.newInstance() + var gotNext = false + var finished = false + + override def hasNext: Boolean = { + if (!gotNext) { + try { + finished = !reader.next(key, value) + } catch { + case eofe: java.io.EOFException => + finished = true + } + gotNext = true + } + !finished + } + + override def next: (K, V) = { + if (!gotNext) { + finished = !reader.next(key, value) + } + if (finished) { + throw new java.util.NoSuchElementException("End of stream") + } + gotNext = false + (key, value) + } + } + + override def preferredLocations(split: Split) = { + // TODO: Filtering out "localhost" in case of file:// URLs + val hadoopSplit = split.asInstanceOf[HadoopSplit] + hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") + } +} + + +/** + * Convenience class for Hadoop files read using TextInputFormat that + * represents the file as an RDD of Strings. + */ +class HadoopTextFile(sc: SparkContext, path: String) +extends MappedRDD[String, (LongWritable, Text)]( + new HadoopFile(sc, path, classOf[TextInputFormat], + classOf[LongWritable], classOf[Text]), + { pair: (LongWritable, Text) => pair._2.toString } +) + + +/** + * Object used to ensure that only one thread at a time is configuring Hadoop + * InputFormat classes. Apparently configuring them is not thread safe! + */ +object ConfigureLock {} diff --git a/src/scala/spark/HdfsFile.scala b/src/scala/spark/HdfsFile.scala deleted file mode 100644 index 8637c6e30a..0000000000 --- a/src/scala/spark/HdfsFile.scala +++ /dev/null @@ -1,80 +0,0 @@ -package spark - -import mesos.SlaveOffer - -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.FileInputFormat -import org.apache.hadoop.mapred.InputSplit -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.mapred.RecordReader -import org.apache.hadoop.mapred.Reporter - -@serializable class HdfsSplit(@transient s: InputSplit) -extends Split { - val inputSplit = new SerializableWritable[InputSplit](s) - - override def getId() = inputSplit.toString // Hadoop makes this unique - // for each split of each file -} - -class HdfsTextFile(sc: SparkContext, path: String) -extends RDD[String](sc) { - @transient val conf = new JobConf() - @transient val inputFormat = new TextInputFormat() - - FileInputFormat.setInputPaths(conf, path) - ConfigureLock.synchronized { inputFormat.configure(conf) } - - @transient val splits_ = - inputFormat.getSplits(conf, sc.scheduler.numCores).map(new HdfsSplit(_)).toArray - - override def splits = splits_.asInstanceOf[Array[Split]] - - override def iterator(split_in: Split) = new Iterator[String] { - val split = split_in.asInstanceOf[HdfsSplit] - var reader: RecordReader[LongWritable, Text] = null - ConfigureLock.synchronized { - val conf = new JobConf() - conf.set("io.file.buffer.size", - System.getProperty("spark.buffer.size", "65536")) - val tif = new TextInputFormat() - tif.configure(conf) - reader = tif.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) - } - val lineNum = new LongWritable() - val text = new Text() - var gotNext = false - var finished = false - - override def hasNext: Boolean = { - if (!gotNext) { - try { - finished = !reader.next(lineNum, text) - } catch { - case eofe: java.io.EOFException => - finished = true - } - gotNext = true - } - !finished - } - - override def next: String = { - if (!gotNext) - finished = !reader.next(lineNum, text) - if (finished) - throw new java.util.NoSuchElementException("end of stream") - gotNext = false - text.toString - } - } - - override def preferredLocations(split: Split) = { - // TODO: Filtering out "localhost" in case of file:// URLs - split.asInstanceOf[HdfsSplit].inputSplit.value.getLocations().filter(_ != "localhost") - } -} - -object ConfigureLock {} diff --git a/src/scala/spark/repl/ClassServer.scala b/src/scala/spark/HttpServer.scala index 6a40d92765..d5bdd245bb 100644 --- a/src/scala/spark/repl/ClassServer.scala +++ b/src/scala/spark/HttpServer.scala @@ -1,4 +1,4 @@ -package spark.repl +package spark import java.io.File import java.net.InetAddress @@ -7,23 +7,22 @@ import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.handler.DefaultHandler import org.eclipse.jetty.server.handler.HandlerList import org.eclipse.jetty.server.handler.ResourceHandler - -import spark.Logging +import org.eclipse.jetty.util.thread.QueuedThreadPool /** - * Exception type thrown by ClassServer when it is in the wrong state + * Exception type thrown by HttpServer when it is in the wrong state * for an operation. */ class ServerStateException(message: String) extends Exception(message) /** - * An HTTP server used by the interpreter to allow worker nodes to access - * class files created as the user types in lines of code. This is just a - * wrapper around a Jetty embedded HTTP server. + * An HTTP server for static content used to allow worker nodes to access JARs + * added to SparkContext as well as classes created by the interpreter when + * the user types in code. This is just a wrapper around a Jetty server. */ -class ClassServer(classDir: File) extends Logging { +class HttpServer(resourceBase: File) extends Logging { private var server: Server = null private var port: Int = -1 @@ -32,14 +31,16 @@ class ClassServer(classDir: File) extends Logging { throw new ServerStateException("Server is already started") } else { server = new Server(0) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) val resHandler = new ResourceHandler - resHandler.setResourceBase(classDir.getAbsolutePath) + resHandler.setResourceBase(resourceBase.getAbsolutePath) val handlerList = new HandlerList handlerList.setHandlers(Array(resHandler, new DefaultHandler)) server.setHandler(handlerList) server.start() port = server.getConnectors()(0).getLocalPort() - logDebug("ClassServer started at " + uri) } } diff --git a/src/scala/spark/Job.scala b/src/scala/spark/Job.scala new file mode 100644 index 0000000000..6abbcbce51 --- /dev/null +++ b/src/scala/spark/Job.scala @@ -0,0 +1,18 @@ +package spark + +import mesos._ + +/** + * Class representing a parallel job in MesosScheduler. Schedules the + * job by implementing various callbacks. + */ +abstract class Job(jobId: Int) { + def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int) + : Option[TaskDescription] + + def statusUpdate(t: TaskStatus): Unit + + def error(code: Int, message: String): Unit + + def getId(): Int = jobId +} diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala index 873a97c59c..c45eff64d4 100644 --- a/src/scala/spark/MesosScheduler.scala +++ b/src/scala/spark/MesosScheduler.scala @@ -1,103 +1,130 @@ package spark -import java.io.File +import java.io.{File, FileInputStream, FileOutputStream} +import java.util.{ArrayList => JArrayList} +import java.util.{List => JList} +import java.util.{HashMap => JHashMap} +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet import scala.collection.mutable.Map import scala.collection.mutable.Queue -import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ -import mesos.{Scheduler => NScheduler} +import mesos.{Scheduler => MScheduler} import mesos._ -// The main Scheduler implementation, which talks to Mesos. Clients are expected -// to first call start(), then submit tasks through the runTasks method. -// -// This implementation is currently a little quick and dirty. The following -// improvements need to be made to it: -// 1) Right now, the scheduler uses a linear scan through the tasks to find a -// local one for a given node. It would be faster to have a separate list of -// pending tasks for each node. -// 2) Presenting a single slave in ParallelOperation.slaveOffer makes it -// difficult to balance tasks across nodes. It would be better to pass -// all the offers to the ParallelOperation and have it load-balance. +/** + * The main Scheduler implementation, which runs jobs on Mesos. Clients should + * first call start(), then submit tasks through the runTasks method. + */ private class MesosScheduler( - master: String, frameworkName: String, execArg: Array[Byte]) -extends NScheduler with spark.Scheduler with Logging + sc: SparkContext, master: String, frameworkName: String) +extends MScheduler with spark.Scheduler with Logging { - // Lock used by runTasks to ensure only one thread can be in it - val runTasksMutex = new Object() + // Environment variables to pass to our executors + val ENV_VARS_TO_SEND_TO_EXECUTORS = Array( + "SPARK_MEM", + "SPARK_CLASSPATH", + "SPARK_LIBRARY_PATH" + ) // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() + private var isRegistered = false + private val registeredLock = new Object() - // Current callback object (may be null) - var activeOpsQueue = new Queue[Int] - var activeOps = new HashMap[Int, ParallelOperation] - private var nextOpId = 0 - private[spark] var taskIdToOpId = new HashMap[Int, Int] - - def newOpId(): Int = { - val id = nextOpId - nextOpId += 1 - return id - } + private var activeJobs = new HashMap[Int, Job] + private var activeJobsQueue = new Queue[Job] + + private var taskIdToJobId = new HashMap[Int, Int] + private var jobTasks = new HashMap[Int, HashSet[Int]] - // Incrementing task ID + // Incrementing job and task IDs + private var nextJobId = 0 private var nextTaskId = 0 + // Driver for talking to Mesos + var driver: SchedulerDriver = null + + // JAR server, if any JARs were added by the user to the SparkContext + var jarServer: HttpServer = null + + // URIs of JARs to pass to executor + var jarUris: String = "" + + def newJobId(): Int = this.synchronized { + val id = nextJobId + nextJobId += 1 + return id + } + def newTaskId(): Int = { val id = nextTaskId; nextTaskId += 1; return id } - - // Driver for talking to Mesos - var driver: SchedulerDriver = null override def start() { + if (sc.jars.size > 0) { + // If the user added any JARS to the SparkContext, create an HTTP server + // to serve them to our executors + createJarServer() + } new Thread("Spark scheduler") { setDaemon(true) override def run { - val ns = MesosScheduler.this - ns.driver = new MesosSchedulerDriver(ns, master) - ns.driver.run() + val sched = MesosScheduler.this + sched.driver = new MesosSchedulerDriver(sched, master) + sched.driver.run() } }.start } override def getFrameworkName(d: SchedulerDriver): String = frameworkName - override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo = - new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg) + override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo = { + val sparkHome = sc.getSparkHome match { + case Some(path) => path + case None => + throw new SparkException("Spark home is not set; set it through the " + + "spark.home system property, the SPARK_HOME environment variable " + + "or the SparkContext constructor") + } + val execScript = new File(sparkHome, "spark-executor").getCanonicalPath + val params = new JHashMap[String, String] + for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) { + if (System.getenv(key) != null) { + params("env." + key) = System.getenv(key) + } + } + new ExecutorInfo(execScript, createExecArg()) + } + /** + * The primary means to submit a job to the scheduler. Given a list of tasks, + * runs them and returns an array of the results. + */ override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = { - var opId = 0 waitForRegister() - this.synchronized { - opId = newOpId() - } - val myOp = new SimpleParallelOperation(this, tasks, opId) - + val jobId = newJobId() + val myJob = new SimpleJob(this, tasks, jobId) try { this.synchronized { - this.activeOps(myOp.opId) = myOp - this.activeOpsQueue += myOp.opId + activeJobs(jobId) = myJob + activeJobsQueue += myJob + jobTasks(jobId) = new HashSet() } driver.reviveOffers(); - myOp.join(); + return myJob.join(); } finally { this.synchronized { - this.activeOps.remove(myOp.opId) - this.activeOpsQueue.dequeueAll(x => (x == myOp.opId)) + activeJobs -= jobId + activeJobsQueue.dequeueAll(x => (x == myJob)) + taskIdToJobId --= jobTasks(jobId) + jobTasks.remove(jobId) } } - - if (myOp.errorHappened) - throw new SparkException(myOp.errorMessage, myOp.errorCode) - else - return myOp.results } override def registered(d: SchedulerDriver, frameworkId: String) { @@ -115,51 +142,68 @@ extends NScheduler with spark.Scheduler with Logging } } + /** + * Method called by Mesos to offer resources on slaves. We resond by asking + * our active jobs for tasks in FIFO order. We fill each node with tasks in + * a round-robin manner so that tasks are balanced across the cluster. + */ override def resourceOffer( - d: SchedulerDriver, oid: String, offers: java.util.List[SlaveOffer]) { + d: SchedulerDriver, oid: String, offers: JList[SlaveOffer]) { synchronized { - val tasks = new java.util.ArrayList[TaskDescription] + val tasks = new JArrayList[TaskDescription] val availableCpus = offers.map(_.getParams.get("cpus").toInt) val availableMem = offers.map(_.getParams.get("mem").toInt) - var launchedTask = true - for (opId <- activeOpsQueue) { - launchedTask = true - while (launchedTask) { + var launchedTask = false + for (job <- activeJobsQueue) { + do { launchedTask = false for (i <- 0 until offers.size.toInt) { try { - activeOps(opId).slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match { + job.slaveOffer(offers(i), availableCpus(i), availableMem(i)) match { case Some(task) => tasks.add(task) + taskIdToJobId(task.getTaskId) = job.getId + jobTasks(job.getId) += task.getTaskId availableCpus(i) -= task.getParams.get("cpus").toInt availableMem(i) -= task.getParams.get("mem").toInt - launchedTask = launchedTask || true + launchedTask = true case None => {} } } catch { case e: Exception => logError("Exception in resourceOffer", e) } } - } + } while (launchedTask) } - val params = new java.util.HashMap[String, String] + val params = new JHashMap[String, String] params.put("timeout", "1") - d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout + d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout? } } + // Check whether a Mesos task state represents a finished task + def isFinished(state: TaskState) = { + state == TaskState.TASK_FINISHED || + state == TaskState.TASK_FAILED || + state == TaskState.TASK_KILLED || + state == TaskState.TASK_LOST + } + override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { synchronized { try { - taskIdToOpId.get(status.getTaskId) match { - case Some(opId) => - if (activeOps.contains(opId)) { - activeOps(opId).statusUpdate(status) + taskIdToJobId.get(status.getTaskId) match { + case Some(jobId) => + if (activeJobs.contains(jobId)) { + activeJobs(jobId).statusUpdate(status) + } + if (isFinished(status.getState)) { + taskIdToJobId.remove(status.getTaskId) + jobTasks(jobId) -= status.getTaskId } case None => logInfo("TID " + status.getTaskId + " already finished") } - } catch { case e: Exception => logError("Exception in statusUpdate", e) } @@ -167,180 +211,84 @@ extends NScheduler with spark.Scheduler with Logging } override def error(d: SchedulerDriver, code: Int, message: String) { + logError("Mesos error: %s (error code: %d)".format(message, code)) synchronized { - if (activeOps.size > 0) { - for ((opId, activeOp) <- activeOps) { + if (activeJobs.size > 0) { + // Have each job throw a SparkException with the error + for ((jobId, activeJob) <- activeJobs) { try { - activeOp.error(code, message) + activeJob.error(code, message) } catch { case e: Exception => logError("Exception in error callback", e) } } } else { - logError("Mesos error: %s (error code: %d)".format(message, code)) + // No jobs are active but we still got an error. Just exit since this + // must mean the error is during registration. + // It might be good to do something smarter here in the future. System.exit(1) } } } override def stop() { - if (driver != null) + if (driver != null) { driver.stop() - } - - // TODO: query Mesos for number of cores - override def numCores() = System.getProperty("spark.default.parallelism", "2").toInt -} - - -// Trait representing an object that manages a parallel operation by -// implementing various scheduler callbacks. -trait ParallelOperation { - def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int): Option[TaskDescription] - def statusUpdate(t: TaskStatus): Unit - def error(code: Int, message: String): Unit -} - - -class SimpleParallelOperation[T: ClassManifest]( - sched: MesosScheduler, tasks: Array[Task[T]], val opId: Int) -extends ParallelOperation with Logging -{ - // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - - val callingThread = currentThread - val numTasks = tasks.length - val results = new Array[T](numTasks) - val launched = new Array[Boolean](numTasks) - val finished = new Array[Boolean](numTasks) - val tidToIndex = Map[Int, Int]() - - var allFinished = false - val joinLock = new Object() - - var errorHappened = false - var errorCode = 0 - var errorMessage = "" - - var tasksLaunched = 0 - var tasksFinished = 0 - var lastPreferredLaunchTime = System.currentTimeMillis - - def setAllFinished() { - joinLock.synchronized { - allFinished = true - joinLock.notifyAll() } - } - - def join() { - joinLock.synchronized { - while (!allFinished) - joinLock.wait() - } - } - - def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int): Option[TaskDescription] = { - if (tasksLaunched < numTasks) { - var checkPrefVals: Array[Boolean] = Array(true) - val time = System.currentTimeMillis - if (time - lastPreferredLaunchTime > LOCALITY_WAIT) - checkPrefVals = Array(true, false) // Allow non-preferred tasks - // TODO: Make desiredCpus and desiredMem configurable - val desiredCpus = 1 - val desiredMem = 500 - if ((availableCpus < desiredCpus) || (availableMem < desiredMem)) - return None - for (checkPref <- checkPrefVals; i <- 0 until numTasks) { - if (!launched(i) && (!checkPref || - tasks(i).preferredLocations.contains(offer.getHost) || - tasks(i).preferredLocations.isEmpty)) - { - val taskId = sched.newTaskId() - sched.taskIdToOpId(taskId) = opId - tidToIndex(taskId) = i - val preferred = if(checkPref) "preferred" else "non-preferred" - val message = - "Starting task %d as opId %d, TID %s on slave %s: %s (%s)".format( - i, opId, taskId, offer.getSlaveId, offer.getHost, preferred) - logInfo(message) - tasks(i).markStarted(offer) - launched(i) = true - tasksLaunched += 1 - if (checkPref) - lastPreferredLaunchTime = time - val params = new java.util.HashMap[String, String] - params.put("cpus", "" + desiredCpus) - params.put("mem", "" + desiredMem) - val serializedTask = Utils.serialize(tasks(i)) - //logInfo("Serialized size: " + serializedTask.size) - return Some(new TaskDescription(taskId, offer.getSlaveId, - "task_" + taskId, params, serializedTask)) - } - } + if (jarServer != null) { + jarServer.stop() } - return None } - def statusUpdate(status: TaskStatus) { - status.getState match { - case TaskState.TASK_FINISHED => - taskFinished(status) - case TaskState.TASK_LOST => - taskLost(status) - case TaskState.TASK_FAILED => - taskLost(status) - case TaskState.TASK_KILLED => - taskLost(status) - case _ => + // TODO: query Mesos for number of cores + override def numCores() = + System.getProperty("spark.default.parallelism", "2").toInt + + // Create a server for all the JARs added by the user to SparkContext. + // We first copy the JARs to a temp directory for easier server setup. + private def createJarServer() { + val jarDir = Utils.createTempDir() + logInfo("Temp directory for JARs: " + jarDir) + val filenames = ArrayBuffer[String]() + // Copy each JAR to a unique filename in the jarDir + for ((path, index) <- sc.jars.zipWithIndex) { + val file = new File(path) + val filename = index + "_" + file.getName + copyFile(file, new File(jarDir, filename)) + filenames += filename } + // Create the server + jarServer = new HttpServer(jarDir) + jarServer.start() + // Build up the jar URI list + val serverUri = jarServer.uri + jarUris = filenames.map(f => serverUri + "/" + f).mkString(",") + logInfo("JAR server started at " + serverUri) } - def taskFinished(status: TaskStatus) { - val tid = status.getTaskId - val index = tidToIndex(tid) - if (!finished(index)) { - tasksFinished += 1 - logInfo("Finished opId %d TID %d (progress: %d/%d)".format( - opId, tid, tasksFinished, numTasks)) - // Deserialize task result - val result = Utils.deserialize[TaskResult[T]](status.getData) - results(index) = result.value - // Update accumulators - Accumulators.add(callingThread, result.accumUpdates) - // Mark finished and stop if we've finished all the tasks - finished(index) = true - // Remove TID -> opId mapping from sched - sched.taskIdToOpId.remove(tid) - if (tasksFinished == numTasks) - setAllFinished() - } else { - logInfo("Ignoring task-finished event for TID " + tid + - " because task " + index + " is already finished") - } + // Copy a file on the local file system + private def copyFile(source: File, dest: File) { + val in = new FileInputStream(source) + val out = new FileOutputStream(dest) + Utils.copyStream(in, out, true) } - def taskLost(status: TaskStatus) { - val tid = status.getTaskId - val index = tidToIndex(tid) - if (!finished(index)) { - logInfo("Lost opId " + opId + " TID " + tid) - launched(index) = false - sched.taskIdToOpId.remove(tid) - tasksLaunched -= 1 - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") + // Create and serialize the executor argument to pass to Mesos. + // Our executor arg is an array containing all the spark.* system properties + // in the form of (String, String) pairs. + private def createExecArg(): Array[Byte] = { + val props = new HashMap[String, String] + val iter = System.getProperties.entrySet.iterator + while (iter.hasNext) { + val entry = iter.next + val (key, value) = (entry.getKey.toString, entry.getValue.toString) + if (key.startsWith("spark.")) { + props(key) = value + } } - } - - def error(code: Int, message: String) { - // Save the error message - errorHappened = true - errorCode = code - errorMessage = message - // Indicate to caller thread that we're done - setAllFinished() + // Set spark.jar.uris to our JAR URIs, regardless of system property + props("spark.jar.uris") = jarUris + // Serialize the map as an array of (String, String) pairs + return Utils.serialize(props.toArray) } } diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala index 803c063865..9dd8bc9dce 100644 --- a/src/scala/spark/RDD.scala +++ b/src/scala/spark/RDD.scala @@ -78,15 +78,14 @@ abstract class RDD[T: ClassManifest]( case _ => throw new UnsupportedOperationException("empty collection") } - def count(): Long = + def count(): Long = try { map(x => 1L).reduce(_+_) } catch { case e: UnsupportedOperationException => 0L } - def union(other: RDD[T]) = new UnionRDD(sc, this, other) + def union(other: RDD[T]) = new UnionRDD(sc, Array(this, other)) def cartesian[U: ClassManifest](other: RDD[U]) = new CartesianRDD(sc, this, other) def ++(other: RDD[T]) = this.union(other) - } @serializable @@ -129,7 +128,7 @@ extends RDDTask[Option[T], T](rdd, split) with Logging { } class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], f: T => U) + prev: RDD[T], f: T => U) extends RDD[U](prev.sparkContext) { override def splits = prev.splits override def preferredLocations(split: Split) = prev.preferredLocations(split) @@ -138,7 +137,7 @@ extends RDD[U](prev.sparkContext) { } class FilteredRDD[T: ClassManifest]( - prev: RDD[T], f: T => Boolean) + prev: RDD[T], f: T => Boolean) extends RDD[T](prev.sparkContext) { override def splits = prev.splits override def preferredLocations(split: Split) = prev.preferredLocations(split) @@ -147,7 +146,7 @@ extends RDD[T](prev.sparkContext) { } class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], f: T => Traversable[U]) + prev: RDD[T], f: T => Traversable[U]) extends RDD[U](prev.sparkContext) { override def splits = prev.splits override def preferredLocations(split: Split) = prev.preferredLocations(split) @@ -156,7 +155,7 @@ extends RDD[U](prev.sparkContext) { override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) } -class SplitRDD[T: ClassManifest](prev: RDD[T]) +class SplitRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.sparkContext) { override def splits = prev.splits override def preferredLocations(split: Split) = prev.preferredLocations(split) @@ -171,16 +170,16 @@ extends RDD[Array[T]](prev.sparkContext) { } class SampledRDD[T: ClassManifest]( - prev: RDD[T], withReplacement: Boolean, frac: Double, seed: Int) + prev: RDD[T], withReplacement: Boolean, frac: Double, seed: Int) extends RDD[T](prev.sparkContext) { - + @transient val splits_ = { val rg = new Random(seed); prev.splits.map(x => new SeededSplit(x, rg.nextInt)) } override def splits = splits_.asInstanceOf[Array[Split]] override def preferredLocations(split: Split) = prev.preferredLocations(split.asInstanceOf[SeededSplit].prev) - override def iterator(splitIn: Split) = { + override def iterator(splitIn: Split) = { val split = splitIn.asInstanceOf[SeededSplit] val rg = new Random(split.seed); // Sampling with replacement (TODO: use reservoir sampling to make this more efficient?) @@ -214,7 +213,7 @@ extends RDD[T](prev.sparkContext) with Logging { else prev.preferredLocations(split) } - + override def iterator(split: Split): Iterator[T] = { val key = id + "::" + split.getId() logInfo("CachedRDD split key is " + key) @@ -268,38 +267,29 @@ private object CachedRDD { } @serializable -abstract class UnionSplit[T: ClassManifest] extends Split { - def iterator(): Iterator[T] - def preferredLocations(): Seq[String] - def getId(): String -} - -@serializable -class UnionSplitImpl[T: ClassManifest]( - rdd: RDD[T], split: Split) -extends UnionSplit[T] { - override def iterator() = rdd.iterator(split) - override def preferredLocations() = rdd.preferredLocations(split) - override def getId() = - "UnionSplitImpl(" + split.getId() + ")" +class UnionSplit[T: ClassManifest](rdd: RDD[T], split: Split) +extends Split { + def iterator() = rdd.iterator(split) + def preferredLocations() = rdd.preferredLocations(split) + override def getId() = "UnionSplit(" + split.getId() + ")" } @serializable -class UnionRDD[T: ClassManifest]( - sc: SparkContext, rdd1: RDD[T], rdd2: RDD[T]) +class UnionRDD[T: ClassManifest](sc: SparkContext, rdds: Seq[RDD[T]]) extends RDD[T](sc) { - - @transient val splits_ : Array[UnionSplit[T]] = { - val a1 = rdd1.splits.map(s => new UnionSplitImpl(rdd1, s)) - val a2 = rdd2.splits.map(s => new UnionSplitImpl(rdd2, s)) - (a1 ++ a2).toArray + @transient val splits_ : Array[Split] = { + val splits: Seq[Split] = + for (rdd <- rdds; split <- rdd.splits) + yield new UnionSplit(rdd, split) + splits.toArray } - override def splits = splits_.asInstanceOf[Array[Split]] + override def splits = splits_ - override def iterator(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() + override def iterator(s: Split): Iterator[T] = + s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = + override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() } diff --git a/src/scala/spark/SimpleJob.scala b/src/scala/spark/SimpleJob.scala new file mode 100644 index 0000000000..b15d0522d4 --- /dev/null +++ b/src/scala/spark/SimpleJob.scala @@ -0,0 +1,255 @@ +package spark + +import java.util.{HashMap => JHashMap} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import mesos._ + + +/** + * A Job that runs a set of tasks with no interdependencies. + */ +class SimpleJob[T: ClassManifest]( + sched: MesosScheduler, tasks: Array[Task[T]], val jobId: Int) +extends Job(jobId) with Logging +{ + // Maximum time to wait to run a task in a preferred location (in ms) + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + + // CPUs and memory to request per task + val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt + val MEM_PER_TASK = System.getProperty("spark.task.mem", "512").toInt + + // Maximum times a task is allowed to fail before failing the job + val MAX_TASK_FAILURES = 4 + + val callingThread = currentThread + val numTasks = tasks.length + val results = new Array[T](numTasks) + val launched = new Array[Boolean](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val tidToIndex = HashMap[Int, Int]() + + var allFinished = false + val joinLock = new Object() // Used to wait for all tasks to finish + + var tasksLaunched = 0 + var tasksFinished = 0 + + // Last time when we launched a preferred task (for delay scheduling) + var lastPreferredLaunchTime = System.currentTimeMillis + + // List of pending tasks for each node. These collections are actually + // treated as stacks, in which new tasks are added to the end of the + // ArrayBuffer and removed from the end. This makes it faster to detect + // tasks that repeatedly fail because whenever a task failed, it is put + // back at the head of the stack. They are also only cleaned up lazily; + // when a task is launched, it remains in all the pending lists except + // the one that it was launched from, but gets removed from them later. + val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List containing all pending tasks (also used as a stack, as above) + val allPendingTasks = new ArrayBuffer[Int] + + // Did the job fail? + var failed = false + var causeOfFailure = "" + + // Add all our tasks to the pending lists. We do this in reverse order + // of task index so that tasks with low indices get launched first. + for (i <- (0 until numTasks).reverse) { + addPendingTask(i) + } + + // Add a task to all the pending-task lists that it should be on. + def addPendingTask(index: Int) { + allPendingTasks += index + for (host <- tasks(index).preferredLocations) { + val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + list += index + } + } + + // Mark the job as finished and wake up any threads waiting on it + def setAllFinished() { + joinLock.synchronized { + allFinished = true + joinLock.notifyAll() + } + } + + // Wait until the job finishes and return its results + def join(): Array[T] = { + joinLock.synchronized { + while (!allFinished) { + joinLock.wait() + } + if (failed) { + throw new SparkException(causeOfFailure) + } else { + return results + } + } + } + + // Return the pending tasks list for a given host, or an empty list if + // there is no map entry for that host + def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + pendingTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Dequeue a pending task from the given list and return its index. + // Return None if the list is empty. + def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + while (!list.isEmpty) { + val index = list.last + list.trimEnd(1) + if (!launched(index) && !finished(index)) { + return Some(index) + } + } + return None + } + + // Dequeue a pending task for a given node and return its index. + // If localOnly is set to false, allow non-local tasks as well. + def findTask(host: String, localOnly: Boolean): Option[Int] = { + findTaskFromList(getPendingTasksForHost(host)) match { + case Some(task) => Some(task) + case None => + if (localOnly) None + else findTaskFromList(allPendingTasks) + } + } + + // Does a host count as a preferred location for a task? This is true if + // either the task has preferred locations and this host is one, or it has + // no preferred locations (in which we still count the launch as preferred). + def isPreferredLocation(task: Task[T], host: String): Boolean = { + val locs = task.preferredLocations + return (locs.contains(host) || locs.isEmpty) + } + + // Respond to an offer of a single slave from the scheduler by finding a task + def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int) + : Option[TaskDescription] = { + if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK && + availableMem >= MEM_PER_TASK) { + val time = System.currentTimeMillis + val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) + val host = offer.getHost + findTask(host, localOnly) match { + case Some(index) => { + // Found a task; do some bookkeeping and return a Mesos task for it + val task = tasks(index) + val taskId = sched.newTaskId() + // Figure out whether this should count as a preferred launch + val preferred = isPreferredLocation(task, host) + val prefStr = if(preferred) "preferred" else "non-preferred" + val message = + "Starting task %d:%d as TID %s on slave %s: %s (%s)".format( + jobId, index, taskId, offer.getSlaveId, host, prefStr) + logInfo(message) + // Do various bookkeeping + tidToIndex(taskId) = index + task.markStarted(offer) + launched(index) = true + tasksLaunched += 1 + if (preferred) + lastPreferredLaunchTime = time + // Create and return the Mesos task object + val params = new JHashMap[String, String] + params.put("cpus", CPUS_PER_TASK.toString) + params.put("mem", MEM_PER_TASK.toString) + val serializedTask = Utils.serialize(task) + logDebug("Serialized size: " + serializedTask.size) + val taskName = "task %d:%d".format(jobId, index) + return Some(new TaskDescription( + taskId, offer.getSlaveId, taskName, params, serializedTask)) + } + case _ => + } + } + return None + } + + def statusUpdate(status: TaskStatus) { + status.getState match { + case TaskState.TASK_FINISHED => + taskFinished(status) + case TaskState.TASK_LOST => + taskLost(status) + case TaskState.TASK_FAILED => + taskLost(status) + case TaskState.TASK_KILLED => + taskLost(status) + case _ => + } + } + + def taskFinished(status: TaskStatus) { + val tid = status.getTaskId + val index = tidToIndex(tid) + if (!finished(index)) { + tasksFinished += 1 + logInfo("Finished TID %d (progress: %d/%d)".format( + tid, tasksFinished, numTasks)) + // Deserialize task result + val result = Utils.deserialize[TaskResult[T]](status.getData) + results(index) = result.value + // Update accumulators + Accumulators.add(callingThread, result.accumUpdates) + // Mark finished and stop if we've finished all the tasks + finished(index) = true + if (tasksFinished == numTasks) + setAllFinished() + } else { + logInfo("Ignoring task-finished event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def taskLost(status: TaskStatus) { + val tid = status.getTaskId + val index = tidToIndex(tid) + if (!finished(index)) { + logInfo("Lost TID %d (task %d:%d)".format(tid, jobId, index)) + launched(index) = false + tasksLaunched -= 1 + // Re-enqueue the task as pending + addPendingTask(index) + // Mark it as failed + if (status.getState == TaskState.TASK_FAILED || + status.getState == TaskState.TASK_LOST) { + numFailures(index) += 1 + if (numFailures(index) > MAX_TASK_FAILURES) { + logError("Task %d:%d failed more than %d times; aborting job".format( + jobId, index, MAX_TASK_FAILURES)) + abort("Task %d failed more than %d times".format( + index, MAX_TASK_FAILURES)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def error(code: Int, message: String) { + // Save the error message + abort("Mesos error: %s (error code: %d)".format(message, code)) + } + + def abort(message: String) { + joinLock.synchronized { + failed = true + causeOfFailure = message + // TODO: Kill running tasks if we were not terminated due to a Mesos error + // Indicate to any joining thread that we're done + setAllFinished() + } + } +} diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index 90bea8921a..e85b26e238 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -1,59 +1,136 @@ package spark import java.io._ -import java.util.UUID import scala.collection.mutable.ArrayBuffer -import scala.actors.Actor._ -class SparkContext(master: String, frameworkName: String) extends Logging { +import org.apache.hadoop.mapred.InputFormat +import org.apache.hadoop.mapred.SequenceFileInputFormat + + +class SparkContext( + master: String, + frameworkName: String, + val sparkHome: String = null, + val jars: Seq[String] = Nil) +extends Logging { + private[spark] var scheduler: Scheduler = { + // Regular expression used for local[N] master format + val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r + master match { + case "local" => + new LocalScheduler(1) + case LOCAL_N_REGEX(threads) => + new LocalScheduler(threads.toInt) + case _ => + System.loadLibrary("mesos") + new MesosScheduler(this, master, frameworkName) + } + } + + private val isLocal = scheduler.isInstanceOf[LocalScheduler] + + // Start the scheduler and the broadcast system + scheduler.start() Broadcast.initialize(true) - def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int) = + // Methods for creating RDDs + + def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int): RDD[T] = new ParallelArray[T](this, seq, numSlices) - def parallelize[T: ClassManifest](seq: Seq[T]): ParallelArray[T] = + def parallelize[T: ClassManifest](seq: Seq[T]): RDD[T] = parallelize(seq, scheduler.numCores) - def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = - new Accumulator(initialValue, param) + def textFile(path: String): RDD[String] = + new HadoopTextFile(this, path) - // TODO: Keep around a weak hash map of values to Cached versions? - def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, local) - // def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, local) + /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ + def hadoopFile[K, V](path: String, + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V]) + : RDD[(K, V)] = { + new HadoopFile(this, path, inputFormatClass, keyClass, valueClass) + } - def textFile(path: String) = new HdfsTextFile(this, path) + /** + * Smarter version of hadoopFile() that uses class manifests to figure out + * the classes of keys, values and the InputFormat so that users don't need + * to pass them directly. + */ + def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) + (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]) + : RDD[(K, V)] = { + hadoopFile(path, + fm.erasure.asInstanceOf[Class[F]], + km.erasure.asInstanceOf[Class[K]], + vm.erasure.asInstanceOf[Class[V]]) + } - val LOCAL_REGEX = """local\[([0-9]+)\]""".r + /** Get an RDD for a Hadoop SequenceFile with given key and value types */ + def sequenceFile[K, V](path: String, + keyClass: Class[K], + valueClass: Class[V]): RDD[(K, V)] = { + val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] + hadoopFile(path, inputFormatClass, keyClass, valueClass) + } - private[spark] var scheduler: Scheduler = master match { - case "local" => new LocalScheduler(1) - case LOCAL_REGEX(threads) => new LocalScheduler(threads.toInt) - case _ => { System.loadLibrary("mesos"); - new MesosScheduler(master, frameworkName, createExecArg()) } + /** + * Smarter version of sequenceFile() that obtains the key and value classes + * from ClassManifests instead of requiring the user to pass them directly. + */ + def sequenceFile[K, V](path: String) + (implicit km: ClassManifest[K], vm: ClassManifest[V]): RDD[(K, V)] = { + sequenceFile(path, + km.erasure.asInstanceOf[Class[K]], + vm.erasure.asInstanceOf[Class[V]]) } - private val local = scheduler.isInstanceOf[LocalScheduler] + /** Build the union of a list of RDDs. */ + def union[T: ClassManifest](rdds: RDD[T]*): RDD[T] = + new UnionRDD(this, rdds) - scheduler.start() + // Methods for creating shared variables - private def createExecArg(): Array[Byte] = { - // Our executor arg is an array containing all the spark.* system properties - val props = new ArrayBuffer[(String, String)] - val iter = System.getProperties.entrySet.iterator - while (iter.hasNext) { - val entry = iter.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) - props += key -> value - } - return Utils.serialize(props.toArray) + def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = + new Accumulator(initialValue, param) + + // TODO: Keep around a weak hash map of values to Cached versions? + def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, isLocal) + //def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, isLocal) + + // Stop the SparkContext + def stop() { + scheduler.stop() + scheduler = null } + // Wait for the scheduler to be registered + def waitForRegister() { + scheduler.waitForRegister() + } + + // Get Spark's home location from either a value set through the constructor, + // or the spark.home Java property, or the SPARK_HOME environment variable + // (in that order of preference). If neither of these is set, return None. + def getSparkHome(): Option[String] = { + if (sparkHome != null) + Some(sparkHome) + else if (System.getProperty("spark.home") != null) + Some(System.getProperty("spark.home")) + else if (System.getenv("SPARK_HOME") != null) + Some(System.getenv("SPARK_HOME")) + else + None + } + + // Submit an array of tasks (passed as functions) to the scheduler def runTasks[T: ClassManifest](tasks: Array[() => T]): Array[T] = { runTaskObjects(tasks.map(f => new FunctionTask(f))) } + // Run an array of spark.Task objects private[spark] def runTaskObjects[T: ClassManifest](tasks: Seq[Task[T]]) : Array[T] = { logInfo("Running " + tasks.length + " tasks in parallel") @@ -63,15 +140,6 @@ class SparkContext(master: String, frameworkName: String) extends Logging { return result } - def stop() { - scheduler.stop() - scheduler = null - } - - def waitForRegister() { - scheduler.waitForRegister() - } - // Clean a closure to make it ready to serialized and send to tasks // (removes unreferenced variables in $outer's, updates REPL variables) private[spark] def clean[F <: AnyRef](f: F): F = { @@ -80,6 +148,11 @@ class SparkContext(master: String, frameworkName: String) extends Logging { } } + +/** + * The SparkContext object contains a number of implicit conversions and + * parameters for use with various Spark features. + */ object SparkContext { implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 diff --git a/src/scala/spark/SparkException.scala b/src/scala/spark/SparkException.scala index 7257bf7b0c..6f9be1a94f 100644 --- a/src/scala/spark/SparkException.scala +++ b/src/scala/spark/SparkException.scala @@ -1,7 +1,3 @@ package spark -class SparkException(message: String) extends Exception(message) { - def this(message: String, errorCode: Int) { - this("%s (error code: %d)".format(message, errorCode)) - } -} +class SparkException(message: String) extends Exception(message) {} diff --git a/src/scala/spark/Utils.scala b/src/scala/spark/Utils.scala index 27d73aefbd..9d300d229a 100644 --- a/src/scala/spark/Utils.scala +++ b/src/scala/spark/Utils.scala @@ -1,9 +1,13 @@ package spark import java.io._ +import java.util.UUID import scala.collection.mutable.ArrayBuffer +/** + * Various utility methods used by Spark. + */ object Utils { def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream @@ -50,4 +54,45 @@ object Utils { } return buf } + + // Create a temporary directory inside the given parent directory + def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = + { + var attempts = 0 + val maxAttempts = 10 + var dir: File = null + while (dir == null) { + attempts += 1 + if (attempts > maxAttempts) { + throw new IOException("Failed to create a temp directory " + + "after " + maxAttempts + " attempts!") + } + try { + dir = new File(root, "spark-" + UUID.randomUUID.toString) + if (dir.exists() || !dir.mkdirs()) { + dir = null + } + } catch { case e: IOException => ; } + } + return dir + } + + // Copy all data from an InputStream to an OutputStream + def copyStream(in: InputStream, + out: OutputStream, + closeStreams: Boolean = false) + { + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + out.write(buf, 0, n) + } + } + if (closeStreams) { + in.close() + out.close() + } + } } diff --git a/src/scala/spark/repl/SparkInterpreter.scala b/src/scala/spark/repl/SparkInterpreter.scala index ae2e7e8a68..10ea346658 100644 --- a/src/scala/spark/repl/SparkInterpreter.scala +++ b/src/scala/spark/repl/SparkInterpreter.scala @@ -36,6 +36,9 @@ import scala.tools.nsc.{ InterpreterResults => IR } import interpreter._ import SparkInterpreter._ +import spark.HttpServer +import spark.Utils + /** <p> * An interpreter for Scala code. * </p> @@ -92,27 +95,12 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { /** Local directory to save .class files too */ val outputDir = { - val rootDir = new File(System.getProperty("spark.repl.classdir", - System.getProperty("java.io.tmpdir"))) - var attempts = 0 - val maxAttempts = 10 - var dir: File = null - while (dir == null) { - attempts += 1 - if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory " + - "after " + maxAttempts + " attempts!") - } - try { - dir = new File(rootDir, "spark-" + UUID.randomUUID.toString) - if (dir.exists() || !dir.mkdirs()) - dir = null - } catch { case e: IOException => ; } - } - if (SPARK_DEBUG_REPL) { - println("Output directory: " + dir) - } - dir + val tmp = System.getProperty("java.io.tmpdir") + val rootDir = System.getProperty("spark.repl.classdir", tmp) + Utils.createTempDir(rootDir) + } + if (SPARK_DEBUG_REPL) { + println("Output directory: " + outputDir) } /** Scala compiler virtual directory for outputDir */ @@ -120,14 +108,14 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { val virtualDirectory = new PlainFile(outputDir) /** Jetty server that will serve our classes to worker nodes */ - val classServer = new ClassServer(outputDir) + val classServer = new HttpServer(outputDir) // Start the classServer and store its URI in a spark system property // (which will be passed to executors so that they can connect to it) classServer.start() System.setProperty("spark.repl.class.uri", classServer.uri) if (SPARK_DEBUG_REPL) { - println("ClassServer started, URI = " + classServer.uri) + println("Class server started, URI = " + classServer.uri) } /** reporter */ diff --git a/src/test/spark/repl/ReplSuite.scala b/src/test/spark/repl/ReplSuite.scala index dcf71182ec..9f41c77ead 100644 --- a/src/test/spark/repl/ReplSuite.scala +++ b/src/test/spark/repl/ReplSuite.scala @@ -39,9 +39,9 @@ class ReplSuite extends FunSuite { test ("external vars") { val output = runInterpreter("local", """ var v = 7 - sc.parallelize(1 to 10).map(x => v).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) v = 10 - sc.parallelize(1 to 10).map(x => v).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -54,7 +54,7 @@ class ReplSuite extends FunSuite { class C { def foo = 5 } - sc.parallelize(1 to 10).map(x => (new C).foo).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_) """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -64,7 +64,7 @@ class ReplSuite extends FunSuite { test ("external functions") { val output = runInterpreter("local", """ def double(x: Int) = x + x - sc.parallelize(1 to 10).map(x => double(x)).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -75,9 +75,9 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ var v = 7 def getV() = v - sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) v = 10 - sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -92,9 +92,9 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ var array = new Array[Int](5) val broadcastArray = sc.broadcast(array) - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect array(0) = 5 - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -103,23 +103,27 @@ class ReplSuite extends FunSuite { } test ("running on Mesos") { - val output = runInterpreter("localquiet", """ - var v = 7 - def getV() = v - sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) - v = 10 - sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) - var array = new Array[Int](5) - val broadcastArray = sc.broadcast(array) - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray - array(0) = 5 - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray - """) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - assertContains("res0: Int = 70", output) - assertContains("res1: Int = 100", output) - assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) - assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + if (System.getenv("MESOS_HOME") != null) { + val output = runInterpreter("localquiet", """ + var v = 7 + def getV() = v + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + v = 10 + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + var array = new Array[Int](5) + val broadcastArray = sc.broadcast(array) + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + array(0) = 5 + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } else { + info("Skipping \"running on Mesos\" test because MESOS_HOME is not set"); + } } } diff --git a/third_party/guava-r06/guava-r06.jar b/third_party/guava-r06/guava-r06.jar Binary files differdeleted file mode 100644 index 8ff3a81748..0000000000 --- a/third_party/guava-r06/guava-r06.jar +++ /dev/null diff --git a/third_party/guava-r06/COPYING b/third_party/guava-r07/COPYING index d645695673..d645695673 100644 --- a/third_party/guava-r06/COPYING +++ b/third_party/guava-r07/COPYING diff --git a/third_party/guava-r06/README b/third_party/guava-r07/README index a0e832dd54..a0e832dd54 100644 --- a/third_party/guava-r06/README +++ b/third_party/guava-r07/README diff --git a/third_party/guava-r07/guava-r07.jar b/third_party/guava-r07/guava-r07.jar Binary files differnew file mode 100644 index 0000000000..a6c9ce02df --- /dev/null +++ b/third_party/guava-r07/guava-r07.jar |