diff options
author | Mosharaf Chowdhury <mosharaf@mosharaf-ubuntu.(none)> | 2010-11-27 01:27:20 -0800 |
---|---|---|
committer | Mosharaf Chowdhury <mosharaf@mosharaf-ubuntu.(none)> | 2010-11-27 01:27:20 -0800 |
commit | 19a6b194a9aceb3d7997f0acfc99530ac2792be4 (patch) | |
tree | 7565c6f7c334f13b583d22d7fdfb569e3dd12671 | |
parent | e4b8db45aef934929dbab443156375aebb1ea45e (diff) | |
parent | f8ea98d9894d72feb7e8cd3951a576b24b448397 (diff) | |
download | spark-19a6b194a9aceb3d7997f0acfc99530ac2792be4.tar.gz spark-19a6b194a9aceb3d7997f0acfc99530ac2792be4.tar.bz2 spark-19a6b194a9aceb3d7997f0acfc99530ac2792be4.zip |
Merge branch 'master' into multi-tracker
Conflicts:
Makefile
run
src/scala/spark/Broadcast.scala
src/scala/spark/Executor.scala
src/scala/spark/HdfsFile.scala
src/scala/spark/MesosScheduler.scala
src/scala/spark/RDD.scala
src/scala/spark/SparkContext.scala
src/scala/spark/Split.scala
src/scala/spark/Utils.scala
src/scala/spark/repl/SparkInterpreter.scala
third_party/mesos.jar
38 files changed, 1957 insertions, 520 deletions
diff --git a/.gitignore b/.gitignore index 2d12458a44..5abdec5d50 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,8 @@ build work .DS_Store +third_party/libmesos.so +third_party/libmesos.dylib +conf/java-opts +conf/spark-env.sh +conf/log4j.properties @@ -5,7 +5,7 @@ SPACE = $(EMPTY) $(EMPTY) JARS = third_party/mesos.jar JARS += third_party/asm-3.2/lib/all/asm-all-3.2.jar JARS += third_party/colt.jar -JARS += third_party/guava-r06/guava-r06.jar +JARS += third_party/guava-r07/guava-r07.jar JARS += third_party/hadoop-0.20.0/hadoop-0.20.0-core.jar JARS += third_party/hadoop-0.20.0/lib/commons-logging-1.0.4.jar JARS += third_party/scalatest-1.2/scalatest-1.2.jar @@ -34,13 +34,15 @@ else COMPILER = $(SCALA_HOME)/bin/$(COMPILER_NAME) endif -all: scala java +CONF_FILES = conf/spark-env.sh conf/log4j.properties conf/java-opts + +all: scala java conf-files build/classes: mkdir -p build/classes scala: build/classes java - $(COMPILER) -unchecked -d build/classes -classpath build/classes:$(CLASSPATH) $(SCALA_SOURCES) + $(COMPILER) -d build/classes -classpath build/classes:$(CLASSPATH) $(SCALA_SOURCES) java: $(JAVA_SOURCES) build/classes javac -d build/classes $(JAVA_SOURCES) @@ -50,6 +52,8 @@ native: java jar: build/spark.jar build/spark-dep.jar +dep-jar: build/spark-dep.jar + build/spark.jar: scala java jar cf build/spark.jar -C build/classes spark @@ -58,6 +62,11 @@ build/spark-dep.jar: cd build/dep && for i in $(JARS); do jar xf ../../$$i; done jar cf build/spark-dep.jar -C build/dep . +conf-files: $(CONF_FILES) + +$(CONF_FILES): %: | %.template + cp $@.template $@ + test: all ./alltests @@ -67,4 +76,4 @@ clean: $(MAKE) -C src/native clean rm -rf build -.phony: default all clean scala java native jar +.phony: default all clean scala java native jar dep-jar conf-files @@ -1,3 +1,11 @@ #!/bin/bash -FWDIR=`dirname $0` -$FWDIR/run org.scalatest.tools.Runner -p $FWDIR/build/classes -o $@ +FWDIR="`dirname $0`" +if [ "x$SPARK_MEM" == "x" ]; then + export SPARK_MEM=500m +fi +RESULTS_DIR="$FWDIR/build/test_results" +if [ -d $RESULTS_DIR ]; then + rm -r $RESULTS_DIR +fi +mkdir -p $RESULTS_DIR +$FWDIR/run org.scalatest.tools.Runner -p $FWDIR/build/classes -u $RESULTS_DIR -o $@ diff --git a/conf/java-opts.template b/conf/java-opts.template new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/conf/java-opts.template diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template new file mode 100644 index 0000000000..d72dbadc39 --- /dev/null +++ b/conf/log4j.properties.template @@ -0,0 +1,8 @@ +# Set everything to be logged to the console +log4j.rootCategory=INFO, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template new file mode 100755 index 0000000000..6852b23a34 --- /dev/null +++ b/conf/spark-env.sh.template @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +# Set Spark environment variables for your site in this file. Some useful +# variables to set are: +# - MESOS_HOME, to point to your Mesos installation +# - SCALA_HOME, to point to your Scala installation +# - SPARK_CLASSPATH, to add elements to Spark's classpath +# - SPARK_JAVA_OPTS, to add JVM options +# - SPARK_MEM, to change the amount of memory used per node (this should +# be in the same format as the JVM's -Xmx option, e.g. 300m or 1g). +# - SPARK_LIBRARY_PATH, to add extra search paths for native libraries. + + @@ -1,16 +1,22 @@ #!/bin/bash # Figure out where the Scala framework is installed -FWDIR=`dirname $0` +FWDIR="$(cd `dirname $0`; pwd)" + +# Export this as SPARK_HOME +export SPARK_HOME="$FWDIR" # Load environment variables from conf/spark-env.sh, if it exists if [ -e $FWDIR/conf/spark-env.sh ] ; then . $FWDIR/conf/spark-env.sh fi +MESOS_CLASSPATH="" +MESOS_LIBRARY_PATH="" + if [ "x$MESOS_HOME" != "x" ] ; then - SPARK_CLASSPATH="$MESOS_HOME/lib/java/mesos.jar:$SPARK_CLASSPATH" - SPARK_LIBRARY_PATH="$MESOS_HOME/lib/java:$SPARK_LIBARY_PATH" + MESOS_CLASSPATH="$MESOS_HOME/lib/java/mesos.jar" + MESOS_LIBRARY_PATH="$MESOS_HOME/lib/java" fi if [ "x$SPARK_MEM" == "x" ] ; then @@ -19,7 +25,7 @@ fi # Set JAVA_OPTS to be able to load native libraries and to set heap size JAVA_OPTS="$SPARK_JAVA_OPTS" -JAVA_OPTS+=" -Djava.library.path=$SPARK_LIBRARY_PATH:$FWDIR/third_party:$FWDIR/src/native" +JAVA_OPTS+=" -Djava.library.path=$SPARK_LIBRARY_PATH:$FWDIR/third_party:$FWDIR/src/native:$MESOS_LIBRARY_PATH" JAVA_OPTS+=" -Xms$SPARK_MEM -Xmx$SPARK_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists if [ -e $FWDIR/conf/java-opts ] ; then @@ -28,12 +34,12 @@ fi export JAVA_OPTS # Build up classpath -CLASSPATH="$SPARK_CLASSPATH:$FWDIR/build/classes" +CLASSPATH="$SPARK_CLASSPATH:$FWDIR/build/classes:$MESOS_CLASSPATH" CLASSPATH+=:$FWDIR/conf CLASSPATH+=:$FWDIR/third_party/mesos.jar CLASSPATH+=:$FWDIR/third_party/asm-3.2/lib/all/asm-all-3.2.jar CLASSPATH+=:$FWDIR/third_party/colt.jar -CLASSPATH+=:$FWDIR/third_party/guava-r06/guava-r06.jar +CLASSPATH+=:$FWDIR/third_party/guava-r07/guava-r07.jar CLASSPATH+=:$FWDIR/third_party/hadoop-0.20.0/hadoop-0.20.0-core.jar CLASSPATH+=:$FWDIR/third_party/scalatest-1.2/scalatest-1.2.jar CLASSPATH+=:$FWDIR/third_party/scalacheck_2.8.0-1.7.jar diff --git a/src/examples/BroadcastTest.scala b/src/examples/BroadcastTest.scala index 7764013413..40c2be8f6d 100644 --- a/src/examples/BroadcastTest.scala +++ b/src/examples/BroadcastTest.scala @@ -10,15 +10,19 @@ object BroadcastTest { val slices = if (args.length > 1) args(1).toInt else 2 val num = if (args.length > 2) args(2).toInt else 1000000 - var arr = new Array[Int](num) - for (i <- 0 until arr.length) - arr(i) = i + var arr1 = new Array[Int](num) + for (i <- 0 until arr1.length) + arr1(i) = i - val barr = spark.broadcast(arr) +// var arr2 = new Array[Int](num * 2) +// for (i <- 0 until arr2.length) +// arr2(i) = i + + val barr1 = spark.broadcast(arr1) +// val barr2 = spark.broadcast(arr2) spark.parallelize(1 to 10, slices).foreach { - println("in task: barr = " + barr) - i => println(barr.value.size) +// i => println(barr1.value.size + barr2.value.size) + i => println(barr1.value.size) } } } - diff --git a/src/examples/SparkPi.scala b/src/examples/SparkPi.scala index 07311908ee..f055614125 100644 --- a/src/examples/SparkPi.scala +++ b/src/examples/SparkPi.scala @@ -5,7 +5,7 @@ import SparkContext._ object SparkPi { def main(args: Array[String]) { if (args.length == 0) { - System.err.println("Usage: SparkLR <host> [<slices>]") + System.err.println("Usage: SparkPi <host> [<slices>]") System.exit(1) } val spark = new SparkContext(args(0), "SparkPi") diff --git a/src/scala/spark/BoundedMemoryCache.scala b/src/scala/spark/BoundedMemoryCache.scala new file mode 100644 index 0000000000..19d9bebfe5 --- /dev/null +++ b/src/scala/spark/BoundedMemoryCache.scala @@ -0,0 +1,69 @@ +package spark + +import java.util.LinkedHashMap + +/** + * An implementation of Cache that estimates the sizes of its entries and + * attempts to limit its total memory usage to a fraction of the JVM heap. + * Objects' sizes are estimated using SizeEstimator, which has limitations; + * most notably, we will overestimate total memory used if some cache + * entries have pointers to a shared object. Nonetheless, this Cache should + * work well when most of the space is used by arrays of primitives or of + * simple classes. + */ +class BoundedMemoryCache extends Cache with Logging { + private val maxBytes: Long = getMaxBytes() + logInfo("BoundedMemoryCache.maxBytes = " + maxBytes) + + private var currentBytes = 0L + private val map = new LinkedHashMap[Any, Entry](32, 0.75f, true) + + // An entry in our map; stores a cached object and its size in bytes + class Entry(val value: Any, val size: Long) {} + + override def get(key: Any): Any = { + synchronized { + val entry = map.get(key) + if (entry != null) entry.value else null + } + } + + override def put(key: Any, value: Any) { + logInfo("Asked to add key " + key) + val startTime = System.currentTimeMillis + val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef]) + val timeTaken = System.currentTimeMillis - startTime + logInfo("Estimated size for key %s is %d".format(key, size)) + logInfo("Size estimation for key %s took %d ms".format(key, timeTaken)) + synchronized { + ensureFreeSpace(size) + logInfo("Adding key " + key) + map.put(key, new Entry(value, size)) + currentBytes += size + logInfo("Number of entries is now " + map.size) + } + } + + private def getMaxBytes(): Long = { + val memoryFractionToUse = System.getProperty( + "spark.boundedMemoryCache.memoryFraction", "0.75").toDouble + (Runtime.getRuntime.totalMemory * memoryFractionToUse).toLong + } + + /** + * Remove least recently used entries from the map until at least space + * bytes are free. Assumes that a lock is held on the BoundedMemoryCache. + */ + private def ensureFreeSpace(space: Long) { + logInfo("ensureFreeSpace(%d) called with curBytes=%d, maxBytes=%d".format( + space, currentBytes, maxBytes)) + val iter = map.entrySet.iterator + while (maxBytes - currentBytes < space && iter.hasNext) { + val mapEntry = iter.next() + logInfo("Dropping key %s of size %d to make space".format( + mapEntry.getKey, mapEntry.getValue.size)) + currentBytes -= mapEntry.getValue.size + iter.remove() + } + } +} diff --git a/src/scala/spark/Cache.scala b/src/scala/spark/Cache.scala new file mode 100644 index 0000000000..9887520758 --- /dev/null +++ b/src/scala/spark/Cache.scala @@ -0,0 +1,63 @@ +package spark + +import java.util.concurrent.atomic.AtomicLong + + +/** + * An interface for caches in Spark, to allow for multiple implementations. + * Caches are used to store both partitions of cached RDDs and broadcast + * variables on Spark executors. + * + * A single Cache instance gets created on each machine and is shared by all + * caches (i.e. both the RDD split cache and the broadcast variable cache), + * to enable global replacement policies. However, because these several + * independent modules all perform caching, it is important to give them + * separate key namespaces, so that an RDD and a broadcast variable (for + * example) do not use the same key. For this purpose, Cache has the + * notion of KeySpaces. Each client module must first ask for a KeySpace, + * and then call get() and put() on that space using its own keys. + * This abstract class handles the creation of key spaces, so that subclasses + * need only deal with keys that are unique across modules. + */ +abstract class Cache { + private val nextKeySpaceId = new AtomicLong(0) + private def newKeySpaceId() = nextKeySpaceId.getAndIncrement() + + def newKeySpace() = new KeySpace(this, newKeySpaceId()) + + def get(key: Any): Any + def put(key: Any, value: Any): Unit +} + + +/** + * A key namespace in a Cache. + */ +class KeySpace(cache: Cache, id: Long) { + def get(key: Any): Any = cache.get((id, key)) + def put(key: Any, value: Any): Unit = cache.put((id, key), value) +} + + +/** + * The Cache object maintains a global Cache instance, of the type specified + * by the spark.cache.class property. + */ +object Cache { + private var instance: Cache = null + + def initialize() { + val cacheClass = System.getProperty("spark.cache.class", + "spark.SoftReferenceCache") + instance = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] + } + + def getInstance(): Cache = { + if (instance == null) { + throw new SparkException("Cache.getInstance called before initialize") + } + instance + } + + def newKeySpace(): KeySpace = getInstance().newKeySpace() +} diff --git a/src/scala/spark/DfsShuffle.scala b/src/scala/spark/DfsShuffle.scala new file mode 100644 index 0000000000..7a42bf2d06 --- /dev/null +++ b/src/scala/spark/DfsShuffle.scala @@ -0,0 +1,120 @@ +package spark + +import java.io.{EOFException, ObjectInputStream, ObjectOutputStream} +import java.net.URI +import java.util.UUID + +import scala.collection.mutable.HashMap + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} + + +/** + * A simple implementation of shuffle using a distributed file system. + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class DfsShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val dir = DfsShuffle.newTempDirectory() + logInfo("Intermediate data directory: " + dir) + + val numberedSplitRdd = new NumberedSplitRDD(input) + val numInputSplits = numberedSplitRdd.splits.size + + // Run a parallel foreach to write the intermediate data files + numberedSplitRdd.foreach((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + val fs = DfsShuffle.getFileSystem() + for (i <- 0 until numOutputSplits) { + val path = new Path(dir, "%d-to-%d".format(myIndex, i)) + val out = new ObjectOutputStream(fs.create(path, true)) + buckets(i).foreach(pair => out.writeObject(pair)) + out.close() + } + }) + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myIndex: Int) => { + val combiners = new HashMap[K, C] + val fs = DfsShuffle.getFileSystem() + for (i <- Utils.shuffle(0 until numInputSplits)) { + val path = new Path(dir, "%d-to-%d".format(i, myIndex)) + val inputStream = new ObjectInputStream(fs.open(path)) + try { + while (true) { + val (k, c) = inputStream.readObject().asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => {} + } + inputStream.close() + } + combiners + }) + } +} + + +/** + * Companion object of DfsShuffle; responsible for initializing a Hadoop + * FileSystem object based on the spark.dfs property and generating names + * for temporary directories. + */ +object DfsShuffle { + private var initialized = false + private var fileSystem: FileSystem = null + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val dfs = System.getProperty("spark.dfs", "file:///") + val conf = new Configuration() + conf.setInt("io.file.buffer.size", bufferSize) + conf.setInt("dfs.replication", 1) + fileSystem = FileSystem.get(new URI(dfs), conf) + initialized = true + } + } + + def getFileSystem(): FileSystem = { + initializeIfNeeded() + return fileSystem + } + + def newTempDirectory(): String = { + val fs = getFileSystem() + val workDir = System.getProperty("spark.dfs.workdir", "/tmp") + val uuid = UUID.randomUUID() + val path = workDir + "/shuffle-" + uuid + fs.mkdirs(new Path(path)) + return path + } +} diff --git a/src/scala/spark/Executor.scala b/src/scala/spark/Executor.scala index be73aae541..b4d023b428 100644 --- a/src/scala/spark/Executor.scala +++ b/src/scala/spark/Executor.scala @@ -1,75 +1,116 @@ package spark +import java.io.{File, FileOutputStream} +import java.net.{URI, URL, URLClassLoader} import java.util.concurrent.{Executors, ExecutorService} +import scala.collection.mutable.ArrayBuffer + import mesos.{ExecutorArgs, ExecutorDriver, MesosExecutorDriver} import mesos.{TaskDescription, TaskState, TaskStatus} /** * The Mesos executor for Spark. */ -object Executor extends Logging { - def main(args: Array[String]) { - System.loadLibrary("mesos") +class Executor extends mesos.Executor with Logging { + var classLoader: ClassLoader = null + var threadPool: ExecutorService = null - // Create a new Executor implementation that will run our tasks - val exec = new mesos.Executor() { - var classLoader: ClassLoader = null - var threadPool: ExecutorService = null + override def init(d: ExecutorDriver, args: ExecutorArgs) { + // Read spark.* system properties from executor arg + val props = Utils.deserialize[Array[(String, String)]](args.getData) + for ((key, value) <- props) + System.setProperty(key, value) - override def init(d: ExecutorDriver, args: ExecutorArgs) { - // Read spark.* system properties - val props = Utils.deserialize[Array[(String, String)]](args.getData) - for ((key, value) <- props) - System.setProperty(key, value) - - // Initialize broadcast system (uses some properties read above) - Broadcast.initialize(false) - - // If the REPL is in use, create a ClassLoader that will be able to - // read new classes defined by the REPL as the user types code - classLoader = this.getClass.getClassLoader - val classUri = System.getProperty("spark.repl.class.uri") - if (classUri != null) { - logInfo("Using REPL class URI: " + classUri) - classLoader = new repl.ExecutorClassLoader(classUri, classLoader) - } - Thread.currentThread.setContextClassLoader(classLoader) - - // Start worker thread pool (they will inherit our context ClassLoader) - threadPool = Executors.newCachedThreadPool() - } - - override def launchTask(d: ExecutorDriver, desc: TaskDescription) { - // Pull taskId and arg out of TaskDescription because it won't be a - // valid pointer after this method call (TODO: fix this in C++/SWIG) - val taskId = desc.getTaskId - val arg = desc.getArg - threadPool.execute(new Runnable() { - def run() = { - logInfo("Running task ID " + taskId) - try { - Accumulators.clear - val task = Utils.deserialize[Task[Any]](arg, classLoader) - val value = task.run - val accumUpdates = Accumulators.values - val result = new TaskResult(value, accumUpdates) - d.sendStatusUpdate(new TaskStatus( - taskId, TaskState.TASK_FINISHED, Utils.serialize(result))) - logInfo("Finished task ID " + taskId) - } catch { - case e: Exception => { - // TODO: Handle errors in tasks less dramatically - logError("Exception in task ID " + taskId, e) - System.exit(1) - } - } + // Initialize cache and broadcast system (uses some properties read above) + Cache.initialize() + Broadcast.initialize(false) + + // Create our ClassLoader (using spark properties) and set it on this thread + classLoader = createClassLoader() + Thread.currentThread.setContextClassLoader(classLoader) + + // Start worker thread pool (they will inherit our context ClassLoader) + threadPool = Executors.newCachedThreadPool() + } + + override def launchTask(d: ExecutorDriver, desc: TaskDescription) { + // Pull taskId and arg out of TaskDescription because it won't be a + // valid pointer after this method call (TODO: fix this in C++/SWIG) + val taskId = desc.getTaskId + val arg = desc.getArg + threadPool.execute(new Runnable() { + def run() = { + logInfo("Running task ID " + taskId) + try { + Accumulators.clear + val task = Utils.deserialize[Task[Any]](arg, classLoader) + val value = task.run + val accumUpdates = Accumulators.values + val result = new TaskResult(value, accumUpdates) + d.sendStatusUpdate(new TaskStatus( + taskId, TaskState.TASK_FINISHED, Utils.serialize(result))) + logInfo("Finished task ID " + taskId) + } catch { + case e: Exception => { + // TODO: Handle errors in tasks less dramatically + logError("Exception in task ID " + taskId, e) + System.exit(1) } - }) + } } + }) + } + + // Create a ClassLoader for use in tasks, adding any JARs specified by the + // user or any classes created by the interpreter to the search path + private def createClassLoader(): ClassLoader = { + var loader = this.getClass.getClassLoader + + // If any JAR URIs are given through spark.jar.uris, fetch them to the + // current directory and put them all on the classpath. We assume that + // each URL has a unique file name so that no local filenames will clash + // in this process. This is guaranteed by MesosScheduler. + val uris = System.getProperty("spark.jar.uris", "") + val localFiles = ArrayBuffer[String]() + for (uri <- uris.split(",").filter(_.size > 0)) { + val url = new URL(uri) + val filename = url.getPath.split("/").last + downloadFile(url, filename) + localFiles += filename + } + if (localFiles.size > 0) { + val urls = localFiles.map(f => new File(f).toURI.toURL).toArray + loader = new URLClassLoader(urls, loader) } - // Start it running and connect it to the slave + // If the REPL is in use, add another ClassLoader that will read + // new classes defined by the REPL as the user types code + val classUri = System.getProperty("spark.repl.class.uri") + if (classUri != null) { + logInfo("Using REPL class URI: " + classUri) + loader = new repl.ExecutorClassLoader(classUri, loader) + } + + return loader + } + + // Download a file from a given URL to the local filesystem + private def downloadFile(url: URL, localPath: String) { + val in = url.openStream() + val out = new FileOutputStream(localPath) + Utils.copyStream(in, out, true) + } +} + +/** + * Executor entry point. + */ +object Executor extends Logging { + def main(args: Array[String]) { + System.loadLibrary("mesos") + // Create a new Executor and start it running + val exec = new Executor new MesosExecutorDriver(exec).run() } } diff --git a/src/scala/spark/HadoopFile.scala b/src/scala/spark/HadoopFile.scala new file mode 100644 index 0000000000..a63c9d8a94 --- /dev/null +++ b/src/scala/spark/HadoopFile.scala @@ -0,0 +1,118 @@ +package spark + +import mesos.SlaveOffer + +import org.apache.hadoop.io.LongWritable +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapred.FileInputFormat +import org.apache.hadoop.mapred.InputFormat +import org.apache.hadoop.mapred.InputSplit +import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapred.RecordReader +import org.apache.hadoop.mapred.Reporter +import org.apache.hadoop.util.ReflectionUtils + +/** A Spark split class that wraps around a Hadoop InputSplit */ +@serializable class HadoopSplit(@transient s: InputSplit) +extends Split { + val inputSplit = new SerializableWritable[InputSplit](s) + + // Hadoop gives each split a unique toString value, so use this as our ID + override def getId() = "HadoopSplit(" + inputSplit.toString + ")" +} + + +/** + * An RDD that reads a Hadoop file (from HDFS, S3, the local filesystem, etc) + * and represents it as a set of key-value pairs using a given InputFormat. + */ +class HadoopFile[K, V]( + sc: SparkContext, + path: String, + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V]) +extends RDD[(K, V)](sc) { + @transient val splits_ : Array[Split] = ConfigureLock.synchronized { + val conf = new JobConf() + FileInputFormat.setInputPaths(conf, path) + val inputFormat = createInputFormat(conf) + val inputSplits = inputFormat.getSplits(conf, sc.numCores) + inputSplits.map(x => new HadoopSplit(x): Split).toArray + } + + def createInputFormat(conf: JobConf): InputFormat[K, V] = { + ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) + .asInstanceOf[InputFormat[K, V]] + } + + override def splits = splits_ + + override def iterator(theSplit: Split) = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[HadoopSplit] + var reader: RecordReader[K, V] = null + + ConfigureLock.synchronized { + val conf = new JobConf() + val bufferSize = System.getProperty("spark.buffer.size", "65536") + conf.set("io.file.buffer.size", bufferSize) + val fmt = createInputFormat(conf) + reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) + } + + val key: K = keyClass.newInstance() + val value: V = valueClass.newInstance() + var gotNext = false + var finished = false + + override def hasNext: Boolean = { + if (!gotNext) { + try { + finished = !reader.next(key, value) + } catch { + case eofe: java.io.EOFException => + finished = true + } + gotNext = true + } + !finished + } + + override def next: (K, V) = { + if (!gotNext) { + finished = !reader.next(key, value) + } + if (finished) { + throw new java.util.NoSuchElementException("End of stream") + } + gotNext = false + (key, value) + } + } + + override def preferredLocations(split: Split) = { + // TODO: Filtering out "localhost" in case of file:// URLs + val hadoopSplit = split.asInstanceOf[HadoopSplit] + hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") + } +} + + +/** + * Convenience class for Hadoop files read using TextInputFormat that + * represents the file as an RDD of Strings. + */ +class HadoopTextFile(sc: SparkContext, path: String) +extends MappedRDD[String, (LongWritable, Text)]( + new HadoopFile(sc, path, classOf[TextInputFormat], + classOf[LongWritable], classOf[Text]), + { pair: (LongWritable, Text) => pair._2.toString } +) + + +/** + * Object used to ensure that only one thread at a time is configuring Hadoop + * InputFormat classes. Apparently configuring them is not thread safe! + */ +object ConfigureLock {} diff --git a/src/scala/spark/HdfsFile.scala b/src/scala/spark/HdfsFile.scala deleted file mode 100644 index 8637c6e30a..0000000000 --- a/src/scala/spark/HdfsFile.scala +++ /dev/null @@ -1,80 +0,0 @@ -package spark - -import mesos.SlaveOffer - -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.FileInputFormat -import org.apache.hadoop.mapred.InputSplit -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.mapred.RecordReader -import org.apache.hadoop.mapred.Reporter - -@serializable class HdfsSplit(@transient s: InputSplit) -extends Split { - val inputSplit = new SerializableWritable[InputSplit](s) - - override def getId() = inputSplit.toString // Hadoop makes this unique - // for each split of each file -} - -class HdfsTextFile(sc: SparkContext, path: String) -extends RDD[String](sc) { - @transient val conf = new JobConf() - @transient val inputFormat = new TextInputFormat() - - FileInputFormat.setInputPaths(conf, path) - ConfigureLock.synchronized { inputFormat.configure(conf) } - - @transient val splits_ = - inputFormat.getSplits(conf, sc.scheduler.numCores).map(new HdfsSplit(_)).toArray - - override def splits = splits_.asInstanceOf[Array[Split]] - - override def iterator(split_in: Split) = new Iterator[String] { - val split = split_in.asInstanceOf[HdfsSplit] - var reader: RecordReader[LongWritable, Text] = null - ConfigureLock.synchronized { - val conf = new JobConf() - conf.set("io.file.buffer.size", - System.getProperty("spark.buffer.size", "65536")) - val tif = new TextInputFormat() - tif.configure(conf) - reader = tif.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) - } - val lineNum = new LongWritable() - val text = new Text() - var gotNext = false - var finished = false - - override def hasNext: Boolean = { - if (!gotNext) { - try { - finished = !reader.next(lineNum, text) - } catch { - case eofe: java.io.EOFException => - finished = true - } - gotNext = true - } - !finished - } - - override def next: String = { - if (!gotNext) - finished = !reader.next(lineNum, text) - if (finished) - throw new java.util.NoSuchElementException("end of stream") - gotNext = false - text.toString - } - } - - override def preferredLocations(split: Split) = { - // TODO: Filtering out "localhost" in case of file:// URLs - split.asInstanceOf[HdfsSplit].inputSplit.value.getLocations().filter(_ != "localhost") - } -} - -object ConfigureLock {} diff --git a/src/scala/spark/HttpServer.scala b/src/scala/spark/HttpServer.scala new file mode 100644 index 0000000000..d2a663ac1f --- /dev/null +++ b/src/scala/spark/HttpServer.scala @@ -0,0 +1,67 @@ +package spark + +import java.io.File +import java.net.InetAddress + +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.handler.DefaultHandler +import org.eclipse.jetty.server.handler.HandlerList +import org.eclipse.jetty.server.handler.ResourceHandler +import org.eclipse.jetty.util.thread.QueuedThreadPool + + +/** + * Exception type thrown by HttpServer when it is in the wrong state + * for an operation. + */ +class ServerStateException(message: String) extends Exception(message) + + +/** + * An HTTP server for static content used to allow worker nodes to access JARs + * added to SparkContext as well as classes created by the interpreter when + * the user types in code. This is just a wrapper around a Jetty server. + */ +class HttpServer(resourceBase: File) extends Logging { + private var server: Server = null + private var port: Int = -1 + + def start() { + if (server != null) { + throw new ServerStateException("Server is already started") + } else { + server = new Server(0) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val resHandler = new ResourceHandler + resHandler.setResourceBase(resourceBase.getAbsolutePath) + val handlerList = new HandlerList + handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + server.setHandler(handlerList) + server.start() + port = server.getConnectors()(0).getLocalPort() + } + } + + def stop() { + if (server == null) { + throw new ServerStateException("Server is already stopped") + } else { + server.stop() + port = -1 + server = null + } + } + + /** + * Get the URI of this HTTP server (http://host:port) + */ + def uri: String = { + if (server == null) { + throw new ServerStateException("Server is not started") + } else { + return "http://" + Utils.localIpAddress + ":" + port + } + } +} diff --git a/src/scala/spark/Job.scala b/src/scala/spark/Job.scala new file mode 100644 index 0000000000..6abbcbce51 --- /dev/null +++ b/src/scala/spark/Job.scala @@ -0,0 +1,18 @@ +package spark + +import mesos._ + +/** + * Class representing a parallel job in MesosScheduler. Schedules the + * job by implementing various callbacks. + */ +abstract class Job(jobId: Int) { + def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int) + : Option[TaskDescription] + + def statusUpdate(t: TaskStatus): Unit + + def error(code: Int, message: String): Unit + + def getId(): Int = jobId +} diff --git a/src/scala/spark/LocalFileShuffle.scala b/src/scala/spark/LocalFileShuffle.scala new file mode 100644 index 0000000000..367599cfb4 --- /dev/null +++ b/src/scala/spark/LocalFileShuffle.scala @@ -0,0 +1,171 @@ +package spark + +import java.io._ +import java.net.URL +import java.util.UUID +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable.{ArrayBuffer, HashMap} + + +/** + * A simple implementation of shuffle using local files served through HTTP. + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = LocalFileShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + for (i <- 0 until numOutputSplits) { + val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i) + val out = new ObjectOutputStream(new FileOutputStream(file)) + buckets(i).foreach(pair => out.writeObject(pair)) + out.close() + } + (myIndex, LocalFileShuffle.serverUri) + }).collect() + + // Build a hashmap from server URI to list of splits (to facillitate + // fetching all the URIs on a server within a single connection) + val splitsByUri = new HashMap[String, ArrayBuffer[Int]] + for ((inputId, serverUri) <- outputLocs) { + splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += inputId + } + + // TODO: Could broadcast splitsByUri + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + val combiners = new HashMap[K, C] + for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) { + for (i <- inputIds) { + val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId) + val inputStream = new ObjectInputStream(new URL(url).openStream()) + try { + while (true) { + val (k, c) = inputStream.readObject().asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => {} + } + inputStream.close() + } + } + combiners + }) + } +} + + +object LocalFileShuffle extends Logging { + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + private var server: HttpServer = null + private var serverUri: String = null + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID() + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists()) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + val extServerPort = System.getProperty( + "spark.localFileShuffle.external.server.port", "-1").toInt + if (extServerPort != -1) { + // We're using an external HTTP server; set URI relative to its root + var extServerPath = System.getProperty( + "spark.localFileShuffle.external.server.path", "") + if (extServerPath != "" && !extServerPath.endsWith("/")) { + extServerPath += "/" + } + serverUri = "http://%s:%d/%s/spark-local-%s".format( + Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) + } else { + // Create our own server + server = new HttpServer(localDir) + server.start() + serverUri = server.uri + } + initialized = true + } + } + + def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "" + outputId) + return file + } + + def getServerUri(): String = { + initializeIfNeeded() + serverUri + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } +} diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala index 873a97c59c..c45eff64d4 100644 --- a/src/scala/spark/MesosScheduler.scala +++ b/src/scala/spark/MesosScheduler.scala @@ -1,103 +1,130 @@ package spark -import java.io.File +import java.io.{File, FileInputStream, FileOutputStream} +import java.util.{ArrayList => JArrayList} +import java.util.{List => JList} +import java.util.{HashMap => JHashMap} +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet import scala.collection.mutable.Map import scala.collection.mutable.Queue -import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ -import mesos.{Scheduler => NScheduler} +import mesos.{Scheduler => MScheduler} import mesos._ -// The main Scheduler implementation, which talks to Mesos. Clients are expected -// to first call start(), then submit tasks through the runTasks method. -// -// This implementation is currently a little quick and dirty. The following -// improvements need to be made to it: -// 1) Right now, the scheduler uses a linear scan through the tasks to find a -// local one for a given node. It would be faster to have a separate list of -// pending tasks for each node. -// 2) Presenting a single slave in ParallelOperation.slaveOffer makes it -// difficult to balance tasks across nodes. It would be better to pass -// all the offers to the ParallelOperation and have it load-balance. +/** + * The main Scheduler implementation, which runs jobs on Mesos. Clients should + * first call start(), then submit tasks through the runTasks method. + */ private class MesosScheduler( - master: String, frameworkName: String, execArg: Array[Byte]) -extends NScheduler with spark.Scheduler with Logging + sc: SparkContext, master: String, frameworkName: String) +extends MScheduler with spark.Scheduler with Logging { - // Lock used by runTasks to ensure only one thread can be in it - val runTasksMutex = new Object() + // Environment variables to pass to our executors + val ENV_VARS_TO_SEND_TO_EXECUTORS = Array( + "SPARK_MEM", + "SPARK_CLASSPATH", + "SPARK_LIBRARY_PATH" + ) // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() + private var isRegistered = false + private val registeredLock = new Object() - // Current callback object (may be null) - var activeOpsQueue = new Queue[Int] - var activeOps = new HashMap[Int, ParallelOperation] - private var nextOpId = 0 - private[spark] var taskIdToOpId = new HashMap[Int, Int] - - def newOpId(): Int = { - val id = nextOpId - nextOpId += 1 - return id - } + private var activeJobs = new HashMap[Int, Job] + private var activeJobsQueue = new Queue[Job] + + private var taskIdToJobId = new HashMap[Int, Int] + private var jobTasks = new HashMap[Int, HashSet[Int]] - // Incrementing task ID + // Incrementing job and task IDs + private var nextJobId = 0 private var nextTaskId = 0 + // Driver for talking to Mesos + var driver: SchedulerDriver = null + + // JAR server, if any JARs were added by the user to the SparkContext + var jarServer: HttpServer = null + + // URIs of JARs to pass to executor + var jarUris: String = "" + + def newJobId(): Int = this.synchronized { + val id = nextJobId + nextJobId += 1 + return id + } + def newTaskId(): Int = { val id = nextTaskId; nextTaskId += 1; return id } - - // Driver for talking to Mesos - var driver: SchedulerDriver = null override def start() { + if (sc.jars.size > 0) { + // If the user added any JARS to the SparkContext, create an HTTP server + // to serve them to our executors + createJarServer() + } new Thread("Spark scheduler") { setDaemon(true) override def run { - val ns = MesosScheduler.this - ns.driver = new MesosSchedulerDriver(ns, master) - ns.driver.run() + val sched = MesosScheduler.this + sched.driver = new MesosSchedulerDriver(sched, master) + sched.driver.run() } }.start } override def getFrameworkName(d: SchedulerDriver): String = frameworkName - override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo = - new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg) + override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo = { + val sparkHome = sc.getSparkHome match { + case Some(path) => path + case None => + throw new SparkException("Spark home is not set; set it through the " + + "spark.home system property, the SPARK_HOME environment variable " + + "or the SparkContext constructor") + } + val execScript = new File(sparkHome, "spark-executor").getCanonicalPath + val params = new JHashMap[String, String] + for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) { + if (System.getenv(key) != null) { + params("env." + key) = System.getenv(key) + } + } + new ExecutorInfo(execScript, createExecArg()) + } + /** + * The primary means to submit a job to the scheduler. Given a list of tasks, + * runs them and returns an array of the results. + */ override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = { - var opId = 0 waitForRegister() - this.synchronized { - opId = newOpId() - } - val myOp = new SimpleParallelOperation(this, tasks, opId) - + val jobId = newJobId() + val myJob = new SimpleJob(this, tasks, jobId) try { this.synchronized { - this.activeOps(myOp.opId) = myOp - this.activeOpsQueue += myOp.opId + activeJobs(jobId) = myJob + activeJobsQueue += myJob + jobTasks(jobId) = new HashSet() } driver.reviveOffers(); - myOp.join(); + return myJob.join(); } finally { this.synchronized { - this.activeOps.remove(myOp.opId) - this.activeOpsQueue.dequeueAll(x => (x == myOp.opId)) + activeJobs -= jobId + activeJobsQueue.dequeueAll(x => (x == myJob)) + taskIdToJobId --= jobTasks(jobId) + jobTasks.remove(jobId) } } - - if (myOp.errorHappened) - throw new SparkException(myOp.errorMessage, myOp.errorCode) - else - return myOp.results } override def registered(d: SchedulerDriver, frameworkId: String) { @@ -115,51 +142,68 @@ extends NScheduler with spark.Scheduler with Logging } } + /** + * Method called by Mesos to offer resources on slaves. We resond by asking + * our active jobs for tasks in FIFO order. We fill each node with tasks in + * a round-robin manner so that tasks are balanced across the cluster. + */ override def resourceOffer( - d: SchedulerDriver, oid: String, offers: java.util.List[SlaveOffer]) { + d: SchedulerDriver, oid: String, offers: JList[SlaveOffer]) { synchronized { - val tasks = new java.util.ArrayList[TaskDescription] + val tasks = new JArrayList[TaskDescription] val availableCpus = offers.map(_.getParams.get("cpus").toInt) val availableMem = offers.map(_.getParams.get("mem").toInt) - var launchedTask = true - for (opId <- activeOpsQueue) { - launchedTask = true - while (launchedTask) { + var launchedTask = false + for (job <- activeJobsQueue) { + do { launchedTask = false for (i <- 0 until offers.size.toInt) { try { - activeOps(opId).slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match { + job.slaveOffer(offers(i), availableCpus(i), availableMem(i)) match { case Some(task) => tasks.add(task) + taskIdToJobId(task.getTaskId) = job.getId + jobTasks(job.getId) += task.getTaskId availableCpus(i) -= task.getParams.get("cpus").toInt availableMem(i) -= task.getParams.get("mem").toInt - launchedTask = launchedTask || true + launchedTask = true case None => {} } } catch { case e: Exception => logError("Exception in resourceOffer", e) } } - } + } while (launchedTask) } - val params = new java.util.HashMap[String, String] + val params = new JHashMap[String, String] params.put("timeout", "1") - d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout + d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout? } } + // Check whether a Mesos task state represents a finished task + def isFinished(state: TaskState) = { + state == TaskState.TASK_FINISHED || + state == TaskState.TASK_FAILED || + state == TaskState.TASK_KILLED || + state == TaskState.TASK_LOST + } + override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { synchronized { try { - taskIdToOpId.get(status.getTaskId) match { - case Some(opId) => - if (activeOps.contains(opId)) { - activeOps(opId).statusUpdate(status) + taskIdToJobId.get(status.getTaskId) match { + case Some(jobId) => + if (activeJobs.contains(jobId)) { + activeJobs(jobId).statusUpdate(status) + } + if (isFinished(status.getState)) { + taskIdToJobId.remove(status.getTaskId) + jobTasks(jobId) -= status.getTaskId } case None => logInfo("TID " + status.getTaskId + " already finished") } - } catch { case e: Exception => logError("Exception in statusUpdate", e) } @@ -167,180 +211,84 @@ extends NScheduler with spark.Scheduler with Logging } override def error(d: SchedulerDriver, code: Int, message: String) { + logError("Mesos error: %s (error code: %d)".format(message, code)) synchronized { - if (activeOps.size > 0) { - for ((opId, activeOp) <- activeOps) { + if (activeJobs.size > 0) { + // Have each job throw a SparkException with the error + for ((jobId, activeJob) <- activeJobs) { try { - activeOp.error(code, message) + activeJob.error(code, message) } catch { case e: Exception => logError("Exception in error callback", e) } } } else { - logError("Mesos error: %s (error code: %d)".format(message, code)) + // No jobs are active but we still got an error. Just exit since this + // must mean the error is during registration. + // It might be good to do something smarter here in the future. System.exit(1) } } } override def stop() { - if (driver != null) + if (driver != null) { driver.stop() - } - - // TODO: query Mesos for number of cores - override def numCores() = System.getProperty("spark.default.parallelism", "2").toInt -} - - -// Trait representing an object that manages a parallel operation by -// implementing various scheduler callbacks. -trait ParallelOperation { - def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int): Option[TaskDescription] - def statusUpdate(t: TaskStatus): Unit - def error(code: Int, message: String): Unit -} - - -class SimpleParallelOperation[T: ClassManifest]( - sched: MesosScheduler, tasks: Array[Task[T]], val opId: Int) -extends ParallelOperation with Logging -{ - // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong - - val callingThread = currentThread - val numTasks = tasks.length - val results = new Array[T](numTasks) - val launched = new Array[Boolean](numTasks) - val finished = new Array[Boolean](numTasks) - val tidToIndex = Map[Int, Int]() - - var allFinished = false - val joinLock = new Object() - - var errorHappened = false - var errorCode = 0 - var errorMessage = "" - - var tasksLaunched = 0 - var tasksFinished = 0 - var lastPreferredLaunchTime = System.currentTimeMillis - - def setAllFinished() { - joinLock.synchronized { - allFinished = true - joinLock.notifyAll() } - } - - def join() { - joinLock.synchronized { - while (!allFinished) - joinLock.wait() - } - } - - def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int): Option[TaskDescription] = { - if (tasksLaunched < numTasks) { - var checkPrefVals: Array[Boolean] = Array(true) - val time = System.currentTimeMillis - if (time - lastPreferredLaunchTime > LOCALITY_WAIT) - checkPrefVals = Array(true, false) // Allow non-preferred tasks - // TODO: Make desiredCpus and desiredMem configurable - val desiredCpus = 1 - val desiredMem = 500 - if ((availableCpus < desiredCpus) || (availableMem < desiredMem)) - return None - for (checkPref <- checkPrefVals; i <- 0 until numTasks) { - if (!launched(i) && (!checkPref || - tasks(i).preferredLocations.contains(offer.getHost) || - tasks(i).preferredLocations.isEmpty)) - { - val taskId = sched.newTaskId() - sched.taskIdToOpId(taskId) = opId - tidToIndex(taskId) = i - val preferred = if(checkPref) "preferred" else "non-preferred" - val message = - "Starting task %d as opId %d, TID %s on slave %s: %s (%s)".format( - i, opId, taskId, offer.getSlaveId, offer.getHost, preferred) - logInfo(message) - tasks(i).markStarted(offer) - launched(i) = true - tasksLaunched += 1 - if (checkPref) - lastPreferredLaunchTime = time - val params = new java.util.HashMap[String, String] - params.put("cpus", "" + desiredCpus) - params.put("mem", "" + desiredMem) - val serializedTask = Utils.serialize(tasks(i)) - //logInfo("Serialized size: " + serializedTask.size) - return Some(new TaskDescription(taskId, offer.getSlaveId, - "task_" + taskId, params, serializedTask)) - } - } + if (jarServer != null) { + jarServer.stop() } - return None } - def statusUpdate(status: TaskStatus) { - status.getState match { - case TaskState.TASK_FINISHED => - taskFinished(status) - case TaskState.TASK_LOST => - taskLost(status) - case TaskState.TASK_FAILED => - taskLost(status) - case TaskState.TASK_KILLED => - taskLost(status) - case _ => + // TODO: query Mesos for number of cores + override def numCores() = + System.getProperty("spark.default.parallelism", "2").toInt + + // Create a server for all the JARs added by the user to SparkContext. + // We first copy the JARs to a temp directory for easier server setup. + private def createJarServer() { + val jarDir = Utils.createTempDir() + logInfo("Temp directory for JARs: " + jarDir) + val filenames = ArrayBuffer[String]() + // Copy each JAR to a unique filename in the jarDir + for ((path, index) <- sc.jars.zipWithIndex) { + val file = new File(path) + val filename = index + "_" + file.getName + copyFile(file, new File(jarDir, filename)) + filenames += filename } + // Create the server + jarServer = new HttpServer(jarDir) + jarServer.start() + // Build up the jar URI list + val serverUri = jarServer.uri + jarUris = filenames.map(f => serverUri + "/" + f).mkString(",") + logInfo("JAR server started at " + serverUri) } - def taskFinished(status: TaskStatus) { - val tid = status.getTaskId - val index = tidToIndex(tid) - if (!finished(index)) { - tasksFinished += 1 - logInfo("Finished opId %d TID %d (progress: %d/%d)".format( - opId, tid, tasksFinished, numTasks)) - // Deserialize task result - val result = Utils.deserialize[TaskResult[T]](status.getData) - results(index) = result.value - // Update accumulators - Accumulators.add(callingThread, result.accumUpdates) - // Mark finished and stop if we've finished all the tasks - finished(index) = true - // Remove TID -> opId mapping from sched - sched.taskIdToOpId.remove(tid) - if (tasksFinished == numTasks) - setAllFinished() - } else { - logInfo("Ignoring task-finished event for TID " + tid + - " because task " + index + " is already finished") - } + // Copy a file on the local file system + private def copyFile(source: File, dest: File) { + val in = new FileInputStream(source) + val out = new FileOutputStream(dest) + Utils.copyStream(in, out, true) } - def taskLost(status: TaskStatus) { - val tid = status.getTaskId - val index = tidToIndex(tid) - if (!finished(index)) { - logInfo("Lost opId " + opId + " TID " + tid) - launched(index) = false - sched.taskIdToOpId.remove(tid) - tasksLaunched -= 1 - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") + // Create and serialize the executor argument to pass to Mesos. + // Our executor arg is an array containing all the spark.* system properties + // in the form of (String, String) pairs. + private def createExecArg(): Array[Byte] = { + val props = new HashMap[String, String] + val iter = System.getProperties.entrySet.iterator + while (iter.hasNext) { + val entry = iter.next + val (key, value) = (entry.getKey.toString, entry.getValue.toString) + if (key.startsWith("spark.")) { + props(key) = value + } } - } - - def error(code: Int, message: String) { - // Save the error message - errorHappened = true - errorCode = code - errorMessage = message - // Indicate to caller thread that we're done - setAllFinished() + // Set spark.jar.uris to our JAR URIs, regardless of system property + props("spark.jar.uris") = jarUris + // Serialize the map as an array of (String, String) pairs + return Utils.serialize(props.toArray) } } diff --git a/src/scala/spark/NumberedSplitRDD.scala b/src/scala/spark/NumberedSplitRDD.scala new file mode 100644 index 0000000000..7b12210d84 --- /dev/null +++ b/src/scala/spark/NumberedSplitRDD.scala @@ -0,0 +1,42 @@ +package spark + +import mesos.SlaveOffer + + +/** + * An RDD that takes the splits of a parent RDD and gives them unique indexes. + * This is useful for a variety of shuffle implementations. + */ +class NumberedSplitRDD[T: ClassManifest](prev: RDD[T]) +extends RDD[(Int, Iterator[T])](prev.sparkContext) { + @transient val splits_ = { + prev.splits.zipWithIndex.map { + case (s, i) => new NumberedSplitRDDSplit(s, i): Split + }.toArray + } + + override def splits = splits_ + + override def preferredLocations(split: Split) = { + val nsplit = split.asInstanceOf[NumberedSplitRDDSplit] + prev.preferredLocations(nsplit.prev) + } + + override def iterator(split: Split) = { + val nsplit = split.asInstanceOf[NumberedSplitRDDSplit] + Iterator((nsplit.index, prev.iterator(nsplit.prev))) + } + + override def taskStarted(split: Split, slot: SlaveOffer) = { + val nsplit = split.asInstanceOf[NumberedSplitRDDSplit] + prev.taskStarted(nsplit.prev, slot) + } +} + + +/** + * A split in a NumberedSplitRDD. + */ +class NumberedSplitRDDSplit(val prev: Split, val index: Int) extends Split { + override def getId() = "NumberedSplitRDDSplit(%d)".format(index) +} diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala index 803c063865..bac59319a0 100644 --- a/src/scala/spark/RDD.scala +++ b/src/scala/spark/RDD.scala @@ -1,7 +1,6 @@ package spark import java.util.concurrent.atomic.AtomicLong -import java.util.concurrent.ConcurrentHashMap import java.util.HashSet import java.util.Random @@ -9,13 +8,13 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.Map import scala.collection.mutable.HashMap +import SparkContext._ + import mesos._ -import com.google.common.collect.MapMaker @serializable -abstract class RDD[T: ClassManifest]( - @transient sc: SparkContext) { +abstract class RDD[T: ClassManifest](@transient sc: SparkContext) { def splits: Array[Split] def iterator(split: Split): Iterator[T] def preferredLocations(split: Split): Seq[String] @@ -26,7 +25,6 @@ abstract class RDD[T: ClassManifest]( def map[U: ClassManifest](f: T => U) = new MappedRDD(this, sc.clean(f)) def filter(f: T => Boolean) = new FilteredRDD(this, sc.clean(f)) - def aggregateSplit() = new SplitRDD(this) def cache() = new CachedRDD(this) def sample(withReplacement: Boolean, frac: Double, seed: Int) = @@ -78,15 +76,28 @@ abstract class RDD[T: ClassManifest]( case _ => throw new UnsupportedOperationException("empty collection") } - def count(): Long = - try { map(x => 1L).reduce(_+_) } - catch { case e: UnsupportedOperationException => 0L } + def count(): Long = { + try { + map(x => 1L).reduce(_+_) + } catch { + case e: UnsupportedOperationException => 0L // No elements in RDD + } + } - def union(other: RDD[T]) = new UnionRDD(sc, this, other) - def cartesian[U: ClassManifest](other: RDD[U]) = new CartesianRDD(sc, this, other) + def union(other: RDD[T]) = new UnionRDD(sc, Array(this, other)) def ++(other: RDD[T]) = this.union(other) + def splitRdd() = new SplitRDD(this) + + def cartesian[U: ClassManifest](other: RDD[U]) = + new CartesianRDD(sc, this, other) + + def groupBy[K](func: T => K, numSplits: Int): RDD[(K, Seq[T])] = + this.map(t => (func(t), t)).groupByKey(numSplits) + + def groupBy[K](func: T => K): RDD[(K, Seq[T])] = + groupBy[K](func, sc.numCores) } @serializable @@ -129,7 +140,7 @@ extends RDDTask[Option[T], T](rdd, split) with Logging { } class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], f: T => U) + prev: RDD[T], f: T => U) extends RDD[U](prev.sparkContext) { override def splits = prev.splits override def preferredLocations(split: Split) = prev.preferredLocations(split) @@ -138,7 +149,7 @@ extends RDD[U](prev.sparkContext) { } class FilteredRDD[T: ClassManifest]( - prev: RDD[T], f: T => Boolean) + prev: RDD[T], f: T => Boolean) extends RDD[T](prev.sparkContext) { override def splits = prev.splits override def preferredLocations(split: Split) = prev.preferredLocations(split) @@ -147,7 +158,7 @@ extends RDD[T](prev.sparkContext) { } class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], f: T => Traversable[U]) + prev: RDD[T], f: T => Traversable[U]) extends RDD[U](prev.sparkContext) { override def splits = prev.splits override def preferredLocations(split: Split) = prev.preferredLocations(split) @@ -156,7 +167,7 @@ extends RDD[U](prev.sparkContext) { override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) } -class SplitRDD[T: ClassManifest](prev: RDD[T]) +class SplitRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.sparkContext) { override def splits = prev.splits override def preferredLocations(split: Split) = prev.preferredLocations(split) @@ -171,16 +182,16 @@ extends RDD[Array[T]](prev.sparkContext) { } class SampledRDD[T: ClassManifest]( - prev: RDD[T], withReplacement: Boolean, frac: Double, seed: Int) + prev: RDD[T], withReplacement: Boolean, frac: Double, seed: Int) extends RDD[T](prev.sparkContext) { - + @transient val splits_ = { val rg = new Random(seed); prev.splits.map(x => new SeededSplit(x, rg.nextInt)) } override def splits = splits_.asInstanceOf[Array[Split]] override def preferredLocations(split: Split) = prev.preferredLocations(split.asInstanceOf[SeededSplit].prev) - override def iterator(splitIn: Split) = { + override def iterator(splitIn: Split) = { val split = splitIn.asInstanceOf[SeededSplit] val rg = new Random(split.seed); // Sampling with replacement (TODO: use reservoir sampling to make this more efficient?) @@ -214,7 +225,7 @@ extends RDD[T](prev.sparkContext) with Logging { else prev.preferredLocations(split) } - + override def iterator(split: Split): Iterator[T] = { val key = id + "::" + split.getId() logInfo("CachedRDD split key is " + key) @@ -261,45 +272,36 @@ private object CachedRDD { def newId() = nextId.getAndIncrement() // Stores map results for various splits locally (on workers) - val cache = new MapMaker().softValues().makeMap[String, AnyRef]() + val cache = Cache.newKeySpace() // Remembers which splits are currently being loaded (on workers) val loading = new HashSet[String] } @serializable -abstract class UnionSplit[T: ClassManifest] extends Split { - def iterator(): Iterator[T] - def preferredLocations(): Seq[String] - def getId(): String -} - -@serializable -class UnionSplitImpl[T: ClassManifest]( - rdd: RDD[T], split: Split) -extends UnionSplit[T] { - override def iterator() = rdd.iterator(split) - override def preferredLocations() = rdd.preferredLocations(split) - override def getId() = - "UnionSplitImpl(" + split.getId() + ")" +class UnionSplit[T: ClassManifest](rdd: RDD[T], split: Split) +extends Split { + def iterator() = rdd.iterator(split) + def preferredLocations() = rdd.preferredLocations(split) + override def getId() = "UnionSplit(" + split.getId() + ")" } @serializable -class UnionRDD[T: ClassManifest]( - sc: SparkContext, rdd1: RDD[T], rdd2: RDD[T]) +class UnionRDD[T: ClassManifest](sc: SparkContext, rdds: Seq[RDD[T]]) extends RDD[T](sc) { - - @transient val splits_ : Array[UnionSplit[T]] = { - val a1 = rdd1.splits.map(s => new UnionSplitImpl(rdd1, s)) - val a2 = rdd2.splits.map(s => new UnionSplitImpl(rdd2, s)) - (a1 ++ a2).toArray + @transient val splits_ : Array[Split] = { + val splits: Seq[Split] = + for (rdd <- rdds; split <- rdd.splits) + yield new UnionSplit(rdd, split) + splits.toArray } - override def splits = splits_.asInstanceOf[Array[Split]] + override def splits = splits_ - override def iterator(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() + override def iterator(s: Split): Iterator[T] = + s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = + override def preferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() } @@ -336,8 +338,8 @@ extends RDD[Pair[T, U]](sc) { } } -@serializable class PairRDDExtras[K, V](rdd: RDD[(K, V)]) { - def reduceByKey(func: (V, V) => V): Map[K, V] = { +@serializable class PairRDDExtras[K, V](self: RDD[(K, V)]) { + def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = { def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = { for ((k, v) <- m2) { m1.get(k) match { @@ -347,6 +349,70 @@ extends RDD[Pair[T, U]](sc) { } return m1 } - rdd.map(pair => HashMap(pair)).reduce(mergeMaps) + self.map(pair => HashMap(pair)).reduce(mergeMaps) + } + + def combineByKey[C](createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + numSplits: Int) + : RDD[(K, C)] = + { + val shufClass = Class.forName(System.getProperty( + "spark.shuffle.class", "spark.DfsShuffle")) + val shuf = shufClass.newInstance().asInstanceOf[Shuffle[K, V, C]] + shuf.compute(self, numSplits, createCombiner, mergeValue, mergeCombiners) } + + def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = { + combineByKey[V]((v: V) => v, func, func, numSplits) + } + + def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = { + def createCombiner(v: V) = ArrayBuffer(v) + def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v + def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2 + val bufs = combineByKey[ArrayBuffer[V]]( + createCombiner _, mergeValue _, mergeCombiners _, numSplits) + bufs.asInstanceOf[RDD[(K, Seq[V])]] + } + + def join[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, W))] = { + val vs: RDD[(K, Either[V, W])] = self.map { case (k, v) => (k, Left(v)) } + val ws: RDD[(K, Either[V, W])] = other.map { case (k, w) => (k, Right(w)) } + (vs ++ ws).groupByKey(numSplits).flatMap { + case (k, seq) => { + val vbuf = new ArrayBuffer[V] + val wbuf = new ArrayBuffer[W] + seq.foreach(_ match { + case Left(v) => vbuf += v + case Right(w) => wbuf += w + }) + for (v <- vbuf; w <- wbuf) yield (k, (v, w)) + } + } + } + + def combineByKey[C](createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = { + combineByKey(createCombiner, mergeValue, mergeCombiners, numCores) + } + + def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { + reduceByKey(func, numCores) + } + + def groupByKey(): RDD[(K, Seq[V])] = { + groupByKey(numCores) + } + + def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { + join(other, numCores) + } + + def numCores = self.sparkContext.numCores + + def collectAsMap(): Map[K, V] = HashMap(self.collect(): _*) } diff --git a/src/scala/spark/Shuffle.scala b/src/scala/spark/Shuffle.scala new file mode 100644 index 0000000000..4c5649b537 --- /dev/null +++ b/src/scala/spark/Shuffle.scala @@ -0,0 +1,15 @@ +package spark + +/** + * A trait for shuffle system. Given an input RDD and combiner functions + * for PairRDDExtras.combineByKey(), returns an output RDD. + */ +@serializable +trait Shuffle[K, V, C] { + def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] +} diff --git a/src/scala/spark/SimpleJob.scala b/src/scala/spark/SimpleJob.scala new file mode 100644 index 0000000000..09846ccc34 --- /dev/null +++ b/src/scala/spark/SimpleJob.scala @@ -0,0 +1,272 @@ +package spark + +import java.util.{HashMap => JHashMap} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import mesos._ + + +/** + * A Job that runs a set of tasks with no interdependencies. + */ +class SimpleJob[T: ClassManifest]( + sched: MesosScheduler, tasks: Array[Task[T]], val jobId: Int) +extends Job(jobId) with Logging +{ + // Maximum time to wait to run a task in a preferred location (in ms) + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + + // CPUs and memory to request per task + val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt + val MEM_PER_TASK = System.getProperty("spark.task.mem", "512").toInt + + // Maximum times a task is allowed to fail before failing the job + val MAX_TASK_FAILURES = 4 + + val callingThread = currentThread + val numTasks = tasks.length + val results = new Array[T](numTasks) + val launched = new Array[Boolean](numTasks) + val finished = new Array[Boolean](numTasks) + val numFailures = new Array[Int](numTasks) + val tidToIndex = HashMap[Int, Int]() + + var allFinished = false + val joinLock = new Object() // Used to wait for all tasks to finish + + var tasksLaunched = 0 + var tasksFinished = 0 + + // Last time when we launched a preferred task (for delay scheduling) + var lastPreferredLaunchTime = System.currentTimeMillis + + // List of pending tasks for each node. These collections are actually + // treated as stacks, in which new tasks are added to the end of the + // ArrayBuffer and removed from the end. This makes it faster to detect + // tasks that repeatedly fail because whenever a task failed, it is put + // back at the head of the stack. They are also only cleaned up lazily; + // when a task is launched, it remains in all the pending lists except + // the one that it was launched from, but gets removed from them later. + val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + + // List containing pending tasks with no locality preferences + val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + + // List containing all pending tasks (also used as a stack, as above) + val allPendingTasks = new ArrayBuffer[Int] + + // Did the job fail? + var failed = false + var causeOfFailure = "" + + // Add all our tasks to the pending lists. We do this in reverse order + // of task index so that tasks with low indices get launched first. + for (i <- (0 until numTasks).reverse) { + addPendingTask(i) + } + + // Add a task to all the pending-task lists that it should be on. + def addPendingTask(index: Int) { + val locations = tasks(index).preferredLocations + if (locations.size == 0) { + pendingTasksWithNoPrefs += index + } else { + for (host <- locations) { + val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + list += index + } + } + allPendingTasks += index + } + + // Mark the job as finished and wake up any threads waiting on it + def setAllFinished() { + joinLock.synchronized { + allFinished = true + joinLock.notifyAll() + } + } + + // Wait until the job finishes and return its results + def join(): Array[T] = { + joinLock.synchronized { + while (!allFinished) { + joinLock.wait() + } + if (failed) { + throw new SparkException(causeOfFailure) + } else { + return results + } + } + } + + // Return the pending tasks list for a given host, or an empty list if + // there is no map entry for that host + def getPendingTasksForHost(host: String): ArrayBuffer[Int] = { + pendingTasksForHost.getOrElse(host, ArrayBuffer()) + } + + // Dequeue a pending task from the given list and return its index. + // Return None if the list is empty. + // This method also cleans up any tasks in the list that have already + // been launched, since we want that to happen lazily. + def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = { + while (!list.isEmpty) { + val index = list.last + list.trimEnd(1) + if (!launched(index) && !finished(index)) { + return Some(index) + } + } + return None + } + + // Dequeue a pending task for a given node and return its index. + // If localOnly is set to false, allow non-local tasks as well. + def findTask(host: String, localOnly: Boolean): Option[Int] = { + val localTask = findTaskFromList(getPendingTasksForHost(host)) + if (localTask != None) { + return localTask + } + val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs) + if (noPrefTask != None) { + return noPrefTask + } + if (!localOnly) { + return findTaskFromList(allPendingTasks) // Look for non-local task + } else { + return None + } + } + + // Does a host count as a preferred location for a task? This is true if + // either the task has preferred locations and this host is one, or it has + // no preferred locations (in which we still count the launch as preferred). + def isPreferredLocation(task: Task[T], host: String): Boolean = { + val locs = task.preferredLocations + return (locs.contains(host) || locs.isEmpty) + } + + // Respond to an offer of a single slave from the scheduler by finding a task + def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int) + : Option[TaskDescription] = { + if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK && + availableMem >= MEM_PER_TASK) { + val time = System.currentTimeMillis + val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) + val host = offer.getHost + findTask(host, localOnly) match { + case Some(index) => { + // Found a task; do some bookkeeping and return a Mesos task for it + val task = tasks(index) + val taskId = sched.newTaskId() + // Figure out whether this should count as a preferred launch + val preferred = isPreferredLocation(task, host) + val prefStr = if(preferred) "preferred" else "non-preferred" + val message = + "Starting task %d:%d as TID %s on slave %s: %s (%s)".format( + jobId, index, taskId, offer.getSlaveId, host, prefStr) + logInfo(message) + // Do various bookkeeping + tidToIndex(taskId) = index + task.markStarted(offer) + launched(index) = true + tasksLaunched += 1 + if (preferred) + lastPreferredLaunchTime = time + // Create and return the Mesos task object + val params = new JHashMap[String, String] + params.put("cpus", CPUS_PER_TASK.toString) + params.put("mem", MEM_PER_TASK.toString) + val serializedTask = Utils.serialize(task) + logDebug("Serialized size: " + serializedTask.size) + val taskName = "task %d:%d".format(jobId, index) + return Some(new TaskDescription( + taskId, offer.getSlaveId, taskName, params, serializedTask)) + } + case _ => + } + } + return None + } + + def statusUpdate(status: TaskStatus) { + status.getState match { + case TaskState.TASK_FINISHED => + taskFinished(status) + case TaskState.TASK_LOST => + taskLost(status) + case TaskState.TASK_FAILED => + taskLost(status) + case TaskState.TASK_KILLED => + taskLost(status) + case _ => + } + } + + def taskFinished(status: TaskStatus) { + val tid = status.getTaskId + val index = tidToIndex(tid) + if (!finished(index)) { + tasksFinished += 1 + logInfo("Finished TID %d (progress: %d/%d)".format( + tid, tasksFinished, numTasks)) + // Deserialize task result + val result = Utils.deserialize[TaskResult[T]](status.getData) + results(index) = result.value + // Update accumulators + Accumulators.add(callingThread, result.accumUpdates) + // Mark finished and stop if we've finished all the tasks + finished(index) = true + if (tasksFinished == numTasks) + setAllFinished() + } else { + logInfo("Ignoring task-finished event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def taskLost(status: TaskStatus) { + val tid = status.getTaskId + val index = tidToIndex(tid) + if (!finished(index)) { + logInfo("Lost TID %d (task %d:%d)".format(tid, jobId, index)) + launched(index) = false + tasksLaunched -= 1 + // Re-enqueue the task as pending + addPendingTask(index) + // Mark it as failed + if (status.getState == TaskState.TASK_FAILED || + status.getState == TaskState.TASK_LOST) { + numFailures(index) += 1 + if (numFailures(index) > MAX_TASK_FAILURES) { + logError("Task %d:%d failed more than %d times; aborting job".format( + jobId, index, MAX_TASK_FAILURES)) + abort("Task %d failed more than %d times".format( + index, MAX_TASK_FAILURES)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } + + def error(code: Int, message: String) { + // Save the error message + abort("Mesos error: %s (error code: %d)".format(message, code)) + } + + def abort(message: String) { + joinLock.synchronized { + failed = true + causeOfFailure = message + // TODO: Kill running tasks if we were not terminated due to a Mesos error + // Indicate to any joining thread that we're done + setAllFinished() + } + } +} diff --git a/src/scala/spark/SizeEstimator.scala b/src/scala/spark/SizeEstimator.scala new file mode 100644 index 0000000000..12dd19d704 --- /dev/null +++ b/src/scala/spark/SizeEstimator.scala @@ -0,0 +1,160 @@ +package spark + +import java.lang.reflect.Field +import java.lang.reflect.Modifier +import java.lang.reflect.{Array => JArray} +import java.util.IdentityHashMap +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.mutable.ArrayBuffer + + +/** + * Estimates the sizes of Java objects (number of bytes of memory they occupy), + * for use in memory-aware caches. + * + * Based on the following JavaWorld article: + * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html + */ +object SizeEstimator { + private val OBJECT_SIZE = 8 // Minimum size of a java.lang.Object + private val POINTER_SIZE = 4 // Size of an object reference + + // Sizes of primitive types + private val BYTE_SIZE = 1 + private val BOOLEAN_SIZE = 1 + private val CHAR_SIZE = 2 + private val SHORT_SIZE = 2 + private val INT_SIZE = 4 + private val LONG_SIZE = 8 + private val FLOAT_SIZE = 4 + private val DOUBLE_SIZE = 8 + + // A cache of ClassInfo objects for each class + private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo] + classInfos.put(classOf[Object], new ClassInfo(OBJECT_SIZE, Nil)) + + /** + * The state of an ongoing size estimation. Contains a stack of objects + * to visit as well as an IdentityHashMap of visited objects, and provides + * utility methods for enqueueing new objects to visit. + */ + private class SearchState { + val visited = new IdentityHashMap[AnyRef, AnyRef] + val stack = new ArrayBuffer[AnyRef] + var size = 0L + + def enqueue(obj: AnyRef) { + if (obj != null && !visited.containsKey(obj)) { + visited.put(obj, null) + stack += obj + } + } + + def isFinished(): Boolean = stack.isEmpty + + def dequeue(): AnyRef = { + val elem = stack.last + stack.trimEnd(1) + return elem + } + } + + /** + * Cached information about each class. We remember two things: the + * "shell size" of the class (size of all non-static fields plus the + * java.lang.Object size), and any fields that are pointers to objects. + */ + private class ClassInfo( + val shellSize: Long, + val pointerFields: List[Field]) {} + + def estimate(obj: AnyRef): Long = { + val state = new SearchState + state.enqueue(obj) + while (!state.isFinished) { + visitSingleObject(state.dequeue(), state) + } + return state.size + } + + private def visitSingleObject(obj: AnyRef, state: SearchState) { + val cls = obj.getClass + if (cls.isArray) { + visitArray(obj, cls, state) + } else { + val classInfo = getClassInfo(cls) + state.size += classInfo.shellSize + for (field <- classInfo.pointerFields) { + state.enqueue(field.get(obj)) + } + } + } + + private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) { + val length = JArray.getLength(array) + val elementClass = cls.getComponentType + if (elementClass.isPrimitive) { + state.size += length * primitiveSize(elementClass) + } else { + state.size += length * POINTER_SIZE + for (i <- 0 until length) { + state.enqueue(JArray.get(array, i)) + } + } + } + + private def primitiveSize(cls: Class[_]): Long = { + if (cls == classOf[Byte]) + BYTE_SIZE + else if (cls == classOf[Boolean]) + BOOLEAN_SIZE + else if (cls == classOf[Char]) + CHAR_SIZE + else if (cls == classOf[Short]) + SHORT_SIZE + else if (cls == classOf[Int]) + INT_SIZE + else if (cls == classOf[Long]) + LONG_SIZE + else if (cls == classOf[Float]) + FLOAT_SIZE + else if (cls == classOf[Double]) + DOUBLE_SIZE + else throw new IllegalArgumentException( + "Non-primitive class " + cls + " passed to primitiveSize()") + } + + /** + * Get or compute the ClassInfo for a given class. + */ + private def getClassInfo(cls: Class[_]): ClassInfo = { + // Check whether we've already cached a ClassInfo for this class + val info = classInfos.get(cls) + if (info != null) { + return info + } + + val parent = getClassInfo(cls.getSuperclass) + var shellSize = parent.shellSize + var pointerFields = parent.pointerFields + + for (field <- cls.getDeclaredFields) { + if (!Modifier.isStatic(field.getModifiers)) { + val fieldClass = field.getType + if (fieldClass.isPrimitive) { + shellSize += primitiveSize(fieldClass) + } else { + field.setAccessible(true) // Enable future get()'s on this field + shellSize += POINTER_SIZE + pointerFields = field :: pointerFields + } + } + } + + // Create and cache a new ClassInfo + val newInfo = new ClassInfo(shellSize, pointerFields) + classInfos.put(cls, newInfo) + return newInfo + } +} diff --git a/src/scala/spark/SoftReferenceCache.scala b/src/scala/spark/SoftReferenceCache.scala new file mode 100644 index 0000000000..e84aa57efa --- /dev/null +++ b/src/scala/spark/SoftReferenceCache.scala @@ -0,0 +1,13 @@ +package spark + +import com.google.common.collect.MapMaker + +/** + * An implementation of Cache that uses soft references. + */ +class SoftReferenceCache extends Cache { + val map = new MapMaker().softValues().makeMap[Any, Any]() + + override def get(key: Any): Any = map.get(key) + override def put(key: Any, value: Any) = map.put(key, value) +} diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala index ef328d821a..8b8e408266 100644 --- a/src/scala/spark/SparkContext.scala +++ b/src/scala/spark/SparkContext.scala @@ -1,26 +1,55 @@ package spark import java.io._ -import java.util.UUID import scala.collection.mutable.ArrayBuffer import scala.actors.Actor._ -class SparkContext(master: String, frameworkName: String) extends Logging { +import org.apache.hadoop.mapred.InputFormat +import org.apache.hadoop.mapred.SequenceFileInputFormat + + +class SparkContext( + master: String, + frameworkName: String, + val sparkHome: String = null, + val jars: Seq[String] = Nil) +extends Logging { + private var scheduler: Scheduler = { + // Regular expression used for local[N] master format + val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r + master match { + case "local" => + new LocalScheduler(1) + case LOCAL_N_REGEX(threads) => + new LocalScheduler(threads.toInt) + case _ => + System.loadLibrary("mesos") + new MesosScheduler(this, master, frameworkName) + } + } + + private val isLocal = scheduler.isInstanceOf[LocalScheduler] + + // Start the scheduler, the cache and the broadcast system + scheduler.start() + Cache.initialize() Broadcast.initialize(true) - def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int) = + // Methods for creating RDDs + + def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int): RDD[T] = new ParallelArray[T](this, seq, numSlices) - def parallelize[T: ClassManifest](seq: Seq[T]): ParallelArray[T] = - parallelize(seq, scheduler.numCores) + def parallelize[T: ClassManifest](seq: Seq[T]): RDD[T] = + parallelize(seq, numCores) - def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = - new Accumulator(initialValue, param) + def textFile(path: String): RDD[String] = + new HadoopTextFile(this, path) // TODO: Keep around a weak hash map of values to Cached versions? - // def broadcast[T](value: T) = new DfsBroadcast(value, local) - def broadcast[T](value: T) = new ChainedBroadcast(value, local) + // def broadcast[T](value: T) = new DfsBroadcast(value, isLocal) + def broadcast[T](value: T) = new ChainedBroadcast(value, isLocal) // def broadcast[T](value: T) = { // val broadcastClass = System.getProperty("spark.broadcast.Class", @@ -35,39 +64,88 @@ class SparkContext(master: String, frameworkName: String) extends Logging { // instance = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] // } + /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ + def hadoopFile[K, V](path: String, + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V]) + : RDD[(K, V)] = { + new HadoopFile(this, path, inputFormatClass, keyClass, valueClass) + } - def textFile(path: String) = new HdfsTextFile(this, path) + /** + * Smarter version of hadoopFile() that uses class manifests to figure out + * the classes of keys, values and the InputFormat so that users don't need + * to pass them directly. + */ + def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) + (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]) + : RDD[(K, V)] = { + hadoopFile(path, + fm.erasure.asInstanceOf[Class[F]], + km.erasure.asInstanceOf[Class[K]], + vm.erasure.asInstanceOf[Class[V]]) + } - val LOCAL_REGEX = """local\[([0-9]+)\]""".r + /** Get an RDD for a Hadoop SequenceFile with given key and value types */ + def sequenceFile[K, V](path: String, + keyClass: Class[K], + valueClass: Class[V]): RDD[(K, V)] = { + val inputFormatClass = classOf[SequenceFileInputFormat[K, V]] + hadoopFile(path, inputFormatClass, keyClass, valueClass) + } - private[spark] var scheduler: Scheduler = master match { - case "local" => new LocalScheduler(1) - case LOCAL_REGEX(threads) => new LocalScheduler(threads.toInt) - case _ => { System.loadLibrary("mesos"); - new MesosScheduler(master, frameworkName, createExecArg()) } + /** + * Smarter version of sequenceFile() that obtains the key and value classes + * from ClassManifests instead of requiring the user to pass them directly. + */ + def sequenceFile[K, V](path: String) + (implicit km: ClassManifest[K], vm: ClassManifest[V]): RDD[(K, V)] = { + sequenceFile(path, + km.erasure.asInstanceOf[Class[K]], + vm.erasure.asInstanceOf[Class[V]]) } - private val local = scheduler.isInstanceOf[LocalScheduler] + /** Build the union of a list of RDDs. */ + def union[T: ClassManifest](rdds: RDD[T]*): RDD[T] = + new UnionRDD(this, rdds) - scheduler.start() + // Methods for creating shared variables - private def createExecArg(): Array[Byte] = { - // Our executor arg is an array containing all the spark.* system properties - val props = new ArrayBuffer[(String, String)] - val iter = System.getProperties.entrySet.iterator - while (iter.hasNext) { - val entry = iter.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) - if (key.startsWith("spark.")) - props += key -> value - } - return Utils.serialize(props.toArray) + def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = + new Accumulator(initialValue, param) + + // Stop the SparkContext + def stop() { + scheduler.stop() + scheduler = null + } + + // Wait for the scheduler to be registered + def waitForRegister() { + scheduler.waitForRegister() + } + + // Get Spark's home location from either a value set through the constructor, + // or the spark.home Java property, or the SPARK_HOME environment variable + // (in that order of preference). If neither of these is set, return None. + def getSparkHome(): Option[String] = { + if (sparkHome != null) + Some(sparkHome) + else if (System.getProperty("spark.home") != null) + Some(System.getProperty("spark.home")) + else if (System.getenv("SPARK_HOME") != null) + Some(System.getenv("SPARK_HOME")) + else + None } + // Submit an array of tasks (passed as functions) to the scheduler def runTasks[T: ClassManifest](tasks: Array[() => T]): Array[T] = { runTaskObjects(tasks.map(f => new FunctionTask(f))) } + // Run an array of spark.Task objects private[spark] def runTaskObjects[T: ClassManifest](tasks: Seq[Task[T]]) : Array[T] = { logInfo("Running " + tasks.length + " tasks in parallel") @@ -77,23 +155,22 @@ class SparkContext(master: String, frameworkName: String) extends Logging { return result } - def stop() { - scheduler.stop() - scheduler = null - } - - def waitForRegister() { - scheduler.waitForRegister() - } - // Clean a closure to make it ready to serialized and send to tasks // (removes unreferenced variables in $outer's, updates REPL variables) private[spark] def clean[F <: AnyRef](f: F): F = { ClosureCleaner.clean(f) return f } + + // Get the number of cores available to run tasks (as reported by Scheduler) + def numCores = scheduler.numCores } + +/** + * The SparkContext object contains a number of implicit conversions and + * parameters for use with various Spark features. + */ object SparkContext { implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 diff --git a/src/scala/spark/SparkException.scala b/src/scala/spark/SparkException.scala index 7257bf7b0c..6f9be1a94f 100644 --- a/src/scala/spark/SparkException.scala +++ b/src/scala/spark/SparkException.scala @@ -1,7 +1,3 @@ package spark -class SparkException(message: String) extends Exception(message) { - def this(message: String, errorCode: Int) { - this("%s (error code: %d)".format(message, errorCode)) - } -} +class SparkException(message: String) extends Exception(message) {} diff --git a/src/scala/spark/Split.scala b/src/scala/spark/Split.scala index 0f7a21354d..116cd16370 100644 --- a/src/scala/spark/Split.scala +++ b/src/scala/spark/Split.scala @@ -3,7 +3,7 @@ package spark /** * A partition of an RDD. */ -trait Split { +@serializable trait Split { /** * Get a unique ID for this split which can be used, for example, to * set up caches based on it. The ID should stay the same if we serialize diff --git a/src/scala/spark/Utils.scala b/src/scala/spark/Utils.scala index 27d73aefbd..e333dd9c91 100644 --- a/src/scala/spark/Utils.scala +++ b/src/scala/spark/Utils.scala @@ -1,12 +1,18 @@ package spark import java.io._ +import java.net.InetAddress +import java.util.UUID import scala.collection.mutable.ArrayBuffer +import scala.util.Random +/** + * Various utility methods used by Spark. + */ object Utils { def serialize[T](o: T): Array[Byte] = { - val bos = new ByteArrayOutputStream + val bos = new ByteArrayOutputStream() val oos = new ObjectOutputStream(bos) oos.writeObject(o) oos.close @@ -50,4 +56,72 @@ object Utils { } return buf } + + // Create a temporary directory inside the given parent directory + def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = + { + var attempts = 0 + val maxAttempts = 10 + var dir: File = null + while (dir == null) { + attempts += 1 + if (attempts > maxAttempts) { + throw new IOException("Failed to create a temp directory " + + "after " + maxAttempts + " attempts!") + } + try { + dir = new File(root, "spark-" + UUID.randomUUID.toString) + if (dir.exists() || !dir.mkdirs()) { + dir = null + } + } catch { case e: IOException => ; } + } + return dir + } + + // Copy all data from an InputStream to an OutputStream + def copyStream(in: InputStream, + out: OutputStream, + closeStreams: Boolean = false) + { + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = in.read(buf) + if (n != -1) { + out.write(buf, 0, n) + } + } + if (closeStreams) { + in.close() + out.close() + } + } + + // Shuffle the elements of a collection into a random order, returning the + // result in a new collection. Unlike scala.util.Random.shuffle, this method + // uses a local random number generator, avoiding inter-thread contention. + def shuffle[T](seq: TraversableOnce[T]): Seq[T] = { + val buf = new ArrayBuffer[T]() + buf ++= seq + val rand = new Random() + for (i <- (buf.size - 1) to 1 by -1) { + val j = rand.nextInt(i) + val tmp = buf(j) + buf(j) = buf(i) + buf(i) = tmp + } + buf + } + + /** + * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4) + */ + def localIpAddress(): String = { + // Get local IP as an array of four bytes + val bytes = InetAddress.getLocalHost().getAddress() + // Convert the bytes to ints (keeping in mind that they may be negative) + // and join them into a string + return bytes.map(b => (b.toInt + 256) % 256).mkString(".") + } } diff --git a/src/scala/spark/WeakReferenceCache.scala b/src/scala/spark/WeakReferenceCache.scala new file mode 100644 index 0000000000..ddca065454 --- /dev/null +++ b/src/scala/spark/WeakReferenceCache.scala @@ -0,0 +1,14 @@ +package spark + +import com.google.common.collect.MapMaker + +/** + * An implementation of Cache that uses weak references. + */ +class WeakReferenceCache extends Cache { + val map = new MapMaker().weakValues().makeMap[Any, Any]() + + override def get(key: Any): Any = map.get(key) + override def put(key: Any, value: Any) = map.put(key, value) +} + diff --git a/src/scala/spark/repl/SparkInterpreter.scala b/src/scala/spark/repl/SparkInterpreter.scala index ae2e7e8a68..10ea346658 100644 --- a/src/scala/spark/repl/SparkInterpreter.scala +++ b/src/scala/spark/repl/SparkInterpreter.scala @@ -36,6 +36,9 @@ import scala.tools.nsc.{ InterpreterResults => IR } import interpreter._ import SparkInterpreter._ +import spark.HttpServer +import spark.Utils + /** <p> * An interpreter for Scala code. * </p> @@ -92,27 +95,12 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { /** Local directory to save .class files too */ val outputDir = { - val rootDir = new File(System.getProperty("spark.repl.classdir", - System.getProperty("java.io.tmpdir"))) - var attempts = 0 - val maxAttempts = 10 - var dir: File = null - while (dir == null) { - attempts += 1 - if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory " + - "after " + maxAttempts + " attempts!") - } - try { - dir = new File(rootDir, "spark-" + UUID.randomUUID.toString) - if (dir.exists() || !dir.mkdirs()) - dir = null - } catch { case e: IOException => ; } - } - if (SPARK_DEBUG_REPL) { - println("Output directory: " + dir) - } - dir + val tmp = System.getProperty("java.io.tmpdir") + val rootDir = System.getProperty("spark.repl.classdir", tmp) + Utils.createTempDir(rootDir) + } + if (SPARK_DEBUG_REPL) { + println("Output directory: " + outputDir) } /** Scala compiler virtual directory for outputDir */ @@ -120,14 +108,14 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) { val virtualDirectory = new PlainFile(outputDir) /** Jetty server that will serve our classes to worker nodes */ - val classServer = new ClassServer(outputDir) + val classServer = new HttpServer(outputDir) // Start the classServer and store its URI in a spark system property // (which will be passed to executors so that they can connect to it) classServer.start() System.setProperty("spark.repl.class.uri", classServer.uri) if (SPARK_DEBUG_REPL) { - println("ClassServer started, URI = " + classServer.uri) + println("Class server started, URI = " + classServer.uri) } /** reporter */ diff --git a/src/test/spark/ShuffleSuite.scala b/src/test/spark/ShuffleSuite.scala new file mode 100644 index 0000000000..a5773614e8 --- /dev/null +++ b/src/test/spark/ShuffleSuite.scala @@ -0,0 +1,130 @@ +package spark + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import SparkContext._ + +class ShuffleSuite extends FunSuite { + test("groupByKey") { + val sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with duplicates") { + val sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with negative key hash codes") { + val sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1))) + val groups = pairs.groupByKey().collect() + assert(groups.size === 2) + val valuesForMinus1 = groups.find(_._1 == -1).get._2 + assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("groupByKey with many output partitions") { + val sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) + val groups = pairs.groupByKey(10).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("reduceByKey") { + val sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("reduceByKey with collectAsMap") { + val sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_).collectAsMap() + assert(sums.size === 2) + assert(sums(1) === 7) + assert(sums(2) === 1) + } + + test("reduceByKey with many output partitons") { + val sc = new SparkContext("local", "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) + val sums = pairs.reduceByKey(_+_, 10).collect() + assert(sums.toSet === Set((1, 7), (2, 1))) + } + + test("join") { + val sc = new SparkContext("local", "test") + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } + + test("join all-to-all") { + val sc = new SparkContext("local", "test") + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3))) + val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 6) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (1, 'y')), + (1, (2, 'x')), + (1, (2, 'y')), + (1, (3, 'x')), + (1, (3, 'y')) + )) + } + + test("join with no matches") { + val sc = new SparkContext("local", "test") + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) + val joined = rdd1.join(rdd2).collect() + assert(joined.size === 0) + } + + test("join with many output partitions") { + val sc = new SparkContext("local", "test") + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) + val joined = rdd1.join(rdd2, 10).collect() + assert(joined.size === 4) + assert(joined.toSet === Set( + (1, (1, 'x')), + (1, (2, 'x')), + (2, (1, 'y')), + (2, (1, 'z')) + )) + } +} diff --git a/src/test/spark/repl/ReplSuite.scala b/src/test/spark/repl/ReplSuite.scala index dcf71182ec..8b38cde85f 100644 --- a/src/test/spark/repl/ReplSuite.scala +++ b/src/test/spark/repl/ReplSuite.scala @@ -39,9 +39,9 @@ class ReplSuite extends FunSuite { test ("external vars") { val output = runInterpreter("local", """ var v = 7 - sc.parallelize(1 to 10).map(x => v).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) v = 10 - sc.parallelize(1 to 10).map(x => v).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -54,7 +54,7 @@ class ReplSuite extends FunSuite { class C { def foo = 5 } - sc.parallelize(1 to 10).map(x => (new C).foo).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_) """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -64,7 +64,7 @@ class ReplSuite extends FunSuite { test ("external functions") { val output = runInterpreter("local", """ def double(x: Int) = x + x - sc.parallelize(1 to 10).map(x => double(x)).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -75,9 +75,9 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ var v = 7 def getV() = v - sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) v = 10 - sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -92,9 +92,9 @@ class ReplSuite extends FunSuite { val output = runInterpreter("local", """ var array = new Array[Int](5) val broadcastArray = sc.broadcast(array) - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect array(0) = 5 - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect """) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) @@ -102,24 +102,26 @@ class ReplSuite extends FunSuite { assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output) } - test ("running on Mesos") { - val output = runInterpreter("localquiet", """ - var v = 7 - def getV() = v - sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) - v = 10 - sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_) - var array = new Array[Int](5) - val broadcastArray = sc.broadcast(array) - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray - array(0) = 5 - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray - """) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - assertContains("res0: Int = 70", output) - assertContains("res1: Int = 100", output) - assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) - assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + if (System.getenv("MESOS_HOME") != null) { + test ("running on Mesos") { + val output = runInterpreter("localquiet", """ + var v = 7 + def getV() = v + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + v = 10 + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + var array = new Array[Int](5) + val broadcastArray = sc.broadcast(array) + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + array(0) = 5 + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } } } diff --git a/third_party/guava-r06/guava-r06.jar b/third_party/guava-r06/guava-r06.jar Binary files differdeleted file mode 100644 index 8ff3a81748..0000000000 --- a/third_party/guava-r06/guava-r06.jar +++ /dev/null diff --git a/third_party/guava-r06/COPYING b/third_party/guava-r07/COPYING index d645695673..d645695673 100644 --- a/third_party/guava-r06/COPYING +++ b/third_party/guava-r07/COPYING diff --git a/third_party/guava-r06/README b/third_party/guava-r07/README index a0e832dd54..a0e832dd54 100644 --- a/third_party/guava-r06/README +++ b/third_party/guava-r07/README diff --git a/third_party/guava-r07/guava-r07.jar b/third_party/guava-r07/guava-r07.jar Binary files differnew file mode 100644 index 0000000000..a6c9ce02df --- /dev/null +++ b/third_party/guava-r07/guava-r07.jar diff --git a/third_party/mesos.jar b/third_party/mesos.jar Binary files differindex 1852cf8fd0..60d299c8af 100644 --- a/third_party/mesos.jar +++ b/third_party/mesos.jar |