aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMosharaf Chowdhury <mosharaf@mosharaf-ubuntu.(none)>2010-11-27 01:27:20 -0800
committerMosharaf Chowdhury <mosharaf@mosharaf-ubuntu.(none)>2010-11-27 01:27:20 -0800
commit19a6b194a9aceb3d7997f0acfc99530ac2792be4 (patch)
tree7565c6f7c334f13b583d22d7fdfb569e3dd12671
parente4b8db45aef934929dbab443156375aebb1ea45e (diff)
parentf8ea98d9894d72feb7e8cd3951a576b24b448397 (diff)
downloadspark-19a6b194a9aceb3d7997f0acfc99530ac2792be4.tar.gz
spark-19a6b194a9aceb3d7997f0acfc99530ac2792be4.tar.bz2
spark-19a6b194a9aceb3d7997f0acfc99530ac2792be4.zip
Merge branch 'master' into multi-tracker
Conflicts: Makefile run src/scala/spark/Broadcast.scala src/scala/spark/Executor.scala src/scala/spark/HdfsFile.scala src/scala/spark/MesosScheduler.scala src/scala/spark/RDD.scala src/scala/spark/SparkContext.scala src/scala/spark/Split.scala src/scala/spark/Utils.scala src/scala/spark/repl/SparkInterpreter.scala third_party/mesos.jar
-rw-r--r--.gitignore5
-rw-r--r--Makefile17
-rwxr-xr-xalltests12
-rw-r--r--conf/java-opts.template0
-rw-r--r--conf/log4j.properties.template8
-rwxr-xr-xconf/spark-env.sh.template13
-rwxr-xr-xrun18
-rw-r--r--src/examples/BroadcastTest.scala18
-rw-r--r--src/examples/SparkPi.scala2
-rw-r--r--src/scala/spark/BoundedMemoryCache.scala69
-rw-r--r--src/scala/spark/Cache.scala63
-rw-r--r--src/scala/spark/DfsShuffle.scala120
-rw-r--r--src/scala/spark/Executor.scala153
-rw-r--r--src/scala/spark/HadoopFile.scala118
-rw-r--r--src/scala/spark/HdfsFile.scala80
-rw-r--r--src/scala/spark/HttpServer.scala67
-rw-r--r--src/scala/spark/Job.scala18
-rw-r--r--src/scala/spark/LocalFileShuffle.scala171
-rw-r--r--src/scala/spark/MesosScheduler.scala394
-rw-r--r--src/scala/spark/NumberedSplitRDD.scala42
-rw-r--r--src/scala/spark/RDD.scala158
-rw-r--r--src/scala/spark/Shuffle.scala15
-rw-r--r--src/scala/spark/SimpleJob.scala272
-rw-r--r--src/scala/spark/SizeEstimator.scala160
-rw-r--r--src/scala/spark/SoftReferenceCache.scala13
-rw-r--r--src/scala/spark/SparkContext.scala153
-rw-r--r--src/scala/spark/SparkException.scala6
-rw-r--r--src/scala/spark/Split.scala2
-rw-r--r--src/scala/spark/Utils.scala76
-rw-r--r--src/scala/spark/WeakReferenceCache.scala14
-rw-r--r--src/scala/spark/repl/SparkInterpreter.scala34
-rw-r--r--src/test/spark/ShuffleSuite.scala130
-rw-r--r--src/test/spark/repl/ReplSuite.scala56
-rw-r--r--third_party/guava-r06/guava-r06.jarbin934385 -> 0 bytes
-rw-r--r--third_party/guava-r07/COPYING (renamed from third_party/guava-r06/COPYING)0
-rw-r--r--third_party/guava-r07/README (renamed from third_party/guava-r06/README)0
-rw-r--r--third_party/guava-r07/guava-r07.jarbin0 -> 1075964 bytes
-rw-r--r--third_party/mesos.jarbin34562 -> 33618 bytes
38 files changed, 1957 insertions, 520 deletions
diff --git a/.gitignore b/.gitignore
index 2d12458a44..5abdec5d50 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,3 +3,8 @@
build
work
.DS_Store
+third_party/libmesos.so
+third_party/libmesos.dylib
+conf/java-opts
+conf/spark-env.sh
+conf/log4j.properties
diff --git a/Makefile b/Makefile
index c5d004fb10..15ab516d1f 100644
--- a/Makefile
+++ b/Makefile
@@ -5,7 +5,7 @@ SPACE = $(EMPTY) $(EMPTY)
JARS = third_party/mesos.jar
JARS += third_party/asm-3.2/lib/all/asm-all-3.2.jar
JARS += third_party/colt.jar
-JARS += third_party/guava-r06/guava-r06.jar
+JARS += third_party/guava-r07/guava-r07.jar
JARS += third_party/hadoop-0.20.0/hadoop-0.20.0-core.jar
JARS += third_party/hadoop-0.20.0/lib/commons-logging-1.0.4.jar
JARS += third_party/scalatest-1.2/scalatest-1.2.jar
@@ -34,13 +34,15 @@ else
COMPILER = $(SCALA_HOME)/bin/$(COMPILER_NAME)
endif
-all: scala java
+CONF_FILES = conf/spark-env.sh conf/log4j.properties conf/java-opts
+
+all: scala java conf-files
build/classes:
mkdir -p build/classes
scala: build/classes java
- $(COMPILER) -unchecked -d build/classes -classpath build/classes:$(CLASSPATH) $(SCALA_SOURCES)
+ $(COMPILER) -d build/classes -classpath build/classes:$(CLASSPATH) $(SCALA_SOURCES)
java: $(JAVA_SOURCES) build/classes
javac -d build/classes $(JAVA_SOURCES)
@@ -50,6 +52,8 @@ native: java
jar: build/spark.jar build/spark-dep.jar
+dep-jar: build/spark-dep.jar
+
build/spark.jar: scala java
jar cf build/spark.jar -C build/classes spark
@@ -58,6 +62,11 @@ build/spark-dep.jar:
cd build/dep && for i in $(JARS); do jar xf ../../$$i; done
jar cf build/spark-dep.jar -C build/dep .
+conf-files: $(CONF_FILES)
+
+$(CONF_FILES): %: | %.template
+ cp $@.template $@
+
test: all
./alltests
@@ -67,4 +76,4 @@ clean:
$(MAKE) -C src/native clean
rm -rf build
-.phony: default all clean scala java native jar
+.phony: default all clean scala java native jar dep-jar conf-files
diff --git a/alltests b/alltests
index 3c9db301c4..cd11604855 100755
--- a/alltests
+++ b/alltests
@@ -1,3 +1,11 @@
#!/bin/bash
-FWDIR=`dirname $0`
-$FWDIR/run org.scalatest.tools.Runner -p $FWDIR/build/classes -o $@
+FWDIR="`dirname $0`"
+if [ "x$SPARK_MEM" == "x" ]; then
+ export SPARK_MEM=500m
+fi
+RESULTS_DIR="$FWDIR/build/test_results"
+if [ -d $RESULTS_DIR ]; then
+ rm -r $RESULTS_DIR
+fi
+mkdir -p $RESULTS_DIR
+$FWDIR/run org.scalatest.tools.Runner -p $FWDIR/build/classes -u $RESULTS_DIR -o $@
diff --git a/conf/java-opts.template b/conf/java-opts.template
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/conf/java-opts.template
diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template
new file mode 100644
index 0000000000..d72dbadc39
--- /dev/null
+++ b/conf/log4j.properties.template
@@ -0,0 +1,8 @@
+# Set everything to be logged to the console
+log4j.rootCategory=INFO, console
+log4j.appender.console=org.apache.log4j.ConsoleAppender
+log4j.appender.console.layout=org.apache.log4j.PatternLayout
+log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.eclipse.jetty=WARN
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
new file mode 100755
index 0000000000..6852b23a34
--- /dev/null
+++ b/conf/spark-env.sh.template
@@ -0,0 +1,13 @@
+#!/usr/bin/env bash
+
+# Set Spark environment variables for your site in this file. Some useful
+# variables to set are:
+# - MESOS_HOME, to point to your Mesos installation
+# - SCALA_HOME, to point to your Scala installation
+# - SPARK_CLASSPATH, to add elements to Spark's classpath
+# - SPARK_JAVA_OPTS, to add JVM options
+# - SPARK_MEM, to change the amount of memory used per node (this should
+# be in the same format as the JVM's -Xmx option, e.g. 300m or 1g).
+# - SPARK_LIBRARY_PATH, to add extra search paths for native libraries.
+
+
diff --git a/run b/run
index f28b39af9b..d6f7d920c5 100755
--- a/run
+++ b/run
@@ -1,16 +1,22 @@
#!/bin/bash
# Figure out where the Scala framework is installed
-FWDIR=`dirname $0`
+FWDIR="$(cd `dirname $0`; pwd)"
+
+# Export this as SPARK_HOME
+export SPARK_HOME="$FWDIR"
# Load environment variables from conf/spark-env.sh, if it exists
if [ -e $FWDIR/conf/spark-env.sh ] ; then
. $FWDIR/conf/spark-env.sh
fi
+MESOS_CLASSPATH=""
+MESOS_LIBRARY_PATH=""
+
if [ "x$MESOS_HOME" != "x" ] ; then
- SPARK_CLASSPATH="$MESOS_HOME/lib/java/mesos.jar:$SPARK_CLASSPATH"
- SPARK_LIBRARY_PATH="$MESOS_HOME/lib/java:$SPARK_LIBARY_PATH"
+ MESOS_CLASSPATH="$MESOS_HOME/lib/java/mesos.jar"
+ MESOS_LIBRARY_PATH="$MESOS_HOME/lib/java"
fi
if [ "x$SPARK_MEM" == "x" ] ; then
@@ -19,7 +25,7 @@ fi
# Set JAVA_OPTS to be able to load native libraries and to set heap size
JAVA_OPTS="$SPARK_JAVA_OPTS"
-JAVA_OPTS+=" -Djava.library.path=$SPARK_LIBRARY_PATH:$FWDIR/third_party:$FWDIR/src/native"
+JAVA_OPTS+=" -Djava.library.path=$SPARK_LIBRARY_PATH:$FWDIR/third_party:$FWDIR/src/native:$MESOS_LIBRARY_PATH"
JAVA_OPTS+=" -Xms$SPARK_MEM -Xmx$SPARK_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
if [ -e $FWDIR/conf/java-opts ] ; then
@@ -28,12 +34,12 @@ fi
export JAVA_OPTS
# Build up classpath
-CLASSPATH="$SPARK_CLASSPATH:$FWDIR/build/classes"
+CLASSPATH="$SPARK_CLASSPATH:$FWDIR/build/classes:$MESOS_CLASSPATH"
CLASSPATH+=:$FWDIR/conf
CLASSPATH+=:$FWDIR/third_party/mesos.jar
CLASSPATH+=:$FWDIR/third_party/asm-3.2/lib/all/asm-all-3.2.jar
CLASSPATH+=:$FWDIR/third_party/colt.jar
-CLASSPATH+=:$FWDIR/third_party/guava-r06/guava-r06.jar
+CLASSPATH+=:$FWDIR/third_party/guava-r07/guava-r07.jar
CLASSPATH+=:$FWDIR/third_party/hadoop-0.20.0/hadoop-0.20.0-core.jar
CLASSPATH+=:$FWDIR/third_party/scalatest-1.2/scalatest-1.2.jar
CLASSPATH+=:$FWDIR/third_party/scalacheck_2.8.0-1.7.jar
diff --git a/src/examples/BroadcastTest.scala b/src/examples/BroadcastTest.scala
index 7764013413..40c2be8f6d 100644
--- a/src/examples/BroadcastTest.scala
+++ b/src/examples/BroadcastTest.scala
@@ -10,15 +10,19 @@ object BroadcastTest {
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
- var arr = new Array[Int](num)
- for (i <- 0 until arr.length)
- arr(i) = i
+ var arr1 = new Array[Int](num)
+ for (i <- 0 until arr1.length)
+ arr1(i) = i
- val barr = spark.broadcast(arr)
+// var arr2 = new Array[Int](num * 2)
+// for (i <- 0 until arr2.length)
+// arr2(i) = i
+
+ val barr1 = spark.broadcast(arr1)
+// val barr2 = spark.broadcast(arr2)
spark.parallelize(1 to 10, slices).foreach {
- println("in task: barr = " + barr)
- i => println(barr.value.size)
+// i => println(barr1.value.size + barr2.value.size)
+ i => println(barr1.value.size)
}
}
}
-
diff --git a/src/examples/SparkPi.scala b/src/examples/SparkPi.scala
index 07311908ee..f055614125 100644
--- a/src/examples/SparkPi.scala
+++ b/src/examples/SparkPi.scala
@@ -5,7 +5,7 @@ import SparkContext._
object SparkPi {
def main(args: Array[String]) {
if (args.length == 0) {
- System.err.println("Usage: SparkLR <host> [<slices>]")
+ System.err.println("Usage: SparkPi <host> [<slices>]")
System.exit(1)
}
val spark = new SparkContext(args(0), "SparkPi")
diff --git a/src/scala/spark/BoundedMemoryCache.scala b/src/scala/spark/BoundedMemoryCache.scala
new file mode 100644
index 0000000000..19d9bebfe5
--- /dev/null
+++ b/src/scala/spark/BoundedMemoryCache.scala
@@ -0,0 +1,69 @@
+package spark
+
+import java.util.LinkedHashMap
+
+/**
+ * An implementation of Cache that estimates the sizes of its entries and
+ * attempts to limit its total memory usage to a fraction of the JVM heap.
+ * Objects' sizes are estimated using SizeEstimator, which has limitations;
+ * most notably, we will overestimate total memory used if some cache
+ * entries have pointers to a shared object. Nonetheless, this Cache should
+ * work well when most of the space is used by arrays of primitives or of
+ * simple classes.
+ */
+class BoundedMemoryCache extends Cache with Logging {
+ private val maxBytes: Long = getMaxBytes()
+ logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
+
+ private var currentBytes = 0L
+ private val map = new LinkedHashMap[Any, Entry](32, 0.75f, true)
+
+ // An entry in our map; stores a cached object and its size in bytes
+ class Entry(val value: Any, val size: Long) {}
+
+ override def get(key: Any): Any = {
+ synchronized {
+ val entry = map.get(key)
+ if (entry != null) entry.value else null
+ }
+ }
+
+ override def put(key: Any, value: Any) {
+ logInfo("Asked to add key " + key)
+ val startTime = System.currentTimeMillis
+ val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef])
+ val timeTaken = System.currentTimeMillis - startTime
+ logInfo("Estimated size for key %s is %d".format(key, size))
+ logInfo("Size estimation for key %s took %d ms".format(key, timeTaken))
+ synchronized {
+ ensureFreeSpace(size)
+ logInfo("Adding key " + key)
+ map.put(key, new Entry(value, size))
+ currentBytes += size
+ logInfo("Number of entries is now " + map.size)
+ }
+ }
+
+ private def getMaxBytes(): Long = {
+ val memoryFractionToUse = System.getProperty(
+ "spark.boundedMemoryCache.memoryFraction", "0.75").toDouble
+ (Runtime.getRuntime.totalMemory * memoryFractionToUse).toLong
+ }
+
+ /**
+ * Remove least recently used entries from the map until at least space
+ * bytes are free. Assumes that a lock is held on the BoundedMemoryCache.
+ */
+ private def ensureFreeSpace(space: Long) {
+ logInfo("ensureFreeSpace(%d) called with curBytes=%d, maxBytes=%d".format(
+ space, currentBytes, maxBytes))
+ val iter = map.entrySet.iterator
+ while (maxBytes - currentBytes < space && iter.hasNext) {
+ val mapEntry = iter.next()
+ logInfo("Dropping key %s of size %d to make space".format(
+ mapEntry.getKey, mapEntry.getValue.size))
+ currentBytes -= mapEntry.getValue.size
+ iter.remove()
+ }
+ }
+}
diff --git a/src/scala/spark/Cache.scala b/src/scala/spark/Cache.scala
new file mode 100644
index 0000000000..9887520758
--- /dev/null
+++ b/src/scala/spark/Cache.scala
@@ -0,0 +1,63 @@
+package spark
+
+import java.util.concurrent.atomic.AtomicLong
+
+
+/**
+ * An interface for caches in Spark, to allow for multiple implementations.
+ * Caches are used to store both partitions of cached RDDs and broadcast
+ * variables on Spark executors.
+ *
+ * A single Cache instance gets created on each machine and is shared by all
+ * caches (i.e. both the RDD split cache and the broadcast variable cache),
+ * to enable global replacement policies. However, because these several
+ * independent modules all perform caching, it is important to give them
+ * separate key namespaces, so that an RDD and a broadcast variable (for
+ * example) do not use the same key. For this purpose, Cache has the
+ * notion of KeySpaces. Each client module must first ask for a KeySpace,
+ * and then call get() and put() on that space using its own keys.
+ * This abstract class handles the creation of key spaces, so that subclasses
+ * need only deal with keys that are unique across modules.
+ */
+abstract class Cache {
+ private val nextKeySpaceId = new AtomicLong(0)
+ private def newKeySpaceId() = nextKeySpaceId.getAndIncrement()
+
+ def newKeySpace() = new KeySpace(this, newKeySpaceId())
+
+ def get(key: Any): Any
+ def put(key: Any, value: Any): Unit
+}
+
+
+/**
+ * A key namespace in a Cache.
+ */
+class KeySpace(cache: Cache, id: Long) {
+ def get(key: Any): Any = cache.get((id, key))
+ def put(key: Any, value: Any): Unit = cache.put((id, key), value)
+}
+
+
+/**
+ * The Cache object maintains a global Cache instance, of the type specified
+ * by the spark.cache.class property.
+ */
+object Cache {
+ private var instance: Cache = null
+
+ def initialize() {
+ val cacheClass = System.getProperty("spark.cache.class",
+ "spark.SoftReferenceCache")
+ instance = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
+ }
+
+ def getInstance(): Cache = {
+ if (instance == null) {
+ throw new SparkException("Cache.getInstance called before initialize")
+ }
+ instance
+ }
+
+ def newKeySpace(): KeySpace = getInstance().newKeySpace()
+}
diff --git a/src/scala/spark/DfsShuffle.scala b/src/scala/spark/DfsShuffle.scala
new file mode 100644
index 0000000000..7a42bf2d06
--- /dev/null
+++ b/src/scala/spark/DfsShuffle.scala
@@ -0,0 +1,120 @@
+package spark
+
+import java.io.{EOFException, ObjectInputStream, ObjectOutputStream}
+import java.net.URI
+import java.util.UUID
+
+import scala.collection.mutable.HashMap
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
+
+
+/**
+ * A simple implementation of shuffle using a distributed file system.
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class DfsShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val dir = DfsShuffle.newTempDirectory()
+ logInfo("Intermediate data directory: " + dir)
+
+ val numberedSplitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = numberedSplitRdd.splits.size
+
+ // Run a parallel foreach to write the intermediate data files
+ numberedSplitRdd.foreach((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+ val fs = DfsShuffle.getFileSystem()
+ for (i <- 0 until numOutputSplits) {
+ val path = new Path(dir, "%d-to-%d".format(myIndex, i))
+ val out = new ObjectOutputStream(fs.create(path, true))
+ buckets(i).foreach(pair => out.writeObject(pair))
+ out.close()
+ }
+ })
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myIndex: Int) => {
+ val combiners = new HashMap[K, C]
+ val fs = DfsShuffle.getFileSystem()
+ for (i <- Utils.shuffle(0 until numInputSplits)) {
+ val path = new Path(dir, "%d-to-%d".format(i, myIndex))
+ val inputStream = new ObjectInputStream(fs.open(path))
+ try {
+ while (true) {
+ val (k, c) = inputStream.readObject().asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => {}
+ }
+ inputStream.close()
+ }
+ combiners
+ })
+ }
+}
+
+
+/**
+ * Companion object of DfsShuffle; responsible for initializing a Hadoop
+ * FileSystem object based on the spark.dfs property and generating names
+ * for temporary directories.
+ */
+object DfsShuffle {
+ private var initialized = false
+ private var fileSystem: FileSystem = null
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val dfs = System.getProperty("spark.dfs", "file:///")
+ val conf = new Configuration()
+ conf.setInt("io.file.buffer.size", bufferSize)
+ conf.setInt("dfs.replication", 1)
+ fileSystem = FileSystem.get(new URI(dfs), conf)
+ initialized = true
+ }
+ }
+
+ def getFileSystem(): FileSystem = {
+ initializeIfNeeded()
+ return fileSystem
+ }
+
+ def newTempDirectory(): String = {
+ val fs = getFileSystem()
+ val workDir = System.getProperty("spark.dfs.workdir", "/tmp")
+ val uuid = UUID.randomUUID()
+ val path = workDir + "/shuffle-" + uuid
+ fs.mkdirs(new Path(path))
+ return path
+ }
+}
diff --git a/src/scala/spark/Executor.scala b/src/scala/spark/Executor.scala
index be73aae541..b4d023b428 100644
--- a/src/scala/spark/Executor.scala
+++ b/src/scala/spark/Executor.scala
@@ -1,75 +1,116 @@
package spark
+import java.io.{File, FileOutputStream}
+import java.net.{URI, URL, URLClassLoader}
import java.util.concurrent.{Executors, ExecutorService}
+import scala.collection.mutable.ArrayBuffer
+
import mesos.{ExecutorArgs, ExecutorDriver, MesosExecutorDriver}
import mesos.{TaskDescription, TaskState, TaskStatus}
/**
* The Mesos executor for Spark.
*/
-object Executor extends Logging {
- def main(args: Array[String]) {
- System.loadLibrary("mesos")
+class Executor extends mesos.Executor with Logging {
+ var classLoader: ClassLoader = null
+ var threadPool: ExecutorService = null
- // Create a new Executor implementation that will run our tasks
- val exec = new mesos.Executor() {
- var classLoader: ClassLoader = null
- var threadPool: ExecutorService = null
+ override def init(d: ExecutorDriver, args: ExecutorArgs) {
+ // Read spark.* system properties from executor arg
+ val props = Utils.deserialize[Array[(String, String)]](args.getData)
+ for ((key, value) <- props)
+ System.setProperty(key, value)
- override def init(d: ExecutorDriver, args: ExecutorArgs) {
- // Read spark.* system properties
- val props = Utils.deserialize[Array[(String, String)]](args.getData)
- for ((key, value) <- props)
- System.setProperty(key, value)
-
- // Initialize broadcast system (uses some properties read above)
- Broadcast.initialize(false)
-
- // If the REPL is in use, create a ClassLoader that will be able to
- // read new classes defined by the REPL as the user types code
- classLoader = this.getClass.getClassLoader
- val classUri = System.getProperty("spark.repl.class.uri")
- if (classUri != null) {
- logInfo("Using REPL class URI: " + classUri)
- classLoader = new repl.ExecutorClassLoader(classUri, classLoader)
- }
- Thread.currentThread.setContextClassLoader(classLoader)
-
- // Start worker thread pool (they will inherit our context ClassLoader)
- threadPool = Executors.newCachedThreadPool()
- }
-
- override def launchTask(d: ExecutorDriver, desc: TaskDescription) {
- // Pull taskId and arg out of TaskDescription because it won't be a
- // valid pointer after this method call (TODO: fix this in C++/SWIG)
- val taskId = desc.getTaskId
- val arg = desc.getArg
- threadPool.execute(new Runnable() {
- def run() = {
- logInfo("Running task ID " + taskId)
- try {
- Accumulators.clear
- val task = Utils.deserialize[Task[Any]](arg, classLoader)
- val value = task.run
- val accumUpdates = Accumulators.values
- val result = new TaskResult(value, accumUpdates)
- d.sendStatusUpdate(new TaskStatus(
- taskId, TaskState.TASK_FINISHED, Utils.serialize(result)))
- logInfo("Finished task ID " + taskId)
- } catch {
- case e: Exception => {
- // TODO: Handle errors in tasks less dramatically
- logError("Exception in task ID " + taskId, e)
- System.exit(1)
- }
- }
+ // Initialize cache and broadcast system (uses some properties read above)
+ Cache.initialize()
+ Broadcast.initialize(false)
+
+ // Create our ClassLoader (using spark properties) and set it on this thread
+ classLoader = createClassLoader()
+ Thread.currentThread.setContextClassLoader(classLoader)
+
+ // Start worker thread pool (they will inherit our context ClassLoader)
+ threadPool = Executors.newCachedThreadPool()
+ }
+
+ override def launchTask(d: ExecutorDriver, desc: TaskDescription) {
+ // Pull taskId and arg out of TaskDescription because it won't be a
+ // valid pointer after this method call (TODO: fix this in C++/SWIG)
+ val taskId = desc.getTaskId
+ val arg = desc.getArg
+ threadPool.execute(new Runnable() {
+ def run() = {
+ logInfo("Running task ID " + taskId)
+ try {
+ Accumulators.clear
+ val task = Utils.deserialize[Task[Any]](arg, classLoader)
+ val value = task.run
+ val accumUpdates = Accumulators.values
+ val result = new TaskResult(value, accumUpdates)
+ d.sendStatusUpdate(new TaskStatus(
+ taskId, TaskState.TASK_FINISHED, Utils.serialize(result)))
+ logInfo("Finished task ID " + taskId)
+ } catch {
+ case e: Exception => {
+ // TODO: Handle errors in tasks less dramatically
+ logError("Exception in task ID " + taskId, e)
+ System.exit(1)
}
- })
+ }
}
+ })
+ }
+
+ // Create a ClassLoader for use in tasks, adding any JARs specified by the
+ // user or any classes created by the interpreter to the search path
+ private def createClassLoader(): ClassLoader = {
+ var loader = this.getClass.getClassLoader
+
+ // If any JAR URIs are given through spark.jar.uris, fetch them to the
+ // current directory and put them all on the classpath. We assume that
+ // each URL has a unique file name so that no local filenames will clash
+ // in this process. This is guaranteed by MesosScheduler.
+ val uris = System.getProperty("spark.jar.uris", "")
+ val localFiles = ArrayBuffer[String]()
+ for (uri <- uris.split(",").filter(_.size > 0)) {
+ val url = new URL(uri)
+ val filename = url.getPath.split("/").last
+ downloadFile(url, filename)
+ localFiles += filename
+ }
+ if (localFiles.size > 0) {
+ val urls = localFiles.map(f => new File(f).toURI.toURL).toArray
+ loader = new URLClassLoader(urls, loader)
}
- // Start it running and connect it to the slave
+ // If the REPL is in use, add another ClassLoader that will read
+ // new classes defined by the REPL as the user types code
+ val classUri = System.getProperty("spark.repl.class.uri")
+ if (classUri != null) {
+ logInfo("Using REPL class URI: " + classUri)
+ loader = new repl.ExecutorClassLoader(classUri, loader)
+ }
+
+ return loader
+ }
+
+ // Download a file from a given URL to the local filesystem
+ private def downloadFile(url: URL, localPath: String) {
+ val in = url.openStream()
+ val out = new FileOutputStream(localPath)
+ Utils.copyStream(in, out, true)
+ }
+}
+
+/**
+ * Executor entry point.
+ */
+object Executor extends Logging {
+ def main(args: Array[String]) {
+ System.loadLibrary("mesos")
+ // Create a new Executor and start it running
+ val exec = new Executor
new MesosExecutorDriver(exec).run()
}
}
diff --git a/src/scala/spark/HadoopFile.scala b/src/scala/spark/HadoopFile.scala
new file mode 100644
index 0000000000..a63c9d8a94
--- /dev/null
+++ b/src/scala/spark/HadoopFile.scala
@@ -0,0 +1,118 @@
+package spark
+
+import mesos.SlaveOffer
+
+import org.apache.hadoop.io.LongWritable
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapred.FileInputFormat
+import org.apache.hadoop.mapred.InputFormat
+import org.apache.hadoop.mapred.InputSplit
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.TextInputFormat
+import org.apache.hadoop.mapred.RecordReader
+import org.apache.hadoop.mapred.Reporter
+import org.apache.hadoop.util.ReflectionUtils
+
+/** A Spark split class that wraps around a Hadoop InputSplit */
+@serializable class HadoopSplit(@transient s: InputSplit)
+extends Split {
+ val inputSplit = new SerializableWritable[InputSplit](s)
+
+ // Hadoop gives each split a unique toString value, so use this as our ID
+ override def getId() = "HadoopSplit(" + inputSplit.toString + ")"
+}
+
+
+/**
+ * An RDD that reads a Hadoop file (from HDFS, S3, the local filesystem, etc)
+ * and represents it as a set of key-value pairs using a given InputFormat.
+ */
+class HadoopFile[K, V](
+ sc: SparkContext,
+ path: String,
+ inputFormatClass: Class[_ <: InputFormat[K, V]],
+ keyClass: Class[K],
+ valueClass: Class[V])
+extends RDD[(K, V)](sc) {
+ @transient val splits_ : Array[Split] = ConfigureLock.synchronized {
+ val conf = new JobConf()
+ FileInputFormat.setInputPaths(conf, path)
+ val inputFormat = createInputFormat(conf)
+ val inputSplits = inputFormat.getSplits(conf, sc.numCores)
+ inputSplits.map(x => new HadoopSplit(x): Split).toArray
+ }
+
+ def createInputFormat(conf: JobConf): InputFormat[K, V] = {
+ ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf)
+ .asInstanceOf[InputFormat[K, V]]
+ }
+
+ override def splits = splits_
+
+ override def iterator(theSplit: Split) = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[HadoopSplit]
+ var reader: RecordReader[K, V] = null
+
+ ConfigureLock.synchronized {
+ val conf = new JobConf()
+ val bufferSize = System.getProperty("spark.buffer.size", "65536")
+ conf.set("io.file.buffer.size", bufferSize)
+ val fmt = createInputFormat(conf)
+ reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
+ }
+
+ val key: K = keyClass.newInstance()
+ val value: V = valueClass.newInstance()
+ var gotNext = false
+ var finished = false
+
+ override def hasNext: Boolean = {
+ if (!gotNext) {
+ try {
+ finished = !reader.next(key, value)
+ } catch {
+ case eofe: java.io.EOFException =>
+ finished = true
+ }
+ gotNext = true
+ }
+ !finished
+ }
+
+ override def next: (K, V) = {
+ if (!gotNext) {
+ finished = !reader.next(key, value)
+ }
+ if (finished) {
+ throw new java.util.NoSuchElementException("End of stream")
+ }
+ gotNext = false
+ (key, value)
+ }
+ }
+
+ override def preferredLocations(split: Split) = {
+ // TODO: Filtering out "localhost" in case of file:// URLs
+ val hadoopSplit = split.asInstanceOf[HadoopSplit]
+ hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
+ }
+}
+
+
+/**
+ * Convenience class for Hadoop files read using TextInputFormat that
+ * represents the file as an RDD of Strings.
+ */
+class HadoopTextFile(sc: SparkContext, path: String)
+extends MappedRDD[String, (LongWritable, Text)](
+ new HadoopFile(sc, path, classOf[TextInputFormat],
+ classOf[LongWritable], classOf[Text]),
+ { pair: (LongWritable, Text) => pair._2.toString }
+)
+
+
+/**
+ * Object used to ensure that only one thread at a time is configuring Hadoop
+ * InputFormat classes. Apparently configuring them is not thread safe!
+ */
+object ConfigureLock {}
diff --git a/src/scala/spark/HdfsFile.scala b/src/scala/spark/HdfsFile.scala
deleted file mode 100644
index 8637c6e30a..0000000000
--- a/src/scala/spark/HdfsFile.scala
+++ /dev/null
@@ -1,80 +0,0 @@
-package spark
-
-import mesos.SlaveOffer
-
-import org.apache.hadoop.io.LongWritable
-import org.apache.hadoop.io.Text
-import org.apache.hadoop.mapred.FileInputFormat
-import org.apache.hadoop.mapred.InputSplit
-import org.apache.hadoop.mapred.JobConf
-import org.apache.hadoop.mapred.TextInputFormat
-import org.apache.hadoop.mapred.RecordReader
-import org.apache.hadoop.mapred.Reporter
-
-@serializable class HdfsSplit(@transient s: InputSplit)
-extends Split {
- val inputSplit = new SerializableWritable[InputSplit](s)
-
- override def getId() = inputSplit.toString // Hadoop makes this unique
- // for each split of each file
-}
-
-class HdfsTextFile(sc: SparkContext, path: String)
-extends RDD[String](sc) {
- @transient val conf = new JobConf()
- @transient val inputFormat = new TextInputFormat()
-
- FileInputFormat.setInputPaths(conf, path)
- ConfigureLock.synchronized { inputFormat.configure(conf) }
-
- @transient val splits_ =
- inputFormat.getSplits(conf, sc.scheduler.numCores).map(new HdfsSplit(_)).toArray
-
- override def splits = splits_.asInstanceOf[Array[Split]]
-
- override def iterator(split_in: Split) = new Iterator[String] {
- val split = split_in.asInstanceOf[HdfsSplit]
- var reader: RecordReader[LongWritable, Text] = null
- ConfigureLock.synchronized {
- val conf = new JobConf()
- conf.set("io.file.buffer.size",
- System.getProperty("spark.buffer.size", "65536"))
- val tif = new TextInputFormat()
- tif.configure(conf)
- reader = tif.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
- }
- val lineNum = new LongWritable()
- val text = new Text()
- var gotNext = false
- var finished = false
-
- override def hasNext: Boolean = {
- if (!gotNext) {
- try {
- finished = !reader.next(lineNum, text)
- } catch {
- case eofe: java.io.EOFException =>
- finished = true
- }
- gotNext = true
- }
- !finished
- }
-
- override def next: String = {
- if (!gotNext)
- finished = !reader.next(lineNum, text)
- if (finished)
- throw new java.util.NoSuchElementException("end of stream")
- gotNext = false
- text.toString
- }
- }
-
- override def preferredLocations(split: Split) = {
- // TODO: Filtering out "localhost" in case of file:// URLs
- split.asInstanceOf[HdfsSplit].inputSplit.value.getLocations().filter(_ != "localhost")
- }
-}
-
-object ConfigureLock {}
diff --git a/src/scala/spark/HttpServer.scala b/src/scala/spark/HttpServer.scala
new file mode 100644
index 0000000000..d2a663ac1f
--- /dev/null
+++ b/src/scala/spark/HttpServer.scala
@@ -0,0 +1,67 @@
+package spark
+
+import java.io.File
+import java.net.InetAddress
+
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.server.handler.DefaultHandler
+import org.eclipse.jetty.server.handler.HandlerList
+import org.eclipse.jetty.server.handler.ResourceHandler
+import org.eclipse.jetty.util.thread.QueuedThreadPool
+
+
+/**
+ * Exception type thrown by HttpServer when it is in the wrong state
+ * for an operation.
+ */
+class ServerStateException(message: String) extends Exception(message)
+
+
+/**
+ * An HTTP server for static content used to allow worker nodes to access JARs
+ * added to SparkContext as well as classes created by the interpreter when
+ * the user types in code. This is just a wrapper around a Jetty server.
+ */
+class HttpServer(resourceBase: File) extends Logging {
+ private var server: Server = null
+ private var port: Int = -1
+
+ def start() {
+ if (server != null) {
+ throw new ServerStateException("Server is already started")
+ } else {
+ server = new Server(0)
+ val threadPool = new QueuedThreadPool
+ threadPool.setDaemon(true)
+ server.setThreadPool(threadPool)
+ val resHandler = new ResourceHandler
+ resHandler.setResourceBase(resourceBase.getAbsolutePath)
+ val handlerList = new HandlerList
+ handlerList.setHandlers(Array(resHandler, new DefaultHandler))
+ server.setHandler(handlerList)
+ server.start()
+ port = server.getConnectors()(0).getLocalPort()
+ }
+ }
+
+ def stop() {
+ if (server == null) {
+ throw new ServerStateException("Server is already stopped")
+ } else {
+ server.stop()
+ port = -1
+ server = null
+ }
+ }
+
+ /**
+ * Get the URI of this HTTP server (http://host:port)
+ */
+ def uri: String = {
+ if (server == null) {
+ throw new ServerStateException("Server is not started")
+ } else {
+ return "http://" + Utils.localIpAddress + ":" + port
+ }
+ }
+}
diff --git a/src/scala/spark/Job.scala b/src/scala/spark/Job.scala
new file mode 100644
index 0000000000..6abbcbce51
--- /dev/null
+++ b/src/scala/spark/Job.scala
@@ -0,0 +1,18 @@
+package spark
+
+import mesos._
+
+/**
+ * Class representing a parallel job in MesosScheduler. Schedules the
+ * job by implementing various callbacks.
+ */
+abstract class Job(jobId: Int) {
+ def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int)
+ : Option[TaskDescription]
+
+ def statusUpdate(t: TaskStatus): Unit
+
+ def error(code: Int, message: String): Unit
+
+ def getId(): Int = jobId
+}
diff --git a/src/scala/spark/LocalFileShuffle.scala b/src/scala/spark/LocalFileShuffle.scala
new file mode 100644
index 0000000000..367599cfb4
--- /dev/null
+++ b/src/scala/spark/LocalFileShuffle.scala
@@ -0,0 +1,171 @@
+package spark
+
+import java.io._
+import java.net.URL
+import java.util.UUID
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+
+/**
+ * A simple implementation of shuffle using local files served through HTTP.
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
+ override def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] =
+ {
+ val sc = input.sparkContext
+ val shuffleId = LocalFileShuffle.newShuffleId()
+ logInfo("Shuffle ID: " + shuffleId)
+
+ val splitRdd = new NumberedSplitRDD(input)
+ val numInputSplits = splitRdd.splits.size
+
+ // Run a parallel map and collect to write the intermediate data files,
+ // returning a list of inputSplitId -> serverUri pairs
+ val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+ val myIndex = pair._1
+ val myIterator = pair._2
+ val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+ for ((k, v) <- myIterator) {
+ var bucketId = k.hashCode % numOutputSplits
+ if (bucketId < 0) { // Fix bucket ID if hash code was negative
+ bucketId += numOutputSplits
+ }
+ val bucket = buckets(bucketId)
+ bucket(k) = bucket.get(k) match {
+ case Some(c) => mergeValue(c, v)
+ case None => createCombiner(v)
+ }
+ }
+ for (i <- 0 until numOutputSplits) {
+ val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
+ val out = new ObjectOutputStream(new FileOutputStream(file))
+ buckets(i).foreach(pair => out.writeObject(pair))
+ out.close()
+ }
+ (myIndex, LocalFileShuffle.serverUri)
+ }).collect()
+
+ // Build a hashmap from server URI to list of splits (to facillitate
+ // fetching all the URIs on a server within a single connection)
+ val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
+ for ((inputId, serverUri) <- outputLocs) {
+ splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += inputId
+ }
+
+ // TODO: Could broadcast splitsByUri
+
+ // Return an RDD that does each of the merges for a given partition
+ val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+ return indexes.flatMap((myId: Int) => {
+ val combiners = new HashMap[K, C]
+ for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) {
+ for (i <- inputIds) {
+ val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId)
+ val inputStream = new ObjectInputStream(new URL(url).openStream())
+ try {
+ while (true) {
+ val (k, c) = inputStream.readObject().asInstanceOf[(K, C)]
+ combiners(k) = combiners.get(k) match {
+ case Some(oldC) => mergeCombiners(oldC, c)
+ case None => c
+ }
+ }
+ } catch {
+ case e: EOFException => {}
+ }
+ inputStream.close()
+ }
+ }
+ combiners
+ })
+ }
+}
+
+
+object LocalFileShuffle extends Logging {
+ private var initialized = false
+ private var nextShuffleId = new AtomicLong(0)
+
+ // Variables initialized by initializeIfNeeded()
+ private var shuffleDir: File = null
+ private var server: HttpServer = null
+ private var serverUri: String = null
+
+ private def initializeIfNeeded() = synchronized {
+ if (!initialized) {
+ // TODO: localDir should be created by some mechanism common to Spark
+ // so that it can be shared among shuffle, broadcast, etc
+ val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+ var tries = 0
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirUuid: UUID = null
+ while (!foundLocalDir && tries < 10) {
+ tries += 1
+ try {
+ localDirUuid = UUID.randomUUID()
+ localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
+ if (!localDir.exists()) {
+ localDir.mkdirs()
+ foundLocalDir = true
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed 10 attempts to create local dir in " + localDirRoot)
+ System.exit(1)
+ }
+ shuffleDir = new File(localDir, "shuffle")
+ shuffleDir.mkdirs()
+ logInfo("Shuffle dir: " + shuffleDir)
+ val extServerPort = System.getProperty(
+ "spark.localFileShuffle.external.server.port", "-1").toInt
+ if (extServerPort != -1) {
+ // We're using an external HTTP server; set URI relative to its root
+ var extServerPath = System.getProperty(
+ "spark.localFileShuffle.external.server.path", "")
+ if (extServerPath != "" && !extServerPath.endsWith("/")) {
+ extServerPath += "/"
+ }
+ serverUri = "http://%s:%d/%s/spark-local-%s".format(
+ Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
+ } else {
+ // Create our own server
+ server = new HttpServer(localDir)
+ server.start()
+ serverUri = server.uri
+ }
+ initialized = true
+ }
+ }
+
+ def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
+ initializeIfNeeded()
+ val dir = new File(shuffleDir, shuffleId + "/" + inputId)
+ dir.mkdirs()
+ val file = new File(dir, "" + outputId)
+ return file
+ }
+
+ def getServerUri(): String = {
+ initializeIfNeeded()
+ serverUri
+ }
+
+ def newShuffleId(): Long = {
+ nextShuffleId.getAndIncrement()
+ }
+}
diff --git a/src/scala/spark/MesosScheduler.scala b/src/scala/spark/MesosScheduler.scala
index 873a97c59c..c45eff64d4 100644
--- a/src/scala/spark/MesosScheduler.scala
+++ b/src/scala/spark/MesosScheduler.scala
@@ -1,103 +1,130 @@
package spark
-import java.io.File
+import java.io.{File, FileInputStream, FileOutputStream}
+import java.util.{ArrayList => JArrayList}
+import java.util.{List => JList}
+import java.util.{HashMap => JHashMap}
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
import scala.collection.mutable.Map
import scala.collection.mutable.Queue
-import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
-import mesos.{Scheduler => NScheduler}
+import mesos.{Scheduler => MScheduler}
import mesos._
-// The main Scheduler implementation, which talks to Mesos. Clients are expected
-// to first call start(), then submit tasks through the runTasks method.
-//
-// This implementation is currently a little quick and dirty. The following
-// improvements need to be made to it:
-// 1) Right now, the scheduler uses a linear scan through the tasks to find a
-// local one for a given node. It would be faster to have a separate list of
-// pending tasks for each node.
-// 2) Presenting a single slave in ParallelOperation.slaveOffer makes it
-// difficult to balance tasks across nodes. It would be better to pass
-// all the offers to the ParallelOperation and have it load-balance.
+/**
+ * The main Scheduler implementation, which runs jobs on Mesos. Clients should
+ * first call start(), then submit tasks through the runTasks method.
+ */
private class MesosScheduler(
- master: String, frameworkName: String, execArg: Array[Byte])
-extends NScheduler with spark.Scheduler with Logging
+ sc: SparkContext, master: String, frameworkName: String)
+extends MScheduler with spark.Scheduler with Logging
{
- // Lock used by runTasks to ensure only one thread can be in it
- val runTasksMutex = new Object()
+ // Environment variables to pass to our executors
+ val ENV_VARS_TO_SEND_TO_EXECUTORS = Array(
+ "SPARK_MEM",
+ "SPARK_CLASSPATH",
+ "SPARK_LIBRARY_PATH"
+ )
// Lock used to wait for scheduler to be registered
- var isRegistered = false
- val registeredLock = new Object()
+ private var isRegistered = false
+ private val registeredLock = new Object()
- // Current callback object (may be null)
- var activeOpsQueue = new Queue[Int]
- var activeOps = new HashMap[Int, ParallelOperation]
- private var nextOpId = 0
- private[spark] var taskIdToOpId = new HashMap[Int, Int]
-
- def newOpId(): Int = {
- val id = nextOpId
- nextOpId += 1
- return id
- }
+ private var activeJobs = new HashMap[Int, Job]
+ private var activeJobsQueue = new Queue[Job]
+
+ private var taskIdToJobId = new HashMap[Int, Int]
+ private var jobTasks = new HashMap[Int, HashSet[Int]]
- // Incrementing task ID
+ // Incrementing job and task IDs
+ private var nextJobId = 0
private var nextTaskId = 0
+ // Driver for talking to Mesos
+ var driver: SchedulerDriver = null
+
+ // JAR server, if any JARs were added by the user to the SparkContext
+ var jarServer: HttpServer = null
+
+ // URIs of JARs to pass to executor
+ var jarUris: String = ""
+
+ def newJobId(): Int = this.synchronized {
+ val id = nextJobId
+ nextJobId += 1
+ return id
+ }
+
def newTaskId(): Int = {
val id = nextTaskId;
nextTaskId += 1;
return id
}
-
- // Driver for talking to Mesos
- var driver: SchedulerDriver = null
override def start() {
+ if (sc.jars.size > 0) {
+ // If the user added any JARS to the SparkContext, create an HTTP server
+ // to serve them to our executors
+ createJarServer()
+ }
new Thread("Spark scheduler") {
setDaemon(true)
override def run {
- val ns = MesosScheduler.this
- ns.driver = new MesosSchedulerDriver(ns, master)
- ns.driver.run()
+ val sched = MesosScheduler.this
+ sched.driver = new MesosSchedulerDriver(sched, master)
+ sched.driver.run()
}
}.start
}
override def getFrameworkName(d: SchedulerDriver): String = frameworkName
- override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo =
- new ExecutorInfo(new File("spark-executor").getCanonicalPath(), execArg)
+ override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo = {
+ val sparkHome = sc.getSparkHome match {
+ case Some(path) => path
+ case None =>
+ throw new SparkException("Spark home is not set; set it through the " +
+ "spark.home system property, the SPARK_HOME environment variable " +
+ "or the SparkContext constructor")
+ }
+ val execScript = new File(sparkHome, "spark-executor").getCanonicalPath
+ val params = new JHashMap[String, String]
+ for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) {
+ if (System.getenv(key) != null) {
+ params("env." + key) = System.getenv(key)
+ }
+ }
+ new ExecutorInfo(execScript, createExecArg())
+ }
+ /**
+ * The primary means to submit a job to the scheduler. Given a list of tasks,
+ * runs them and returns an array of the results.
+ */
override def runTasks[T: ClassManifest](tasks: Array[Task[T]]): Array[T] = {
- var opId = 0
waitForRegister()
- this.synchronized {
- opId = newOpId()
- }
- val myOp = new SimpleParallelOperation(this, tasks, opId)
-
+ val jobId = newJobId()
+ val myJob = new SimpleJob(this, tasks, jobId)
try {
this.synchronized {
- this.activeOps(myOp.opId) = myOp
- this.activeOpsQueue += myOp.opId
+ activeJobs(jobId) = myJob
+ activeJobsQueue += myJob
+ jobTasks(jobId) = new HashSet()
}
driver.reviveOffers();
- myOp.join();
+ return myJob.join();
} finally {
this.synchronized {
- this.activeOps.remove(myOp.opId)
- this.activeOpsQueue.dequeueAll(x => (x == myOp.opId))
+ activeJobs -= jobId
+ activeJobsQueue.dequeueAll(x => (x == myJob))
+ taskIdToJobId --= jobTasks(jobId)
+ jobTasks.remove(jobId)
}
}
-
- if (myOp.errorHappened)
- throw new SparkException(myOp.errorMessage, myOp.errorCode)
- else
- return myOp.results
}
override def registered(d: SchedulerDriver, frameworkId: String) {
@@ -115,51 +142,68 @@ extends NScheduler with spark.Scheduler with Logging
}
}
+ /**
+ * Method called by Mesos to offer resources on slaves. We resond by asking
+ * our active jobs for tasks in FIFO order. We fill each node with tasks in
+ * a round-robin manner so that tasks are balanced across the cluster.
+ */
override def resourceOffer(
- d: SchedulerDriver, oid: String, offers: java.util.List[SlaveOffer]) {
+ d: SchedulerDriver, oid: String, offers: JList[SlaveOffer]) {
synchronized {
- val tasks = new java.util.ArrayList[TaskDescription]
+ val tasks = new JArrayList[TaskDescription]
val availableCpus = offers.map(_.getParams.get("cpus").toInt)
val availableMem = offers.map(_.getParams.get("mem").toInt)
- var launchedTask = true
- for (opId <- activeOpsQueue) {
- launchedTask = true
- while (launchedTask) {
+ var launchedTask = false
+ for (job <- activeJobsQueue) {
+ do {
launchedTask = false
for (i <- 0 until offers.size.toInt) {
try {
- activeOps(opId).slaveOffer(offers.get(i), availableCpus(i), availableMem(i)) match {
+ job.slaveOffer(offers(i), availableCpus(i), availableMem(i)) match {
case Some(task) =>
tasks.add(task)
+ taskIdToJobId(task.getTaskId) = job.getId
+ jobTasks(job.getId) += task.getTaskId
availableCpus(i) -= task.getParams.get("cpus").toInt
availableMem(i) -= task.getParams.get("mem").toInt
- launchedTask = launchedTask || true
+ launchedTask = true
case None => {}
}
} catch {
case e: Exception => logError("Exception in resourceOffer", e)
}
}
- }
+ } while (launchedTask)
}
- val params = new java.util.HashMap[String, String]
+ val params = new JHashMap[String, String]
params.put("timeout", "1")
- d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout
+ d.replyToOffer(oid, tasks, params) // TODO: use smaller timeout?
}
}
+ // Check whether a Mesos task state represents a finished task
+ def isFinished(state: TaskState) = {
+ state == TaskState.TASK_FINISHED ||
+ state == TaskState.TASK_FAILED ||
+ state == TaskState.TASK_KILLED ||
+ state == TaskState.TASK_LOST
+ }
+
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
synchronized {
try {
- taskIdToOpId.get(status.getTaskId) match {
- case Some(opId) =>
- if (activeOps.contains(opId)) {
- activeOps(opId).statusUpdate(status)
+ taskIdToJobId.get(status.getTaskId) match {
+ case Some(jobId) =>
+ if (activeJobs.contains(jobId)) {
+ activeJobs(jobId).statusUpdate(status)
+ }
+ if (isFinished(status.getState)) {
+ taskIdToJobId.remove(status.getTaskId)
+ jobTasks(jobId) -= status.getTaskId
}
case None =>
logInfo("TID " + status.getTaskId + " already finished")
}
-
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
}
@@ -167,180 +211,84 @@ extends NScheduler with spark.Scheduler with Logging
}
override def error(d: SchedulerDriver, code: Int, message: String) {
+ logError("Mesos error: %s (error code: %d)".format(message, code))
synchronized {
- if (activeOps.size > 0) {
- for ((opId, activeOp) <- activeOps) {
+ if (activeJobs.size > 0) {
+ // Have each job throw a SparkException with the error
+ for ((jobId, activeJob) <- activeJobs) {
try {
- activeOp.error(code, message)
+ activeJob.error(code, message)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
}
} else {
- logError("Mesos error: %s (error code: %d)".format(message, code))
+ // No jobs are active but we still got an error. Just exit since this
+ // must mean the error is during registration.
+ // It might be good to do something smarter here in the future.
System.exit(1)
}
}
}
override def stop() {
- if (driver != null)
+ if (driver != null) {
driver.stop()
- }
-
- // TODO: query Mesos for number of cores
- override def numCores() = System.getProperty("spark.default.parallelism", "2").toInt
-}
-
-
-// Trait representing an object that manages a parallel operation by
-// implementing various scheduler callbacks.
-trait ParallelOperation {
- def slaveOffer(s: SlaveOffer, availableCpus: Int, availableMem: Int): Option[TaskDescription]
- def statusUpdate(t: TaskStatus): Unit
- def error(code: Int, message: String): Unit
-}
-
-
-class SimpleParallelOperation[T: ClassManifest](
- sched: MesosScheduler, tasks: Array[Task[T]], val opId: Int)
-extends ParallelOperation with Logging
-{
- // Maximum time to wait to run a task in a preferred location (in ms)
- val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
-
- val callingThread = currentThread
- val numTasks = tasks.length
- val results = new Array[T](numTasks)
- val launched = new Array[Boolean](numTasks)
- val finished = new Array[Boolean](numTasks)
- val tidToIndex = Map[Int, Int]()
-
- var allFinished = false
- val joinLock = new Object()
-
- var errorHappened = false
- var errorCode = 0
- var errorMessage = ""
-
- var tasksLaunched = 0
- var tasksFinished = 0
- var lastPreferredLaunchTime = System.currentTimeMillis
-
- def setAllFinished() {
- joinLock.synchronized {
- allFinished = true
- joinLock.notifyAll()
}
- }
-
- def join() {
- joinLock.synchronized {
- while (!allFinished)
- joinLock.wait()
- }
- }
-
- def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int): Option[TaskDescription] = {
- if (tasksLaunched < numTasks) {
- var checkPrefVals: Array[Boolean] = Array(true)
- val time = System.currentTimeMillis
- if (time - lastPreferredLaunchTime > LOCALITY_WAIT)
- checkPrefVals = Array(true, false) // Allow non-preferred tasks
- // TODO: Make desiredCpus and desiredMem configurable
- val desiredCpus = 1
- val desiredMem = 500
- if ((availableCpus < desiredCpus) || (availableMem < desiredMem))
- return None
- for (checkPref <- checkPrefVals; i <- 0 until numTasks) {
- if (!launched(i) && (!checkPref ||
- tasks(i).preferredLocations.contains(offer.getHost) ||
- tasks(i).preferredLocations.isEmpty))
- {
- val taskId = sched.newTaskId()
- sched.taskIdToOpId(taskId) = opId
- tidToIndex(taskId) = i
- val preferred = if(checkPref) "preferred" else "non-preferred"
- val message =
- "Starting task %d as opId %d, TID %s on slave %s: %s (%s)".format(
- i, opId, taskId, offer.getSlaveId, offer.getHost, preferred)
- logInfo(message)
- tasks(i).markStarted(offer)
- launched(i) = true
- tasksLaunched += 1
- if (checkPref)
- lastPreferredLaunchTime = time
- val params = new java.util.HashMap[String, String]
- params.put("cpus", "" + desiredCpus)
- params.put("mem", "" + desiredMem)
- val serializedTask = Utils.serialize(tasks(i))
- //logInfo("Serialized size: " + serializedTask.size)
- return Some(new TaskDescription(taskId, offer.getSlaveId,
- "task_" + taskId, params, serializedTask))
- }
- }
+ if (jarServer != null) {
+ jarServer.stop()
}
- return None
}
- def statusUpdate(status: TaskStatus) {
- status.getState match {
- case TaskState.TASK_FINISHED =>
- taskFinished(status)
- case TaskState.TASK_LOST =>
- taskLost(status)
- case TaskState.TASK_FAILED =>
- taskLost(status)
- case TaskState.TASK_KILLED =>
- taskLost(status)
- case _ =>
+ // TODO: query Mesos for number of cores
+ override def numCores() =
+ System.getProperty("spark.default.parallelism", "2").toInt
+
+ // Create a server for all the JARs added by the user to SparkContext.
+ // We first copy the JARs to a temp directory for easier server setup.
+ private def createJarServer() {
+ val jarDir = Utils.createTempDir()
+ logInfo("Temp directory for JARs: " + jarDir)
+ val filenames = ArrayBuffer[String]()
+ // Copy each JAR to a unique filename in the jarDir
+ for ((path, index) <- sc.jars.zipWithIndex) {
+ val file = new File(path)
+ val filename = index + "_" + file.getName
+ copyFile(file, new File(jarDir, filename))
+ filenames += filename
}
+ // Create the server
+ jarServer = new HttpServer(jarDir)
+ jarServer.start()
+ // Build up the jar URI list
+ val serverUri = jarServer.uri
+ jarUris = filenames.map(f => serverUri + "/" + f).mkString(",")
+ logInfo("JAR server started at " + serverUri)
}
- def taskFinished(status: TaskStatus) {
- val tid = status.getTaskId
- val index = tidToIndex(tid)
- if (!finished(index)) {
- tasksFinished += 1
- logInfo("Finished opId %d TID %d (progress: %d/%d)".format(
- opId, tid, tasksFinished, numTasks))
- // Deserialize task result
- val result = Utils.deserialize[TaskResult[T]](status.getData)
- results(index) = result.value
- // Update accumulators
- Accumulators.add(callingThread, result.accumUpdates)
- // Mark finished and stop if we've finished all the tasks
- finished(index) = true
- // Remove TID -> opId mapping from sched
- sched.taskIdToOpId.remove(tid)
- if (tasksFinished == numTasks)
- setAllFinished()
- } else {
- logInfo("Ignoring task-finished event for TID " + tid +
- " because task " + index + " is already finished")
- }
+ // Copy a file on the local file system
+ private def copyFile(source: File, dest: File) {
+ val in = new FileInputStream(source)
+ val out = new FileOutputStream(dest)
+ Utils.copyStream(in, out, true)
}
- def taskLost(status: TaskStatus) {
- val tid = status.getTaskId
- val index = tidToIndex(tid)
- if (!finished(index)) {
- logInfo("Lost opId " + opId + " TID " + tid)
- launched(index) = false
- sched.taskIdToOpId.remove(tid)
- tasksLaunched -= 1
- } else {
- logInfo("Ignoring task-lost event for TID " + tid +
- " because task " + index + " is already finished")
+ // Create and serialize the executor argument to pass to Mesos.
+ // Our executor arg is an array containing all the spark.* system properties
+ // in the form of (String, String) pairs.
+ private def createExecArg(): Array[Byte] = {
+ val props = new HashMap[String, String]
+ val iter = System.getProperties.entrySet.iterator
+ while (iter.hasNext) {
+ val entry = iter.next
+ val (key, value) = (entry.getKey.toString, entry.getValue.toString)
+ if (key.startsWith("spark.")) {
+ props(key) = value
+ }
}
- }
-
- def error(code: Int, message: String) {
- // Save the error message
- errorHappened = true
- errorCode = code
- errorMessage = message
- // Indicate to caller thread that we're done
- setAllFinished()
+ // Set spark.jar.uris to our JAR URIs, regardless of system property
+ props("spark.jar.uris") = jarUris
+ // Serialize the map as an array of (String, String) pairs
+ return Utils.serialize(props.toArray)
}
}
diff --git a/src/scala/spark/NumberedSplitRDD.scala b/src/scala/spark/NumberedSplitRDD.scala
new file mode 100644
index 0000000000..7b12210d84
--- /dev/null
+++ b/src/scala/spark/NumberedSplitRDD.scala
@@ -0,0 +1,42 @@
+package spark
+
+import mesos.SlaveOffer
+
+
+/**
+ * An RDD that takes the splits of a parent RDD and gives them unique indexes.
+ * This is useful for a variety of shuffle implementations.
+ */
+class NumberedSplitRDD[T: ClassManifest](prev: RDD[T])
+extends RDD[(Int, Iterator[T])](prev.sparkContext) {
+ @transient val splits_ = {
+ prev.splits.zipWithIndex.map {
+ case (s, i) => new NumberedSplitRDDSplit(s, i): Split
+ }.toArray
+ }
+
+ override def splits = splits_
+
+ override def preferredLocations(split: Split) = {
+ val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
+ prev.preferredLocations(nsplit.prev)
+ }
+
+ override def iterator(split: Split) = {
+ val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
+ Iterator((nsplit.index, prev.iterator(nsplit.prev)))
+ }
+
+ override def taskStarted(split: Split, slot: SlaveOffer) = {
+ val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
+ prev.taskStarted(nsplit.prev, slot)
+ }
+}
+
+
+/**
+ * A split in a NumberedSplitRDD.
+ */
+class NumberedSplitRDDSplit(val prev: Split, val index: Int) extends Split {
+ override def getId() = "NumberedSplitRDDSplit(%d)".format(index)
+}
diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala
index 803c063865..bac59319a0 100644
--- a/src/scala/spark/RDD.scala
+++ b/src/scala/spark/RDD.scala
@@ -1,7 +1,6 @@
package spark
import java.util.concurrent.atomic.AtomicLong
-import java.util.concurrent.ConcurrentHashMap
import java.util.HashSet
import java.util.Random
@@ -9,13 +8,13 @@ import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.Map
import scala.collection.mutable.HashMap
+import SparkContext._
+
import mesos._
-import com.google.common.collect.MapMaker
@serializable
-abstract class RDD[T: ClassManifest](
- @transient sc: SparkContext) {
+abstract class RDD[T: ClassManifest](@transient sc: SparkContext) {
def splits: Array[Split]
def iterator(split: Split): Iterator[T]
def preferredLocations(split: Split): Seq[String]
@@ -26,7 +25,6 @@ abstract class RDD[T: ClassManifest](
def map[U: ClassManifest](f: T => U) = new MappedRDD(this, sc.clean(f))
def filter(f: T => Boolean) = new FilteredRDD(this, sc.clean(f))
- def aggregateSplit() = new SplitRDD(this)
def cache() = new CachedRDD(this)
def sample(withReplacement: Boolean, frac: Double, seed: Int) =
@@ -78,15 +76,28 @@ abstract class RDD[T: ClassManifest](
case _ => throw new UnsupportedOperationException("empty collection")
}
- def count(): Long =
- try { map(x => 1L).reduce(_+_) }
- catch { case e: UnsupportedOperationException => 0L }
+ def count(): Long = {
+ try {
+ map(x => 1L).reduce(_+_)
+ } catch {
+ case e: UnsupportedOperationException => 0L // No elements in RDD
+ }
+ }
- def union(other: RDD[T]) = new UnionRDD(sc, this, other)
- def cartesian[U: ClassManifest](other: RDD[U]) = new CartesianRDD(sc, this, other)
+ def union(other: RDD[T]) = new UnionRDD(sc, Array(this, other))
def ++(other: RDD[T]) = this.union(other)
+ def splitRdd() = new SplitRDD(this)
+
+ def cartesian[U: ClassManifest](other: RDD[U]) =
+ new CartesianRDD(sc, this, other)
+
+ def groupBy[K](func: T => K, numSplits: Int): RDD[(K, Seq[T])] =
+ this.map(t => (func(t), t)).groupByKey(numSplits)
+
+ def groupBy[K](func: T => K): RDD[(K, Seq[T])] =
+ groupBy[K](func, sc.numCores)
}
@serializable
@@ -129,7 +140,7 @@ extends RDDTask[Option[T], T](rdd, split) with Logging {
}
class MappedRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T], f: T => U)
+ prev: RDD[T], f: T => U)
extends RDD[U](prev.sparkContext) {
override def splits = prev.splits
override def preferredLocations(split: Split) = prev.preferredLocations(split)
@@ -138,7 +149,7 @@ extends RDD[U](prev.sparkContext) {
}
class FilteredRDD[T: ClassManifest](
- prev: RDD[T], f: T => Boolean)
+ prev: RDD[T], f: T => Boolean)
extends RDD[T](prev.sparkContext) {
override def splits = prev.splits
override def preferredLocations(split: Split) = prev.preferredLocations(split)
@@ -147,7 +158,7 @@ extends RDD[T](prev.sparkContext) {
}
class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T], f: T => Traversable[U])
+ prev: RDD[T], f: T => Traversable[U])
extends RDD[U](prev.sparkContext) {
override def splits = prev.splits
override def preferredLocations(split: Split) = prev.preferredLocations(split)
@@ -156,7 +167,7 @@ extends RDD[U](prev.sparkContext) {
override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
}
-class SplitRDD[T: ClassManifest](prev: RDD[T])
+class SplitRDD[T: ClassManifest](prev: RDD[T])
extends RDD[Array[T]](prev.sparkContext) {
override def splits = prev.splits
override def preferredLocations(split: Split) = prev.preferredLocations(split)
@@ -171,16 +182,16 @@ extends RDD[Array[T]](prev.sparkContext) {
}
class SampledRDD[T: ClassManifest](
- prev: RDD[T], withReplacement: Boolean, frac: Double, seed: Int)
+ prev: RDD[T], withReplacement: Boolean, frac: Double, seed: Int)
extends RDD[T](prev.sparkContext) {
-
+
@transient val splits_ = { val rg = new Random(seed); prev.splits.map(x => new SeededSplit(x, rg.nextInt)) }
override def splits = splits_.asInstanceOf[Array[Split]]
override def preferredLocations(split: Split) = prev.preferredLocations(split.asInstanceOf[SeededSplit].prev)
- override def iterator(splitIn: Split) = {
+ override def iterator(splitIn: Split) = {
val split = splitIn.asInstanceOf[SeededSplit]
val rg = new Random(split.seed);
// Sampling with replacement (TODO: use reservoir sampling to make this more efficient?)
@@ -214,7 +225,7 @@ extends RDD[T](prev.sparkContext) with Logging {
else
prev.preferredLocations(split)
}
-
+
override def iterator(split: Split): Iterator[T] = {
val key = id + "::" + split.getId()
logInfo("CachedRDD split key is " + key)
@@ -261,45 +272,36 @@ private object CachedRDD {
def newId() = nextId.getAndIncrement()
// Stores map results for various splits locally (on workers)
- val cache = new MapMaker().softValues().makeMap[String, AnyRef]()
+ val cache = Cache.newKeySpace()
// Remembers which splits are currently being loaded (on workers)
val loading = new HashSet[String]
}
@serializable
-abstract class UnionSplit[T: ClassManifest] extends Split {
- def iterator(): Iterator[T]
- def preferredLocations(): Seq[String]
- def getId(): String
-}
-
-@serializable
-class UnionSplitImpl[T: ClassManifest](
- rdd: RDD[T], split: Split)
-extends UnionSplit[T] {
- override def iterator() = rdd.iterator(split)
- override def preferredLocations() = rdd.preferredLocations(split)
- override def getId() =
- "UnionSplitImpl(" + split.getId() + ")"
+class UnionSplit[T: ClassManifest](rdd: RDD[T], split: Split)
+extends Split {
+ def iterator() = rdd.iterator(split)
+ def preferredLocations() = rdd.preferredLocations(split)
+ override def getId() = "UnionSplit(" + split.getId() + ")"
}
@serializable
-class UnionRDD[T: ClassManifest](
- sc: SparkContext, rdd1: RDD[T], rdd2: RDD[T])
+class UnionRDD[T: ClassManifest](sc: SparkContext, rdds: Seq[RDD[T]])
extends RDD[T](sc) {
-
- @transient val splits_ : Array[UnionSplit[T]] = {
- val a1 = rdd1.splits.map(s => new UnionSplitImpl(rdd1, s))
- val a2 = rdd2.splits.map(s => new UnionSplitImpl(rdd2, s))
- (a1 ++ a2).toArray
+ @transient val splits_ : Array[Split] = {
+ val splits: Seq[Split] =
+ for (rdd <- rdds; split <- rdd.splits)
+ yield new UnionSplit(rdd, split)
+ splits.toArray
}
- override def splits = splits_.asInstanceOf[Array[Split]]
+ override def splits = splits_
- override def iterator(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator()
+ override def iterator(s: Split): Iterator[T] =
+ s.asInstanceOf[UnionSplit[T]].iterator()
- override def preferredLocations(s: Split): Seq[String] =
+ override def preferredLocations(s: Split): Seq[String] =
s.asInstanceOf[UnionSplit[T]].preferredLocations()
}
@@ -336,8 +338,8 @@ extends RDD[Pair[T, U]](sc) {
}
}
-@serializable class PairRDDExtras[K, V](rdd: RDD[(K, V)]) {
- def reduceByKey(func: (V, V) => V): Map[K, V] = {
+@serializable class PairRDDExtras[K, V](self: RDD[(K, V)]) {
+ def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = {
def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = {
for ((k, v) <- m2) {
m1.get(k) match {
@@ -347,6 +349,70 @@ extends RDD[Pair[T, U]](sc) {
}
return m1
}
- rdd.map(pair => HashMap(pair)).reduce(mergeMaps)
+ self.map(pair => HashMap(pair)).reduce(mergeMaps)
+ }
+
+ def combineByKey[C](createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C,
+ numSplits: Int)
+ : RDD[(K, C)] =
+ {
+ val shufClass = Class.forName(System.getProperty(
+ "spark.shuffle.class", "spark.DfsShuffle"))
+ val shuf = shufClass.newInstance().asInstanceOf[Shuffle[K, V, C]]
+ shuf.compute(self, numSplits, createCombiner, mergeValue, mergeCombiners)
}
+
+ def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
+ combineByKey[V]((v: V) => v, func, func, numSplits)
+ }
+
+ def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = {
+ def createCombiner(v: V) = ArrayBuffer(v)
+ def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
+ def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
+ val bufs = combineByKey[ArrayBuffer[V]](
+ createCombiner _, mergeValue _, mergeCombiners _, numSplits)
+ bufs.asInstanceOf[RDD[(K, Seq[V])]]
+ }
+
+ def join[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, W))] = {
+ val vs: RDD[(K, Either[V, W])] = self.map { case (k, v) => (k, Left(v)) }
+ val ws: RDD[(K, Either[V, W])] = other.map { case (k, w) => (k, Right(w)) }
+ (vs ++ ws).groupByKey(numSplits).flatMap {
+ case (k, seq) => {
+ val vbuf = new ArrayBuffer[V]
+ val wbuf = new ArrayBuffer[W]
+ seq.foreach(_ match {
+ case Left(v) => vbuf += v
+ case Right(w) => wbuf += w
+ })
+ for (v <- vbuf; w <- wbuf) yield (k, (v, w))
+ }
+ }
+ }
+
+ def combineByKey[C](createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] = {
+ combineByKey(createCombiner, mergeValue, mergeCombiners, numCores)
+ }
+
+ def reduceByKey(func: (V, V) => V): RDD[(K, V)] = {
+ reduceByKey(func, numCores)
+ }
+
+ def groupByKey(): RDD[(K, Seq[V])] = {
+ groupByKey(numCores)
+ }
+
+ def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = {
+ join(other, numCores)
+ }
+
+ def numCores = self.sparkContext.numCores
+
+ def collectAsMap(): Map[K, V] = HashMap(self.collect(): _*)
}
diff --git a/src/scala/spark/Shuffle.scala b/src/scala/spark/Shuffle.scala
new file mode 100644
index 0000000000..4c5649b537
--- /dev/null
+++ b/src/scala/spark/Shuffle.scala
@@ -0,0 +1,15 @@
+package spark
+
+/**
+ * A trait for shuffle system. Given an input RDD and combiner functions
+ * for PairRDDExtras.combineByKey(), returns an output RDD.
+ */
+@serializable
+trait Shuffle[K, V, C] {
+ def compute(input: RDD[(K, V)],
+ numOutputSplits: Int,
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C)
+ : RDD[(K, C)]
+}
diff --git a/src/scala/spark/SimpleJob.scala b/src/scala/spark/SimpleJob.scala
new file mode 100644
index 0000000000..09846ccc34
--- /dev/null
+++ b/src/scala/spark/SimpleJob.scala
@@ -0,0 +1,272 @@
+package spark
+
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import mesos._
+
+
+/**
+ * A Job that runs a set of tasks with no interdependencies.
+ */
+class SimpleJob[T: ClassManifest](
+ sched: MesosScheduler, tasks: Array[Task[T]], val jobId: Int)
+extends Job(jobId) with Logging
+{
+ // Maximum time to wait to run a task in a preferred location (in ms)
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
+
+ // CPUs and memory to request per task
+ val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
+ val MEM_PER_TASK = System.getProperty("spark.task.mem", "512").toInt
+
+ // Maximum times a task is allowed to fail before failing the job
+ val MAX_TASK_FAILURES = 4
+
+ val callingThread = currentThread
+ val numTasks = tasks.length
+ val results = new Array[T](numTasks)
+ val launched = new Array[Boolean](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val tidToIndex = HashMap[Int, Int]()
+
+ var allFinished = false
+ val joinLock = new Object() // Used to wait for all tasks to finish
+
+ var tasksLaunched = 0
+ var tasksFinished = 0
+
+ // Last time when we launched a preferred task (for delay scheduling)
+ var lastPreferredLaunchTime = System.currentTimeMillis
+
+ // List of pending tasks for each node. These collections are actually
+ // treated as stacks, in which new tasks are added to the end of the
+ // ArrayBuffer and removed from the end. This makes it faster to detect
+ // tasks that repeatedly fail because whenever a task failed, it is put
+ // back at the head of the stack. They are also only cleaned up lazily;
+ // when a task is launched, it remains in all the pending lists except
+ // the one that it was launched from, but gets removed from them later.
+ val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // List containing pending tasks with no locality preferences
+ val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+
+ // List containing all pending tasks (also used as a stack, as above)
+ val allPendingTasks = new ArrayBuffer[Int]
+
+ // Did the job fail?
+ var failed = false
+ var causeOfFailure = ""
+
+ // Add all our tasks to the pending lists. We do this in reverse order
+ // of task index so that tasks with low indices get launched first.
+ for (i <- (0 until numTasks).reverse) {
+ addPendingTask(i)
+ }
+
+ // Add a task to all the pending-task lists that it should be on.
+ def addPendingTask(index: Int) {
+ val locations = tasks(index).preferredLocations
+ if (locations.size == 0) {
+ pendingTasksWithNoPrefs += index
+ } else {
+ for (host <- locations) {
+ val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
+ list += index
+ }
+ }
+ allPendingTasks += index
+ }
+
+ // Mark the job as finished and wake up any threads waiting on it
+ def setAllFinished() {
+ joinLock.synchronized {
+ allFinished = true
+ joinLock.notifyAll()
+ }
+ }
+
+ // Wait until the job finishes and return its results
+ def join(): Array[T] = {
+ joinLock.synchronized {
+ while (!allFinished) {
+ joinLock.wait()
+ }
+ if (failed) {
+ throw new SparkException(causeOfFailure)
+ } else {
+ return results
+ }
+ }
+ }
+
+ // Return the pending tasks list for a given host, or an empty list if
+ // there is no map entry for that host
+ def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+ pendingTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ // Dequeue a pending task from the given list and return its index.
+ // Return None if the list is empty.
+ // This method also cleans up any tasks in the list that have already
+ // been launched, since we want that to happen lazily.
+ def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ while (!list.isEmpty) {
+ val index = list.last
+ list.trimEnd(1)
+ if (!launched(index) && !finished(index)) {
+ return Some(index)
+ }
+ }
+ return None
+ }
+
+ // Dequeue a pending task for a given node and return its index.
+ // If localOnly is set to false, allow non-local tasks as well.
+ def findTask(host: String, localOnly: Boolean): Option[Int] = {
+ val localTask = findTaskFromList(getPendingTasksForHost(host))
+ if (localTask != None) {
+ return localTask
+ }
+ val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
+ if (noPrefTask != None) {
+ return noPrefTask
+ }
+ if (!localOnly) {
+ return findTaskFromList(allPendingTasks) // Look for non-local task
+ } else {
+ return None
+ }
+ }
+
+ // Does a host count as a preferred location for a task? This is true if
+ // either the task has preferred locations and this host is one, or it has
+ // no preferred locations (in which we still count the launch as preferred).
+ def isPreferredLocation(task: Task[T], host: String): Boolean = {
+ val locs = task.preferredLocations
+ return (locs.contains(host) || locs.isEmpty)
+ }
+
+ // Respond to an offer of a single slave from the scheduler by finding a task
+ def slaveOffer(offer: SlaveOffer, availableCpus: Int, availableMem: Int)
+ : Option[TaskDescription] = {
+ if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK &&
+ availableMem >= MEM_PER_TASK) {
+ val time = System.currentTimeMillis
+ val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
+ val host = offer.getHost
+ findTask(host, localOnly) match {
+ case Some(index) => {
+ // Found a task; do some bookkeeping and return a Mesos task for it
+ val task = tasks(index)
+ val taskId = sched.newTaskId()
+ // Figure out whether this should count as a preferred launch
+ val preferred = isPreferredLocation(task, host)
+ val prefStr = if(preferred) "preferred" else "non-preferred"
+ val message =
+ "Starting task %d:%d as TID %s on slave %s: %s (%s)".format(
+ jobId, index, taskId, offer.getSlaveId, host, prefStr)
+ logInfo(message)
+ // Do various bookkeeping
+ tidToIndex(taskId) = index
+ task.markStarted(offer)
+ launched(index) = true
+ tasksLaunched += 1
+ if (preferred)
+ lastPreferredLaunchTime = time
+ // Create and return the Mesos task object
+ val params = new JHashMap[String, String]
+ params.put("cpus", CPUS_PER_TASK.toString)
+ params.put("mem", MEM_PER_TASK.toString)
+ val serializedTask = Utils.serialize(task)
+ logDebug("Serialized size: " + serializedTask.size)
+ val taskName = "task %d:%d".format(jobId, index)
+ return Some(new TaskDescription(
+ taskId, offer.getSlaveId, taskName, params, serializedTask))
+ }
+ case _ =>
+ }
+ }
+ return None
+ }
+
+ def statusUpdate(status: TaskStatus) {
+ status.getState match {
+ case TaskState.TASK_FINISHED =>
+ taskFinished(status)
+ case TaskState.TASK_LOST =>
+ taskLost(status)
+ case TaskState.TASK_FAILED =>
+ taskLost(status)
+ case TaskState.TASK_KILLED =>
+ taskLost(status)
+ case _ =>
+ }
+ }
+
+ def taskFinished(status: TaskStatus) {
+ val tid = status.getTaskId
+ val index = tidToIndex(tid)
+ if (!finished(index)) {
+ tasksFinished += 1
+ logInfo("Finished TID %d (progress: %d/%d)".format(
+ tid, tasksFinished, numTasks))
+ // Deserialize task result
+ val result = Utils.deserialize[TaskResult[T]](status.getData)
+ results(index) = result.value
+ // Update accumulators
+ Accumulators.add(callingThread, result.accumUpdates)
+ // Mark finished and stop if we've finished all the tasks
+ finished(index) = true
+ if (tasksFinished == numTasks)
+ setAllFinished()
+ } else {
+ logInfo("Ignoring task-finished event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def taskLost(status: TaskStatus) {
+ val tid = status.getTaskId
+ val index = tidToIndex(tid)
+ if (!finished(index)) {
+ logInfo("Lost TID %d (task %d:%d)".format(tid, jobId, index))
+ launched(index) = false
+ tasksLaunched -= 1
+ // Re-enqueue the task as pending
+ addPendingTask(index)
+ // Mark it as failed
+ if (status.getState == TaskState.TASK_FAILED ||
+ status.getState == TaskState.TASK_LOST) {
+ numFailures(index) += 1
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ logError("Task %d:%d failed more than %d times; aborting job".format(
+ jobId, index, MAX_TASK_FAILURES))
+ abort("Task %d failed more than %d times".format(
+ index, MAX_TASK_FAILURES))
+ }
+ }
+ } else {
+ logInfo("Ignoring task-lost event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def error(code: Int, message: String) {
+ // Save the error message
+ abort("Mesos error: %s (error code: %d)".format(message, code))
+ }
+
+ def abort(message: String) {
+ joinLock.synchronized {
+ failed = true
+ causeOfFailure = message
+ // TODO: Kill running tasks if we were not terminated due to a Mesos error
+ // Indicate to any joining thread that we're done
+ setAllFinished()
+ }
+ }
+}
diff --git a/src/scala/spark/SizeEstimator.scala b/src/scala/spark/SizeEstimator.scala
new file mode 100644
index 0000000000..12dd19d704
--- /dev/null
+++ b/src/scala/spark/SizeEstimator.scala
@@ -0,0 +1,160 @@
+package spark
+
+import java.lang.reflect.Field
+import java.lang.reflect.Modifier
+import java.lang.reflect.{Array => JArray}
+import java.util.IdentityHashMap
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.mutable.ArrayBuffer
+
+
+/**
+ * Estimates the sizes of Java objects (number of bytes of memory they occupy),
+ * for use in memory-aware caches.
+ *
+ * Based on the following JavaWorld article:
+ * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html
+ */
+object SizeEstimator {
+ private val OBJECT_SIZE = 8 // Minimum size of a java.lang.Object
+ private val POINTER_SIZE = 4 // Size of an object reference
+
+ // Sizes of primitive types
+ private val BYTE_SIZE = 1
+ private val BOOLEAN_SIZE = 1
+ private val CHAR_SIZE = 2
+ private val SHORT_SIZE = 2
+ private val INT_SIZE = 4
+ private val LONG_SIZE = 8
+ private val FLOAT_SIZE = 4
+ private val DOUBLE_SIZE = 8
+
+ // A cache of ClassInfo objects for each class
+ private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo]
+ classInfos.put(classOf[Object], new ClassInfo(OBJECT_SIZE, Nil))
+
+ /**
+ * The state of an ongoing size estimation. Contains a stack of objects
+ * to visit as well as an IdentityHashMap of visited objects, and provides
+ * utility methods for enqueueing new objects to visit.
+ */
+ private class SearchState {
+ val visited = new IdentityHashMap[AnyRef, AnyRef]
+ val stack = new ArrayBuffer[AnyRef]
+ var size = 0L
+
+ def enqueue(obj: AnyRef) {
+ if (obj != null && !visited.containsKey(obj)) {
+ visited.put(obj, null)
+ stack += obj
+ }
+ }
+
+ def isFinished(): Boolean = stack.isEmpty
+
+ def dequeue(): AnyRef = {
+ val elem = stack.last
+ stack.trimEnd(1)
+ return elem
+ }
+ }
+
+ /**
+ * Cached information about each class. We remember two things: the
+ * "shell size" of the class (size of all non-static fields plus the
+ * java.lang.Object size), and any fields that are pointers to objects.
+ */
+ private class ClassInfo(
+ val shellSize: Long,
+ val pointerFields: List[Field]) {}
+
+ def estimate(obj: AnyRef): Long = {
+ val state = new SearchState
+ state.enqueue(obj)
+ while (!state.isFinished) {
+ visitSingleObject(state.dequeue(), state)
+ }
+ return state.size
+ }
+
+ private def visitSingleObject(obj: AnyRef, state: SearchState) {
+ val cls = obj.getClass
+ if (cls.isArray) {
+ visitArray(obj, cls, state)
+ } else {
+ val classInfo = getClassInfo(cls)
+ state.size += classInfo.shellSize
+ for (field <- classInfo.pointerFields) {
+ state.enqueue(field.get(obj))
+ }
+ }
+ }
+
+ private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) {
+ val length = JArray.getLength(array)
+ val elementClass = cls.getComponentType
+ if (elementClass.isPrimitive) {
+ state.size += length * primitiveSize(elementClass)
+ } else {
+ state.size += length * POINTER_SIZE
+ for (i <- 0 until length) {
+ state.enqueue(JArray.get(array, i))
+ }
+ }
+ }
+
+ private def primitiveSize(cls: Class[_]): Long = {
+ if (cls == classOf[Byte])
+ BYTE_SIZE
+ else if (cls == classOf[Boolean])
+ BOOLEAN_SIZE
+ else if (cls == classOf[Char])
+ CHAR_SIZE
+ else if (cls == classOf[Short])
+ SHORT_SIZE
+ else if (cls == classOf[Int])
+ INT_SIZE
+ else if (cls == classOf[Long])
+ LONG_SIZE
+ else if (cls == classOf[Float])
+ FLOAT_SIZE
+ else if (cls == classOf[Double])
+ DOUBLE_SIZE
+ else throw new IllegalArgumentException(
+ "Non-primitive class " + cls + " passed to primitiveSize()")
+ }
+
+ /**
+ * Get or compute the ClassInfo for a given class.
+ */
+ private def getClassInfo(cls: Class[_]): ClassInfo = {
+ // Check whether we've already cached a ClassInfo for this class
+ val info = classInfos.get(cls)
+ if (info != null) {
+ return info
+ }
+
+ val parent = getClassInfo(cls.getSuperclass)
+ var shellSize = parent.shellSize
+ var pointerFields = parent.pointerFields
+
+ for (field <- cls.getDeclaredFields) {
+ if (!Modifier.isStatic(field.getModifiers)) {
+ val fieldClass = field.getType
+ if (fieldClass.isPrimitive) {
+ shellSize += primitiveSize(fieldClass)
+ } else {
+ field.setAccessible(true) // Enable future get()'s on this field
+ shellSize += POINTER_SIZE
+ pointerFields = field :: pointerFields
+ }
+ }
+ }
+
+ // Create and cache a new ClassInfo
+ val newInfo = new ClassInfo(shellSize, pointerFields)
+ classInfos.put(cls, newInfo)
+ return newInfo
+ }
+}
diff --git a/src/scala/spark/SoftReferenceCache.scala b/src/scala/spark/SoftReferenceCache.scala
new file mode 100644
index 0000000000..e84aa57efa
--- /dev/null
+++ b/src/scala/spark/SoftReferenceCache.scala
@@ -0,0 +1,13 @@
+package spark
+
+import com.google.common.collect.MapMaker
+
+/**
+ * An implementation of Cache that uses soft references.
+ */
+class SoftReferenceCache extends Cache {
+ val map = new MapMaker().softValues().makeMap[Any, Any]()
+
+ override def get(key: Any): Any = map.get(key)
+ override def put(key: Any, value: Any) = map.put(key, value)
+}
diff --git a/src/scala/spark/SparkContext.scala b/src/scala/spark/SparkContext.scala
index ef328d821a..8b8e408266 100644
--- a/src/scala/spark/SparkContext.scala
+++ b/src/scala/spark/SparkContext.scala
@@ -1,26 +1,55 @@
package spark
import java.io._
-import java.util.UUID
import scala.collection.mutable.ArrayBuffer
import scala.actors.Actor._
-class SparkContext(master: String, frameworkName: String) extends Logging {
+import org.apache.hadoop.mapred.InputFormat
+import org.apache.hadoop.mapred.SequenceFileInputFormat
+
+
+class SparkContext(
+ master: String,
+ frameworkName: String,
+ val sparkHome: String = null,
+ val jars: Seq[String] = Nil)
+extends Logging {
+ private var scheduler: Scheduler = {
+ // Regular expression used for local[N] master format
+ val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
+ master match {
+ case "local" =>
+ new LocalScheduler(1)
+ case LOCAL_N_REGEX(threads) =>
+ new LocalScheduler(threads.toInt)
+ case _ =>
+ System.loadLibrary("mesos")
+ new MesosScheduler(this, master, frameworkName)
+ }
+ }
+
+ private val isLocal = scheduler.isInstanceOf[LocalScheduler]
+
+ // Start the scheduler, the cache and the broadcast system
+ scheduler.start()
+ Cache.initialize()
Broadcast.initialize(true)
- def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int) =
+ // Methods for creating RDDs
+
+ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int): RDD[T] =
new ParallelArray[T](this, seq, numSlices)
- def parallelize[T: ClassManifest](seq: Seq[T]): ParallelArray[T] =
- parallelize(seq, scheduler.numCores)
+ def parallelize[T: ClassManifest](seq: Seq[T]): RDD[T] =
+ parallelize(seq, numCores)
- def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
- new Accumulator(initialValue, param)
+ def textFile(path: String): RDD[String] =
+ new HadoopTextFile(this, path)
// TODO: Keep around a weak hash map of values to Cached versions?
- // def broadcast[T](value: T) = new DfsBroadcast(value, local)
- def broadcast[T](value: T) = new ChainedBroadcast(value, local)
+ // def broadcast[T](value: T) = new DfsBroadcast(value, isLocal)
+ def broadcast[T](value: T) = new ChainedBroadcast(value, isLocal)
// def broadcast[T](value: T) = {
// val broadcastClass = System.getProperty("spark.broadcast.Class",
@@ -35,39 +64,88 @@ class SparkContext(master: String, frameworkName: String) extends Logging {
// instance = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
// }
+ /** Get an RDD for a Hadoop file with an arbitrary InputFormat */
+ def hadoopFile[K, V](path: String,
+ inputFormatClass: Class[_ <: InputFormat[K, V]],
+ keyClass: Class[K],
+ valueClass: Class[V])
+ : RDD[(K, V)] = {
+ new HadoopFile(this, path, inputFormatClass, keyClass, valueClass)
+ }
- def textFile(path: String) = new HdfsTextFile(this, path)
+ /**
+ * Smarter version of hadoopFile() that uses class manifests to figure out
+ * the classes of keys, values and the InputFormat so that users don't need
+ * to pass them directly.
+ */
+ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
+ (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
+ : RDD[(K, V)] = {
+ hadoopFile(path,
+ fm.erasure.asInstanceOf[Class[F]],
+ km.erasure.asInstanceOf[Class[K]],
+ vm.erasure.asInstanceOf[Class[V]])
+ }
- val LOCAL_REGEX = """local\[([0-9]+)\]""".r
+ /** Get an RDD for a Hadoop SequenceFile with given key and value types */
+ def sequenceFile[K, V](path: String,
+ keyClass: Class[K],
+ valueClass: Class[V]): RDD[(K, V)] = {
+ val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
+ hadoopFile(path, inputFormatClass, keyClass, valueClass)
+ }
- private[spark] var scheduler: Scheduler = master match {
- case "local" => new LocalScheduler(1)
- case LOCAL_REGEX(threads) => new LocalScheduler(threads.toInt)
- case _ => { System.loadLibrary("mesos");
- new MesosScheduler(master, frameworkName, createExecArg()) }
+ /**
+ * Smarter version of sequenceFile() that obtains the key and value classes
+ * from ClassManifests instead of requiring the user to pass them directly.
+ */
+ def sequenceFile[K, V](path: String)
+ (implicit km: ClassManifest[K], vm: ClassManifest[V]): RDD[(K, V)] = {
+ sequenceFile(path,
+ km.erasure.asInstanceOf[Class[K]],
+ vm.erasure.asInstanceOf[Class[V]])
}
- private val local = scheduler.isInstanceOf[LocalScheduler]
+ /** Build the union of a list of RDDs. */
+ def union[T: ClassManifest](rdds: RDD[T]*): RDD[T] =
+ new UnionRDD(this, rdds)
- scheduler.start()
+ // Methods for creating shared variables
- private def createExecArg(): Array[Byte] = {
- // Our executor arg is an array containing all the spark.* system properties
- val props = new ArrayBuffer[(String, String)]
- val iter = System.getProperties.entrySet.iterator
- while (iter.hasNext) {
- val entry = iter.next
- val (key, value) = (entry.getKey.toString, entry.getValue.toString)
- if (key.startsWith("spark."))
- props += key -> value
- }
- return Utils.serialize(props.toArray)
+ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
+ new Accumulator(initialValue, param)
+
+ // Stop the SparkContext
+ def stop() {
+ scheduler.stop()
+ scheduler = null
+ }
+
+ // Wait for the scheduler to be registered
+ def waitForRegister() {
+ scheduler.waitForRegister()
+ }
+
+ // Get Spark's home location from either a value set through the constructor,
+ // or the spark.home Java property, or the SPARK_HOME environment variable
+ // (in that order of preference). If neither of these is set, return None.
+ def getSparkHome(): Option[String] = {
+ if (sparkHome != null)
+ Some(sparkHome)
+ else if (System.getProperty("spark.home") != null)
+ Some(System.getProperty("spark.home"))
+ else if (System.getenv("SPARK_HOME") != null)
+ Some(System.getenv("SPARK_HOME"))
+ else
+ None
}
+ // Submit an array of tasks (passed as functions) to the scheduler
def runTasks[T: ClassManifest](tasks: Array[() => T]): Array[T] = {
runTaskObjects(tasks.map(f => new FunctionTask(f)))
}
+ // Run an array of spark.Task objects
private[spark] def runTaskObjects[T: ClassManifest](tasks: Seq[Task[T]])
: Array[T] = {
logInfo("Running " + tasks.length + " tasks in parallel")
@@ -77,23 +155,22 @@ class SparkContext(master: String, frameworkName: String) extends Logging {
return result
}
- def stop() {
- scheduler.stop()
- scheduler = null
- }
-
- def waitForRegister() {
- scheduler.waitForRegister()
- }
-
// Clean a closure to make it ready to serialized and send to tasks
// (removes unreferenced variables in $outer's, updates REPL variables)
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
return f
}
+
+ // Get the number of cores available to run tasks (as reported by Scheduler)
+ def numCores = scheduler.numCores
}
+
+/**
+ * The SparkContext object contains a number of implicit conversions and
+ * parameters for use with various Spark features.
+ */
object SparkContext {
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
diff --git a/src/scala/spark/SparkException.scala b/src/scala/spark/SparkException.scala
index 7257bf7b0c..6f9be1a94f 100644
--- a/src/scala/spark/SparkException.scala
+++ b/src/scala/spark/SparkException.scala
@@ -1,7 +1,3 @@
package spark
-class SparkException(message: String) extends Exception(message) {
- def this(message: String, errorCode: Int) {
- this("%s (error code: %d)".format(message, errorCode))
- }
-}
+class SparkException(message: String) extends Exception(message) {}
diff --git a/src/scala/spark/Split.scala b/src/scala/spark/Split.scala
index 0f7a21354d..116cd16370 100644
--- a/src/scala/spark/Split.scala
+++ b/src/scala/spark/Split.scala
@@ -3,7 +3,7 @@ package spark
/**
* A partition of an RDD.
*/
-trait Split {
+@serializable trait Split {
/**
* Get a unique ID for this split which can be used, for example, to
* set up caches based on it. The ID should stay the same if we serialize
diff --git a/src/scala/spark/Utils.scala b/src/scala/spark/Utils.scala
index 27d73aefbd..e333dd9c91 100644
--- a/src/scala/spark/Utils.scala
+++ b/src/scala/spark/Utils.scala
@@ -1,12 +1,18 @@
package spark
import java.io._
+import java.net.InetAddress
+import java.util.UUID
import scala.collection.mutable.ArrayBuffer
+import scala.util.Random
+/**
+ * Various utility methods used by Spark.
+ */
object Utils {
def serialize[T](o: T): Array[Byte] = {
- val bos = new ByteArrayOutputStream
+ val bos = new ByteArrayOutputStream()
val oos = new ObjectOutputStream(bos)
oos.writeObject(o)
oos.close
@@ -50,4 +56,72 @@ object Utils {
}
return buf
}
+
+ // Create a temporary directory inside the given parent directory
+ def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File =
+ {
+ var attempts = 0
+ val maxAttempts = 10
+ var dir: File = null
+ while (dir == null) {
+ attempts += 1
+ if (attempts > maxAttempts) {
+ throw new IOException("Failed to create a temp directory " +
+ "after " + maxAttempts + " attempts!")
+ }
+ try {
+ dir = new File(root, "spark-" + UUID.randomUUID.toString)
+ if (dir.exists() || !dir.mkdirs()) {
+ dir = null
+ }
+ } catch { case e: IOException => ; }
+ }
+ return dir
+ }
+
+ // Copy all data from an InputStream to an OutputStream
+ def copyStream(in: InputStream,
+ out: OutputStream,
+ closeStreams: Boolean = false)
+ {
+ val buf = new Array[Byte](8192)
+ var n = 0
+ while (n != -1) {
+ n = in.read(buf)
+ if (n != -1) {
+ out.write(buf, 0, n)
+ }
+ }
+ if (closeStreams) {
+ in.close()
+ out.close()
+ }
+ }
+
+ // Shuffle the elements of a collection into a random order, returning the
+ // result in a new collection. Unlike scala.util.Random.shuffle, this method
+ // uses a local random number generator, avoiding inter-thread contention.
+ def shuffle[T](seq: TraversableOnce[T]): Seq[T] = {
+ val buf = new ArrayBuffer[T]()
+ buf ++= seq
+ val rand = new Random()
+ for (i <- (buf.size - 1) to 1 by -1) {
+ val j = rand.nextInt(i)
+ val tmp = buf(j)
+ buf(j) = buf(i)
+ buf(i) = tmp
+ }
+ buf
+ }
+
+ /**
+ * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4)
+ */
+ def localIpAddress(): String = {
+ // Get local IP as an array of four bytes
+ val bytes = InetAddress.getLocalHost().getAddress()
+ // Convert the bytes to ints (keeping in mind that they may be negative)
+ // and join them into a string
+ return bytes.map(b => (b.toInt + 256) % 256).mkString(".")
+ }
}
diff --git a/src/scala/spark/WeakReferenceCache.scala b/src/scala/spark/WeakReferenceCache.scala
new file mode 100644
index 0000000000..ddca065454
--- /dev/null
+++ b/src/scala/spark/WeakReferenceCache.scala
@@ -0,0 +1,14 @@
+package spark
+
+import com.google.common.collect.MapMaker
+
+/**
+ * An implementation of Cache that uses weak references.
+ */
+class WeakReferenceCache extends Cache {
+ val map = new MapMaker().weakValues().makeMap[Any, Any]()
+
+ override def get(key: Any): Any = map.get(key)
+ override def put(key: Any, value: Any) = map.put(key, value)
+}
+
diff --git a/src/scala/spark/repl/SparkInterpreter.scala b/src/scala/spark/repl/SparkInterpreter.scala
index ae2e7e8a68..10ea346658 100644
--- a/src/scala/spark/repl/SparkInterpreter.scala
+++ b/src/scala/spark/repl/SparkInterpreter.scala
@@ -36,6 +36,9 @@ import scala.tools.nsc.{ InterpreterResults => IR }
import interpreter._
import SparkInterpreter._
+import spark.HttpServer
+import spark.Utils
+
/** <p>
* An interpreter for Scala code.
* </p>
@@ -92,27 +95,12 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) {
/** Local directory to save .class files too */
val outputDir = {
- val rootDir = new File(System.getProperty("spark.repl.classdir",
- System.getProperty("java.io.tmpdir")))
- var attempts = 0
- val maxAttempts = 10
- var dir: File = null
- while (dir == null) {
- attempts += 1
- if (attempts > maxAttempts) {
- throw new IOException("Failed to create a temp directory " +
- "after " + maxAttempts + " attempts!")
- }
- try {
- dir = new File(rootDir, "spark-" + UUID.randomUUID.toString)
- if (dir.exists() || !dir.mkdirs())
- dir = null
- } catch { case e: IOException => ; }
- }
- if (SPARK_DEBUG_REPL) {
- println("Output directory: " + dir)
- }
- dir
+ val tmp = System.getProperty("java.io.tmpdir")
+ val rootDir = System.getProperty("spark.repl.classdir", tmp)
+ Utils.createTempDir(rootDir)
+ }
+ if (SPARK_DEBUG_REPL) {
+ println("Output directory: " + outputDir)
}
/** Scala compiler virtual directory for outputDir */
@@ -120,14 +108,14 @@ class SparkInterpreter(val settings: Settings, out: PrintWriter) {
val virtualDirectory = new PlainFile(outputDir)
/** Jetty server that will serve our classes to worker nodes */
- val classServer = new ClassServer(outputDir)
+ val classServer = new HttpServer(outputDir)
// Start the classServer and store its URI in a spark system property
// (which will be passed to executors so that they can connect to it)
classServer.start()
System.setProperty("spark.repl.class.uri", classServer.uri)
if (SPARK_DEBUG_REPL) {
- println("ClassServer started, URI = " + classServer.uri)
+ println("Class server started, URI = " + classServer.uri)
}
/** reporter */
diff --git a/src/test/spark/ShuffleSuite.scala b/src/test/spark/ShuffleSuite.scala
new file mode 100644
index 0000000000..a5773614e8
--- /dev/null
+++ b/src/test/spark/ShuffleSuite.scala
@@ -0,0 +1,130 @@
+package spark
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+
+import SparkContext._
+
+class ShuffleSuite extends FunSuite {
+ test("groupByKey") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with duplicates") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with negative key hash codes") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
+ val groups = pairs.groupByKey().collect()
+ assert(groups.size === 2)
+ val valuesForMinus1 = groups.find(_._1 == -1).get._2
+ assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("groupByKey with many output partitions") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
+ val groups = pairs.groupByKey(10).collect()
+ assert(groups.size === 2)
+ val valuesFor1 = groups.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1, 2, 3))
+ val valuesFor2 = groups.find(_._1 == 2).get._2
+ assert(valuesFor2.toList.sorted === List(1))
+ }
+
+ test("reduceByKey") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("reduceByKey with collectAsMap") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_).collectAsMap()
+ assert(sums.size === 2)
+ assert(sums(1) === 7)
+ assert(sums(2) === 1)
+ }
+
+ test("reduceByKey with many output partitons") {
+ val sc = new SparkContext("local", "test")
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
+ val sums = pairs.reduceByKey(_+_, 10).collect()
+ assert(sums.toSet === Set((1, 7), (2, 1)))
+ }
+
+ test("join") {
+ val sc = new SparkContext("local", "test")
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+
+ test("join all-to-all") {
+ val sc = new SparkContext("local", "test")
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 6)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (1, 'y')),
+ (1, (2, 'x')),
+ (1, (2, 'y')),
+ (1, (3, 'x')),
+ (1, (3, 'y'))
+ ))
+ }
+
+ test("join with no matches") {
+ val sc = new SparkContext("local", "test")
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
+ val joined = rdd1.join(rdd2).collect()
+ assert(joined.size === 0)
+ }
+
+ test("join with many output partitions") {
+ val sc = new SparkContext("local", "test")
+ val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+ val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
+ val joined = rdd1.join(rdd2, 10).collect()
+ assert(joined.size === 4)
+ assert(joined.toSet === Set(
+ (1, (1, 'x')),
+ (1, (2, 'x')),
+ (2, (1, 'y')),
+ (2, (1, 'z'))
+ ))
+ }
+}
diff --git a/src/test/spark/repl/ReplSuite.scala b/src/test/spark/repl/ReplSuite.scala
index dcf71182ec..8b38cde85f 100644
--- a/src/test/spark/repl/ReplSuite.scala
+++ b/src/test/spark/repl/ReplSuite.scala
@@ -39,9 +39,9 @@ class ReplSuite extends FunSuite {
test ("external vars") {
val output = runInterpreter("local", """
var v = 7
- sc.parallelize(1 to 10).map(x => v).toArray.reduceLeft(_+_)
+ sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
v = 10
- sc.parallelize(1 to 10).map(x => v).toArray.reduceLeft(_+_)
+ sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
@@ -54,7 +54,7 @@ class ReplSuite extends FunSuite {
class C {
def foo = 5
}
- sc.parallelize(1 to 10).map(x => (new C).foo).toArray.reduceLeft(_+_)
+ sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_)
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
@@ -64,7 +64,7 @@ class ReplSuite extends FunSuite {
test ("external functions") {
val output = runInterpreter("local", """
def double(x: Int) = x + x
- sc.parallelize(1 to 10).map(x => double(x)).toArray.reduceLeft(_+_)
+ sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_)
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
@@ -75,9 +75,9 @@ class ReplSuite extends FunSuite {
val output = runInterpreter("local", """
var v = 7
def getV() = v
- sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_)
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
v = 10
- sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_)
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
@@ -92,9 +92,9 @@ class ReplSuite extends FunSuite {
val output = runInterpreter("local", """
var array = new Array[Int](5)
val broadcastArray = sc.broadcast(array)
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
array(0) = 5
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
@@ -102,24 +102,26 @@ class ReplSuite extends FunSuite {
assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output)
}
- test ("running on Mesos") {
- val output = runInterpreter("localquiet", """
- var v = 7
- def getV() = v
- sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_)
- v = 10
- sc.parallelize(1 to 10).map(x => getV()).toArray.reduceLeft(_+_)
- var array = new Array[Int](5)
- val broadcastArray = sc.broadcast(array)
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray
- array(0) = 5
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).toArray
- """)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- assertContains("res0: Int = 70", output)
- assertContains("res1: Int = 100", output)
- assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
- assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ if (System.getenv("MESOS_HOME") != null) {
+ test ("running on Mesos") {
+ val output = runInterpreter("localquiet", """
+ var v = 7
+ def getV() = v
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ v = 10
+ sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ var array = new Array[Int](5)
+ val broadcastArray = sc.broadcast(array)
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ array(0) = 5
+ sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res1: Int = 100", output)
+ assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ }
}
}
diff --git a/third_party/guava-r06/guava-r06.jar b/third_party/guava-r06/guava-r06.jar
deleted file mode 100644
index 8ff3a81748..0000000000
--- a/third_party/guava-r06/guava-r06.jar
+++ /dev/null
Binary files differ
diff --git a/third_party/guava-r06/COPYING b/third_party/guava-r07/COPYING
index d645695673..d645695673 100644
--- a/third_party/guava-r06/COPYING
+++ b/third_party/guava-r07/COPYING
diff --git a/third_party/guava-r06/README b/third_party/guava-r07/README
index a0e832dd54..a0e832dd54 100644
--- a/third_party/guava-r06/README
+++ b/third_party/guava-r07/README
diff --git a/third_party/guava-r07/guava-r07.jar b/third_party/guava-r07/guava-r07.jar
new file mode 100644
index 0000000000..a6c9ce02df
--- /dev/null
+++ b/third_party/guava-r07/guava-r07.jar
Binary files differ
diff --git a/third_party/mesos.jar b/third_party/mesos.jar
index 1852cf8fd0..60d299c8af 100644
--- a/third_party/mesos.jar
+++ b/third_party/mesos.jar
Binary files differ