aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-09-11 17:04:17 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-09-11 17:04:17 -0700
commita29ac5f9cf3b63cdb0bdd864dc0fea3d3d8db095 (patch)
treeeeeadcd957958ad0210d86d0e8defde534ab28eb
parent943df48348662d1ca17091dd403c5365e27924a8 (diff)
parent5e4076e3f2eb6b0206119c5d67ac6ee405cee1ad (diff)
downloadspark-a29ac5f9cf3b63cdb0bdd864dc0fea3d3d8db095.tar.gz
spark-a29ac5f9cf3b63cdb0bdd864dc0fea3d3d8db095.tar.bz2
spark-a29ac5f9cf3b63cdb0bdd864dc0fea3d3d8db095.zip
Merge pull request #195 from dennybritz/feature/fileserver
Spark HTTP FileServer
-rw-r--r--core/src/main/scala/spark/HttpFileServer.scala45
-rw-r--r--core/src/main/scala/spark/SparkContext.scala62
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala15
-rw-r--r--core/src/main/scala/spark/Utils.scala57
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala32
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala81
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala56
-rw-r--r--core/src/main/scala/spark/scheduler/Task.scala31
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala32
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala26
-rw-r--r--core/src/test/resources/uncommons-maths-1.2.2.jarbin0 -> 49019 bytes
-rw-r--r--core/src/test/scala/spark/FileServerSuite.scala93
12 files changed, 414 insertions, 116 deletions
diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala
new file mode 100644
index 0000000000..e6ad4dd28e
--- /dev/null
+++ b/core/src/main/scala/spark/HttpFileServer.scala
@@ -0,0 +1,45 @@
+package spark
+
+import java.io.{File, PrintWriter}
+import java.net.URL
+import scala.collection.mutable.HashMap
+import org.apache.hadoop.fs.FileUtil
+
+class HttpFileServer extends Logging {
+
+ var baseDir : File = null
+ var fileDir : File = null
+ var jarDir : File = null
+ var httpServer : HttpServer = null
+ var serverUri : String = null
+
+ def initialize() {
+ baseDir = Utils.createTempDir()
+ fileDir = new File(baseDir, "files")
+ jarDir = new File(baseDir, "jars")
+ fileDir.mkdir()
+ jarDir.mkdir()
+ logInfo("HTTP File server directory is " + baseDir)
+ httpServer = new HttpServer(fileDir)
+ httpServer.start()
+ serverUri = httpServer.uri
+ }
+
+ def stop() {
+ httpServer.stop()
+ }
+
+ def addFile(file: File) : String = {
+ return addFileToDir(file, fileDir)
+ }
+
+ def addJar(file: File) : String = {
+ return addFileToDir(file, jarDir)
+ }
+
+ def addFileToDir(file: File, dir: File) : String = {
+ Utils.copyFile(file, new File(dir, file.getName))
+ return dir + "/" + file.getName
+ }
+
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 0dec44979f..758c42fa61 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -2,14 +2,15 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
+import java.net.{URI, URLClassLoader}
import akka.actor.Actor
import akka.actor.Actor._
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.generic.Growable
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
@@ -77,7 +78,14 @@ class SparkContext(
true,
isLocal)
SparkEnv.set(env)
-
+
+ // Used to store a URL for each static file/jar together with the file's local timestamp
+ val addedFiles = HashMap[String, Long]()
+ val addedJars = HashMap[String, Long]()
+
+ // Add each JAR given through the constructor
+ jars.foreach { addJar(_) }
+
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
@@ -91,13 +99,13 @@ class SparkContext(
master match {
case "local" =>
- new LocalScheduler(1, 0)
+ new LocalScheduler(1, 0, this)
case LOCAL_N_REGEX(threads) =>
- new LocalScheduler(threads.toInt, 0)
+ new LocalScheduler(threads.toInt, 0, this)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
- new LocalScheduler(threads.toInt, maxFailures.toInt)
+ new LocalScheduler(threads.toInt, maxFailures.toInt, this)
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
@@ -132,7 +140,7 @@ class SparkContext(
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
-
+
// Methods for creating RDDs
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
@@ -321,7 +329,44 @@ class SparkContext(
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
+
+ // Adds a file dependency to all Tasks executed in the future.
+ def addFile(path: String) {
+ val uri = new URI(path)
+ val key = uri.getScheme match {
+ case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
+ case _ => path
+ }
+ addedFiles(key) = System.currentTimeMillis
+
+ // Fetch the file locally in case the task is executed locally
+ val filename = new File(path.split("/").last)
+ Utils.fetchFile(path, new File("."))
+
+ logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
+ }
+ def clearFiles() {
+ addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
+ addedFiles.clear()
+ }
+
+ // Adds a jar dependency to all Tasks executed in the future.
+ def addJar(path: String) {
+ val uri = new URI(path)
+ val key = uri.getScheme match {
+ case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
+ case _ => path
+ }
+ addedJars(key) = System.currentTimeMillis
+ logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
+ }
+
+ def clearJars() {
+ addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
+ addedJars.clear()
+ }
+
// Stop the SparkContext
def stop() {
dagScheduler.stop()
@@ -329,6 +374,9 @@ class SparkContext(
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
+ // Clean up locally linked files
+ clearFiles()
+ clearJars()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
logInfo("Successfully stopped SparkContext")
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index add8fcec51..a95d1bc8ea 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -19,15 +19,17 @@ class SparkEnv (
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
- val connectionManager: ConnectionManager
+ val connectionManager: ConnectionManager,
+ val httpFileServer: HttpFileServer
) {
/** No-parameter constructor for unit tests. */
def this() = {
- this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
+ this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null, null)
}
def stop() {
+ httpFileServer.stop()
mapOutputTracker.stop()
cacheTracker.stop()
shuffleFetcher.stop()
@@ -95,7 +97,11 @@ object SparkEnv {
System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val shuffleFetcher =
Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
-
+
+ val httpFileServer = new HttpFileServer()
+ httpFileServer.initialize()
+ System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
+
/*
if (System.getProperty("spark.stream.distributed", "false") == "true") {
val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]]
@@ -126,6 +132,7 @@ object SparkEnv {
shuffleManager,
broadcastManager,
blockManager,
- connectionManager)
+ connectionManager,
+ httpFileServer)
}
}
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 5eda1011f9..07aa18e540 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -1,18 +1,19 @@
package spark
import java.io._
-import java.net.InetAddress
+import java.net.{InetAddress, URL, URI}
+import java.util.{Locale, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
-
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
-import java.util.{Locale, UUID}
import scala.io.Source
/**
* Various utility methods used by Spark.
*/
-object Utils {
+object Utils extends Logging {
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@@ -115,6 +116,54 @@ object Utils {
val out = new FileOutputStream(dest)
copyStream(in, out, true)
}
+
+
+
+ /* Download a file from a given URL to the local filesystem */
+ def downloadFile(url: URL, localPath: String) {
+ val in = url.openStream()
+ val out = new FileOutputStream(localPath)
+ Utils.copyStream(in, out, true)
+ }
+
+ /**
+ * Download a file requested by the executor. Supports fetching the file in a variety of ways,
+ * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
+ */
+ def fetchFile(url: String, targetDir: File) {
+ val filename = url.split("/").last
+ val targetFile = new File(targetDir, filename)
+ val uri = new URI(url)
+ uri.getScheme match {
+ case "http" | "https" | "ftp" =>
+ logInfo("Fetching " + url + " to " + targetFile)
+ val in = new URL(url).openStream()
+ val out = new FileOutputStream(targetFile)
+ Utils.copyStream(in, out, true)
+ case "file" | null =>
+ // Remove the file if it already exists
+ targetFile.delete()
+ // Symlink the file locally
+ logInfo("Symlinking " + url + " to " + targetFile)
+ FileUtil.symLink(url, targetFile.toString)
+ case _ =>
+ // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
+ val uri = new URI(url)
+ val conf = new Configuration()
+ val fs = FileSystem.get(uri, conf)
+ val in = fs.open(new Path(uri))
+ val out = new FileOutputStream(targetFile)
+ Utils.copyStream(in, out, true)
+ }
+ // Decompress the file if it's a .tar or .tar.gz
+ if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
+ logInfo("Untarring " + filename)
+ Utils.execute(Seq("tar", "-xzf", filename), targetDir)
+ } else if (filename.endsWith(".tar")) {
+ logInfo("Untarring " + filename)
+ Utils.execute(Seq("tar", "-xf", filename), targetDir)
+ }
+ }
/**
* Shuffle the elements of a collection into a random order, returning the
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index 1740a42a7e..7043361020 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -65,38 +65,6 @@ class ExecutorRunner(
}
}
- /**
- * Download a file requested by the executor. Supports fetching the file in a variety of ways,
- * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
- */
- def fetchFile(url: String, targetDir: File) {
- val filename = url.split("/").last
- val targetFile = new File(targetDir, filename)
- if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) {
- // Use the java.net library to fetch it
- logInfo("Fetching " + url + " to " + targetFile)
- val in = new URL(url).openStream()
- val out = new FileOutputStream(targetFile)
- Utils.copyStream(in, out, true)
- } else {
- // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
- val uri = new URI(url)
- val conf = new Configuration()
- val fs = FileSystem.get(uri, conf)
- val in = fs.open(new Path(uri))
- val out = new FileOutputStream(targetFile)
- Utils.copyStream(in, out, true)
- }
- // Decompress the file if it's a .tar or .tar.gz
- if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
- logInfo("Untarring " + filename)
- Utils.execute(Seq("tar", "-xzf", filename), targetDir)
- } else if (filename.endsWith(".tar")) {
- logInfo("Untarring " + filename)
- Utils.execute(Seq("tar", "-xf", filename), targetDir)
- }
- }
-
/** Replace variables such as {{SLAVEID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
case "{{SLAVEID}}" => workerId
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index dba209ac27..8f975c52d4 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -1,10 +1,12 @@
package spark.executor
import java.io.{File, FileOutputStream}
-import java.net.{URL, URLClassLoader}
+import java.net.{URI, URL, URLClassLoader}
import java.util.concurrent._
-import scala.collection.mutable.ArrayBuffer
+import org.apache.hadoop.fs.FileUtil
+
+import scala.collection.mutable.{ArrayBuffer, Map, HashMap}
import spark.broadcast._
import spark.scheduler._
@@ -15,9 +17,13 @@ import java.nio.ByteBuffer
* The Mesos executor for Spark.
*/
class Executor extends Logging {
- var classLoader: ClassLoader = null
+ var urlClassLoader : ExecutorURLClassLoader = null
var threadPool: ExecutorService = null
var env: SparkEnv = null
+
+ val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
+ val jarSet: HashMap[String, Long] = new HashMap[String, Long]()
+
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
@@ -36,13 +42,14 @@ class Executor extends Logging {
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
SparkEnv.set(env)
- // Create our ClassLoader (using spark properties) and set it on this thread
- classLoader = createClassLoader()
- Thread.currentThread.setContextClassLoader(classLoader)
-
// Start worker thread pool
threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
+
+ // Create our ClassLoader and set it on this thread
+ urlClassLoader = createClassLoader()
+ Thread.currentThread.setContextClassLoader(urlClassLoader)
+
}
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
@@ -54,15 +61,16 @@ class Executor extends Logging {
override def run() {
SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(classLoader)
+ Thread.currentThread.setContextClassLoader(urlClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
try {
SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear()
- val task = ser.deserialize[Task[Any]](serializedTask, classLoader)
+ val task = ser.deserialize[Task[Any]](serializedTask, urlClassLoader)
+ task.downloadDependencies(fileSet, jarSet)
+ updateClassLoader()
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(taskId.toInt)
@@ -96,25 +104,16 @@ class Executor extends Logging {
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
*/
- private def createClassLoader(): ClassLoader = {
- var loader = this.getClass.getClassLoader
-
- // If any JAR URIs are given through spark.jar.uris, fetch them to the
- // current directory and put them all on the classpath. We assume that
- // each URL has a unique file name so that no local filenames will clash
- // in this process. This is guaranteed by ClusterScheduler.
- val uris = System.getProperty("spark.jar.uris", "")
- val localFiles = ArrayBuffer[String]()
- for (uri <- uris.split(",").filter(_.size > 0)) {
- val url = new URL(uri)
- val filename = url.getPath.split("/").last
- downloadFile(url, filename)
- localFiles += filename
- }
- if (localFiles.size > 0) {
- val urls = localFiles.map(f => new File(f).toURI.toURL).toArray
- loader = new URLClassLoader(urls, loader)
- }
+ private def createClassLoader(): ExecutorURLClassLoader = {
+
+ var loader = this.getClass().getClassLoader()
+
+ // For each of the jars in the jarSet, add them to the class loader.
+ // We assume each of the files has already been fetched.
+ val urls = jarSet.keySet.map { uri =>
+ new File(uri.split("/").last).toURI.toURL
+ }.toArray
+ loader = new URLClassLoader(urls, loader)
// If the REPL is in use, add another ClassLoader that will read
// new classes defined by the REPL as the user types code
@@ -133,13 +132,25 @@ class Executor extends Logging {
}
}
- return loader
+ return new ExecutorURLClassLoader(Array(), loader)
}
- // Download a file from a given URL to the local filesystem
- private def downloadFile(url: URL, localPath: String) {
- val in = url.openStream()
- val out = new FileOutputStream(localPath)
- Utils.copyStream(in, out, true)
+ def updateClassLoader() {
+ val currentURLs = urlClassLoader.getURLs()
+ val urlSet = jarSet.keySet.map { x => new File(x.split("/").last).toURI.toURL }
+ urlSet.filterNot(currentURLs.contains(_)).foreach { url =>
+ logInfo("Adding " + url + " to the class loader.")
+ urlClassLoader.addURL(url)
+ }
+
}
+
+ // The addURL method in URLClassLoader is protected. We subclass it to make it accessible.
+ class ExecutorURLClassLoader(urls : Array[URL], parent : ClassLoader)
+ extends URLClassLoader(urls, parent) {
+ override def addURL(url: URL) {
+ super.addURL(url)
+ }
+ }
+
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index a281ae94c5..b9f0a0d6d0 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -1,10 +1,10 @@
package spark.scheduler
import java.io._
-import java.util.HashMap
+import java.util.{HashMap => JHashMap}
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
@@ -20,7 +20,9 @@ object ShuffleMapTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
- val serializedInfoCache = new HashMap[Int, Array[Byte]]
+ val serializedInfoCache = new JHashMap[Int, Array[Byte]]
+ val fileSetCache = new JHashMap[Int, Array[Byte]]
+ val jarSetCache = new JHashMap[Int, Array[Byte]]
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
synchronized {
@@ -40,6 +42,23 @@ object ShuffleMapTask {
}
}
+ // Since both the JarSet and FileSet have the same format this is used for both.
+ def serializeFileSet(set : HashMap[String, Long], stageId: Int, cache : JHashMap[Int, Array[Byte]]) : Array[Byte] = {
+ val old = cache.get(stageId)
+ if (old != null) {
+ return old
+ } else {
+ val out = new ByteArrayOutputStream
+ val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
+ objOut.writeObject(set.toArray)
+ objOut.close()
+ val bytes = out.toByteArray
+ cache.put(stageId, bytes)
+ return bytes
+ }
+ }
+
+
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
@@ -54,9 +73,19 @@ object ShuffleMapTask {
}
}
+ // Since both the JarSet and FileSet have the same format this is used for both.
+ def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = {
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val objIn = new ObjectInputStream(in)
+ val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
+ return (HashMap(set.toSeq: _*))
+ }
+
def clearCache() {
synchronized {
serializedInfoCache.clear()
+ fileSetCache.clear()
+ jarSetCache.clear()
}
}
}
@@ -84,6 +113,14 @@ class ShuffleMapTask(
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
out.write(bytes)
+
+ val fileSetBytes = ShuffleMapTask.serializeFileSet(fileSet, stageId, ShuffleMapTask.fileSetCache)
+ out.writeInt(fileSetBytes.length)
+ out.write(fileSetBytes)
+ val jarSetBytes = ShuffleMapTask.serializeFileSet(jarSet, stageId, ShuffleMapTask.jarSetCache)
+ out.writeInt(jarSetBytes.length)
+ out.write(jarSetBytes)
+
out.writeInt(partition)
out.writeLong(generation)
out.writeObject(split)
@@ -97,6 +134,17 @@ class ShuffleMapTask(
val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
rdd = rdd_
dep = dep_
+
+ val fileSetNumBytes = in.readInt()
+ val fileSetBytes = new Array[Byte](fileSetNumBytes)
+ in.readFully(fileSetBytes)
+ fileSet = ShuffleMapTask.deserializeFileSet(fileSetBytes)
+
+ val jarSetNumBytes = in.readInt()
+ val jarSetBytes = new Array[Byte](jarSetNumBytes)
+ in.readFully(jarSetBytes)
+ jarSet = ShuffleMapTask.deserializeFileSet(jarSetBytes)
+
partition = in.readInt()
generation = in.readLong()
split = in.readObject().asInstanceOf[Split]
@@ -110,7 +158,7 @@ class ShuffleMapTask(
val bucketIterators =
if (aggregator.mapSideCombine) {
// Apply combiners (map-side aggregation) to the map output.
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
+ val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
val (k, v) = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(k)
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
index f84d8d9c4f..0d5b71b06c 100644
--- a/core/src/main/scala/spark/scheduler/Task.scala
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -1,5 +1,10 @@
package spark.scheduler
+import scala.collection.mutable.{HashMap}
+import spark.HttpFileServer
+import spark.Utils
+import java.io.File
+
/**
* A task to execute on a worker node.
*/
@@ -8,4 +13,30 @@ abstract class Task[T](val stageId: Int) extends Serializable {
def preferredLocations: Seq[String] = Nil
var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
+
+ // Stores jar and file dependencies for this task.
+ var fileSet : HashMap[String, Long] = new HashMap[String, Long]()
+ var jarSet : HashMap[String, Long] = new HashMap[String, Long]()
+
+ // Downloads all file dependencies from the Master file server
+ def downloadDependencies(currentFileSet : HashMap[String, Long],
+ currentJarSet : HashMap[String, Long]) {
+
+ // Fetch missing file dependencies
+ fileSet.filter { case(k,v) =>
+ !currentFileSet.contains(k) || currentFileSet(k) <= v
+ }.foreach { case (k,v) =>
+ Utils.fetchFile(k, new File(System.getProperty("user.dir")))
+ currentFileSet(k) = v
+ }
+ // Fetch missing jar dependencies
+ jarSet.filter { case(k,v) =>
+ !currentJarSet.contains(k) || currentJarSet(k) <= v
+ }.foreach { case (k,v) =>
+ Utils.fetchFile(k, new File(System.getProperty("user.dir")))
+ currentJarSet(k) = v
+ }
+
+ }
+
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 5b59479682..750231ac31 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -60,7 +60,6 @@ class ClusterScheduler(sc: SparkContext)
def initialize(context: SchedulerBackend) {
backend = context
- createJarServer()
}
def newTaskId(): Long = nextTaskId.getAndIncrement()
@@ -88,6 +87,10 @@ class ClusterScheduler(sc: SparkContext)
def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
+ tasks.foreach { task =>
+ task.fileSet ++= sc.addedFiles
+ task.jarSet ++= sc.addedJars
+ }
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = new TaskSetManager(this, taskSet)
@@ -235,32 +238,7 @@ class ClusterScheduler(sc: SparkContext)
}
override def defaultParallelism() = backend.defaultParallelism()
-
- // Create a server for all the JARs added by the user to SparkContext.
- // We first copy the JARs to a temp directory for easier server setup.
- private def createJarServer() {
- val jarDir = Utils.createTempDir()
- logInfo("Temp directory for JARs: " + jarDir)
- val filenames = ArrayBuffer[String]()
- // Copy each JAR to a unique filename in the jarDir
- for ((path, index) <- sc.jars.zipWithIndex) {
- val file = new File(path)
- if (file.exists) {
- val filename = index + "_" + file.getName
- Utils.copyFile(file, new File(jarDir, filename))
- filenames += filename
- }
- }
- // Create the server
- jarServer = new HttpServer(jarDir)
- jarServer.start()
- // Build up the jar URI list
- val serverUri = jarServer.uri
- jarUris = filenames.map(f => serverUri + "/" + f).mkString(",")
- System.setProperty("spark.jar.uris", jarUris)
- logInfo("JAR server started at " + serverUri)
- }
-
+
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index eb47988f0c..65078b026e 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -1,7 +1,10 @@
package spark.scheduler.local
+import java.io.File
+import java.net.URLClassLoader
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.mutable.HashMap
import spark._
import spark.scheduler._
@@ -11,15 +14,17 @@ import spark.scheduler._
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
-class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging {
+class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends TaskScheduler with Logging {
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
-
+ val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
+ val jarSet: HashMap[String, Long] = new HashMap[String, Long]()
+
// TODO: Need to take into account stage priority in scheduling
- override def start() {}
+ override def start() { }
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
@@ -30,6 +35,8 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
+ task.fileSet ++= sc.addedFiles
+ task.jarSet ++= sc.addedJars
val myAttemptId = attemptId.getAndIncrement()
threadPool.submit(new Runnable {
def run() {
@@ -42,6 +49,9 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
logInfo("Running task " + idInJob)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
+ task.downloadDependencies(fileSet, jarSet)
+ // Create a new classLaoder for the downloaded JARs
+ Thread.currentThread.setContextClassLoader(createClassLoader())
try {
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
@@ -81,9 +91,19 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
}
}
+
override def stop() {
threadPool.shutdownNow()
}
+ private def createClassLoader() : ClassLoader = {
+ val currentLoader = Thread.currentThread.getContextClassLoader()
+ val urls = jarSet.keySet.map { uri =>
+ new File(uri.split("/").last).toURI.toURL
+ }.toArray
+ logInfo("Creating ClassLoader with jars: " + urls.mkString)
+ return new URLClassLoader(urls, currentLoader)
+ }
+
override def defaultParallelism() = threads
}
diff --git a/core/src/test/resources/uncommons-maths-1.2.2.jar b/core/src/test/resources/uncommons-maths-1.2.2.jar
new file mode 100644
index 0000000000..e126001c1c
--- /dev/null
+++ b/core/src/test/resources/uncommons-maths-1.2.2.jar
Binary files differ
diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala
new file mode 100644
index 0000000000..500af1eb90
--- /dev/null
+++ b/core/src/test/scala/spark/FileServerSuite.scala
@@ -0,0 +1,93 @@
+package spark
+
+import com.google.common.io.Files
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import java.io.{File, PrintWriter}
+import SparkContext._
+
+class FileServerSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+ var tmpFile : File = _
+ var testJarFile : File = _
+
+ before {
+ // Create a sample text file
+ val tmpdir = new File(Files.createTempDir(), "test")
+ tmpdir.mkdir()
+ tmpFile = new File(tmpdir, "FileServerSuite.txt")
+ val pw = new PrintWriter(tmpFile)
+ pw.println("100")
+ pw.close()
+ }
+
+ after {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ // Clean up downloaded file
+ if (tmpFile.exists) {
+ tmpFile.delete()
+ }
+ }
+
+ test("Distributing files locally") {
+ sc = new SparkContext("local[4]", "test")
+ sc.addFile(tmpFile.toString)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
+ val result = sc.parallelize(testData).reduceByKey {
+ val in = new java.io.BufferedReader(new java.io.FileReader(tmpFile))
+ val fileVal = in.readLine().toInt
+ in.close()
+ _ * fileVal + _ * fileVal
+ }.collect
+ println(result)
+ assert(result.toSet === Set((1,200), (2,300), (3,500)))
+ }
+
+ test ("Dynamically adding JARS locally") {
+ sc = new SparkContext("local[4]", "test")
+ val sampleJarFile = getClass().getClassLoader().getResource("uncommons-maths-1.2.2.jar").getFile()
+ sc.addJar(sampleJarFile)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0))
+ val result = sc.parallelize(testData).reduceByKey { (x,y) =>
+ val fac = Thread.currentThread.getContextClassLoader().loadClass("org.uncommons.maths.Maths").getDeclaredMethod("factorial", classOf[Int])
+ val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ a + b
+ }.collect()
+ assert(result.toSet === Set((1,2), (2,7), (3,121)))
+ }
+
+ test("Distributing files on a standalone cluster") {
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ sc.addFile(tmpFile.toString)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
+ val result = sc.parallelize(testData).reduceByKey {
+ val in = new java.io.BufferedReader(new java.io.FileReader(tmpFile))
+ val fileVal = in.readLine().toInt
+ in.close()
+ _ * fileVal + _ * fileVal
+ }.collect
+ println(result)
+ assert(result.toSet === Set((1,200), (2,300), (3,500)))
+ }
+
+
+ test ("Dynamically adding JARS on a standalone cluster") {
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ val sampleJarFile = getClass().getClassLoader().getResource("uncommons-maths-1.2.2.jar").getFile()
+ sc.addJar(sampleJarFile)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0))
+ val result = sc.parallelize(testData).reduceByKey { (x,y) =>
+ val fac = Thread.currentThread.getContextClassLoader().loadClass("org.uncommons.maths.Maths").getDeclaredMethod("factorial", classOf[Int])
+ val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ a + b
+ }.collect()
+ assert(result.toSet === Set((1,2), (2,7), (3,121)))
+ }
+
+} \ No newline at end of file