aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDenny <dennybritz@gmail.com>2012-09-04 18:52:07 -0700
committerDenny <dennybritz@gmail.com>2012-09-10 12:49:09 -0700
commitb864c36a3098e0ad8a2e508c94877bb2f4f4205d (patch)
treea36902a063b0346790ebcd8b282661166af5e02d
parentf275fb07da33cfa38fc02ed121a52caef20f61d0 (diff)
downloadspark-b864c36a3098e0ad8a2e508c94877bb2f4f4205d.tar.gz
spark-b864c36a3098e0ad8a2e508c94877bb2f4f4205d.tar.bz2
spark-b864c36a3098e0ad8a2e508c94877bb2f4f4205d.zip
Dynamically adding jar files and caching fileSets.
-rw-r--r--core/src/main/scala/spark/HttpFileServer.scala26
-rw-r--r--core/src/main/scala/spark/SparkContext.scala59
-rw-r--r--core/src/main/scala/spark/Utils.scala37
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala72
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala54
-rw-r--r--core/src/main/scala/spark/scheduler/Task.scala27
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala25
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala21
-rw-r--r--core/src/test/resources/uncommons-maths-1.2.2.jarbin0 -> 49019 bytes
-rw-r--r--core/src/test/scala/spark/FileServerSuite.scala39
10 files changed, 246 insertions, 114 deletions
diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala
index 3659de02c7..e6ad4dd28e 100644
--- a/core/src/main/scala/spark/HttpFileServer.scala
+++ b/core/src/main/scala/spark/HttpFileServer.scala
@@ -7,25 +7,39 @@ import org.apache.hadoop.fs.FileUtil
class HttpFileServer extends Logging {
+ var baseDir : File = null
var fileDir : File = null
+ var jarDir : File = null
var httpServer : HttpServer = null
var serverUri : String = null
def initialize() {
- fileDir = Utils.createTempDir()
- logInfo("HTTP File server directory is " + fileDir)
+ baseDir = Utils.createTempDir()
+ fileDir = new File(baseDir, "files")
+ jarDir = new File(baseDir, "jars")
+ fileDir.mkdir()
+ jarDir.mkdir()
+ logInfo("HTTP File server directory is " + baseDir)
httpServer = new HttpServer(fileDir)
httpServer.start()
serverUri = httpServer.uri
}
+ def stop() {
+ httpServer.stop()
+ }
+
def addFile(file: File) : String = {
- Utils.copyFile(file, new File(fileDir, file.getName))
- return serverUri + "/" + file.getName
+ return addFileToDir(file, fileDir)
}
- def stop() {
- httpServer.stop()
+ def addJar(file: File) : String = {
+ return addFileToDir(file, jarDir)
+ }
+
+ def addFileToDir(file: File, dir: File) : String = {
+ Utils.copyFile(file, new File(dir, file.getName))
+ return dir + "/" + file.getName
}
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index dee7cd4925..7a1bf692e4 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -2,14 +2,14 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
-import java.net.URI
+import java.net.{URI, URLClassLoader}
import akka.actor.Actor
import akka.actor.Actor._
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
@@ -78,8 +78,12 @@ class SparkContext(
isLocal)
SparkEnv.set(env)
- // Used to store a URL for each static file together with the file's local timestamp
- val files = HashMap[String, Long]()
+ // Used to store a URL for each static file/jar together with the file's local timestamp
+ val addedFiles = HashMap[String, Long]()
+ val addedJars = HashMap[String, Long]()
+
+ // Add each JAR given through the constructor
+ jars.foreach { addJar(_) }
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
@@ -316,20 +320,40 @@ class SparkContext(
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
// Adds a file dependency to all Tasks executed in the future.
- def addFile(path: String) : String = {
+ def addFile(path: String) {
val uri = new URI(path)
- uri.getScheme match {
- // A local file
- case null | "file" =>
- val file = new File(uri.getPath)
- val url = env.httpFileServer.addFile(file)
- files(url) = System.currentTimeMillis
- logInfo("Added file " + path + " at " + url + " with timestamp " + files(url))
- return url
- case _ =>
- files(path) = System.currentTimeMillis
- return path
+ val key = uri.getScheme match {
+ case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
+ case _ => path
+ }
+ addedFiles(key) = System.currentTimeMillis
+
+ // Fetch the file locally in case the task is executed locally
+ val filename = new File(path.split("/").last)
+ Utils.fetchFile(path, new File(""))
+
+ logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
+ }
+
+ def clearFiles() {
+ addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
+ addedFiles.clear()
+ }
+
+ // Adds a jar dependency to all Tasks executed in the future.
+ def addJar(path: String) {
+ val uri = new URI(path)
+ val key = uri.getScheme match {
+ case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
+ case _ => path
}
+ addedJars(key) = System.currentTimeMillis
+ logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
+ }
+
+ def clearJars() {
+ addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
+ addedJars.clear()
}
// Stop the SparkContext
@@ -339,6 +363,9 @@ class SparkContext(
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
+ // Clean up locally linked files
+ clearFiles()
+ clearJars()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
logInfo("Successfully stopped SparkContext")
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index eb0a4c99bb..07aa18e540 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -5,7 +5,7 @@ import java.net.{InetAddress, URL, URI}
import java.util.{Locale, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{Path, FileSystem}
+import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import scala.io.Source
@@ -133,20 +133,27 @@ object Utils extends Logging {
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
val targetFile = new File(targetDir, filename)
- if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) {
- // Use the java.net library to fetch it
- logInfo("Fetching " + url + " to " + targetFile)
- val in = new URL(url).openStream()
- val out = new FileOutputStream(targetFile)
- Utils.copyStream(in, out, true)
- } else {
- // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
- val uri = new URI(url)
- val conf = new Configuration()
- val fs = FileSystem.get(uri, conf)
- val in = fs.open(new Path(uri))
- val out = new FileOutputStream(targetFile)
- Utils.copyStream(in, out, true)
+ val uri = new URI(url)
+ uri.getScheme match {
+ case "http" | "https" | "ftp" =>
+ logInfo("Fetching " + url + " to " + targetFile)
+ val in = new URL(url).openStream()
+ val out = new FileOutputStream(targetFile)
+ Utils.copyStream(in, out, true)
+ case "file" | null =>
+ // Remove the file if it already exists
+ targetFile.delete()
+ // Symlink the file locally
+ logInfo("Symlinking " + url + " to " + targetFile)
+ FileUtil.symLink(url, targetFile.toString)
+ case _ =>
+ // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
+ val uri = new URI(url)
+ val conf = new Configuration()
+ val fs = FileSystem.get(uri, conf)
+ val in = fs.open(new Path(uri))
+ val out = new FileOutputStream(targetFile)
+ Utils.copyStream(in, out, true)
}
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index ce3aa49726..2d53c7a6ad 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -1,12 +1,12 @@
package spark.executor
import java.io.{File, FileOutputStream}
-import java.net.{URL, URLClassLoader}
+import java.net.{URI, URL, URLClassLoader}
import java.util.concurrent._
import org.apache.hadoop.fs.FileUtil
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.{ArrayBuffer, Map, HashMap}
import spark.broadcast._
import spark.scheduler._
@@ -17,11 +17,13 @@ import java.nio.ByteBuffer
* The Mesos executor for Spark.
*/
class Executor extends Logging {
- var classLoader: ClassLoader = null
+ var urlClassLoader : URLClassLoader = null
var threadPool: ExecutorService = null
var env: SparkEnv = null
val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
+ val jarSet: HashMap[String, Long] = new HashMap[String, Long]()
+
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
@@ -40,13 +42,14 @@ class Executor extends Logging {
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
SparkEnv.set(env)
- // Create our ClassLoader (using spark properties) and set it on this thread
- classLoader = createClassLoader()
- Thread.currentThread.setContextClassLoader(classLoader)
-
// Start worker thread pool
threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
+
+ // Create our ClassLoader and set it on this thread
+ urlClassLoader = createClassLoader()
+ Thread.currentThread.setContextClassLoader(urlClassLoader)
+
}
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
@@ -58,16 +61,16 @@ class Executor extends Logging {
override def run() {
SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(classLoader)
+ Thread.currentThread.setContextClassLoader(urlClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
try {
SparkEnv.set(env)
- Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear()
- val task = ser.deserialize[Task[Any]](serializedTask, classLoader)
- task.downloadFileDependencies(fileSet)
+ val task = ser.deserialize[Task[Any]](serializedTask, urlClassLoader)
+ task.downloadDependencies(fileSet, jarSet)
+ updateClassLoader()
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(taskId.toInt)
@@ -101,25 +104,16 @@ class Executor extends Logging {
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
*/
- private def createClassLoader(): ClassLoader = {
- var loader = this.getClass.getClassLoader
-
- // If any JAR URIs are given through spark.jar.uris, fetch them to the
- // current directory and put them all on the classpath. We assume that
- // each URL has a unique file name so that no local filenames will clash
- // in this process. This is guaranteed by ClusterScheduler.
- val uris = System.getProperty("spark.jar.uris", "")
- val localFiles = ArrayBuffer[String]()
- for (uri <- uris.split(",").filter(_.size > 0)) {
- val url = new URL(uri)
- val filename = url.getPath.split("/").last
- Utils.downloadFile(url, filename)
- localFiles += filename
- }
- if (localFiles.size > 0) {
- val urls = localFiles.map(f => new File(f).toURI.toURL).toArray
- loader = new URLClassLoader(urls, loader)
- }
+ private def createClassLoader(): URLClassLoader = {
+
+ var loader = this.getClass().getClassLoader()
+
+ // For each of the jars in the jarSet, add them to the class loader.
+ // We assume each of the files has already been fetched.
+ val urls = jarSet.keySet.map { uri =>
+ new File(uri.split("/").last).toURI.toURL
+ }.toArray
+ loader = new URLClassLoader(urls, loader)
// If the REPL is in use, add another ClassLoader that will read
// new classes defined by the REPL as the user types code
@@ -138,7 +132,23 @@ class Executor extends Logging {
}
}
- return loader
+ return new URLClassLoader(Array(), loader)
+ }
+
+ def updateClassLoader() {
+ val currentURLs = urlClassLoader.getURLs()
+
+ val urlSet = jarSet.keySet.map { x => new File(x.split("/").last).toURI.toURL }
+
+ // For abstraction reasons the addURL method in URLClassLoader is protected.
+ // We'll save us the hassle of sublassing here and use relfection instead.
+ val m = classOf[URLClassLoader].getDeclaredMethod("addURL", classOf[URL])
+ m.setAccessible(true)
+ urlSet.filterNot(currentURLs.contains(_)).foreach { url =>
+ logInfo("Adding " + url + " to the class loader.")
+ m.invoke(urlClassLoader, url)
+ }
+
}
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index a281ae94c5..3687bb990c 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -1,10 +1,10 @@
package spark.scheduler
import java.io._
-import java.util.HashMap
+import java.util.{HashMap => JHashMap}
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
@@ -20,7 +20,8 @@ object ShuffleMapTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
- val serializedInfoCache = new HashMap[Int, Array[Byte]]
+ val serializedInfoCache = new JHashMap[Int, Array[Byte]]
+ val fileSetCache = new JHashMap[Int, Array[Byte]]
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
synchronized {
@@ -40,6 +41,23 @@ object ShuffleMapTask {
}
}
+ // Since both the JarSet and FileSet have the same format this is used for both.
+ def serializeFileSet(set : HashMap[String, Long]) : Array[Byte] = {
+ val old = fileSetCache.get(set.hashCode)
+ if (old != null) {
+ return old
+ } else {
+ val out = new ByteArrayOutputStream
+ val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
+ objOut.writeObject(set.toArray)
+ objOut.close()
+ val bytes = out.toByteArray
+ fileSetCache.put(set.hashCode, bytes)
+ return bytes
+ }
+ }
+
+
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
@@ -54,9 +72,18 @@ object ShuffleMapTask {
}
}
+ // Since both the JarSet and FileSet have the same format this is used for both.
+ def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = {
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val objIn = new ObjectInputStream(in)
+ val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
+ return (HashMap(set.toSeq: _*))
+ }
+
def clearCache() {
synchronized {
serializedInfoCache.clear()
+ fileSetCache.clear()
}
}
}
@@ -84,6 +111,14 @@ class ShuffleMapTask(
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
out.write(bytes)
+
+ val fileSetBytes = ShuffleMapTask.serializeFileSet(fileSet)
+ out.writeInt(fileSetBytes.length)
+ out.write(fileSetBytes)
+ val jarSetBytes = ShuffleMapTask.serializeFileSet(jarSet)
+ out.writeInt(jarSetBytes.length)
+ out.write(jarSetBytes)
+
out.writeInt(partition)
out.writeLong(generation)
out.writeObject(split)
@@ -97,6 +132,17 @@ class ShuffleMapTask(
val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
rdd = rdd_
dep = dep_
+
+ val fileSetNumBytes = in.readInt()
+ val fileSetBytes = new Array[Byte](fileSetNumBytes)
+ in.readFully(fileSetBytes)
+ fileSet = ShuffleMapTask.deserializeFileSet(fileSetBytes)
+
+ val jarSetNumBytes = in.readInt()
+ val jarSetBytes = new Array[Byte](jarSetNumBytes)
+ in.readFully(jarSetBytes)
+ fileSet = ShuffleMapTask.deserializeFileSet(jarSetBytes)
+
partition = in.readInt()
generation = in.readLong()
split = in.readObject().asInstanceOf[Split]
@@ -110,7 +156,7 @@ class ShuffleMapTask(
val bucketIterators =
if (aggregator.mapSideCombine) {
// Apply combiners (map-side aggregation) to the map output.
- val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
+ val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
val (k, v) = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(k)
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
index faf042ad02..0d5b71b06c 100644
--- a/core/src/main/scala/spark/scheduler/Task.scala
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -1,6 +1,6 @@
package spark.scheduler
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.{HashMap}
import spark.HttpFileServer
import spark.Utils
import java.io.File
@@ -14,20 +14,29 @@ abstract class Task[T](val stageId: Int) extends Serializable {
var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
- // Stores file dependencies for this task.
+ // Stores jar and file dependencies for this task.
var fileSet : HashMap[String, Long] = new HashMap[String, Long]()
+ var jarSet : HashMap[String, Long] = new HashMap[String, Long]()
// Downloads all file dependencies from the Master file server
- def downloadFileDependencies(currentFileSet : HashMap[String, Long]) {
- // Find files that either don't exist or have an earlier timestamp
- val missingFiles = fileSet.filter { case(k,v) =>
- !currentFileSet.isDefinedAt(k) || currentFileSet(k) <= v
- }
- // Fetch each missing file
- missingFiles.foreach { case (k,v) =>
+ def downloadDependencies(currentFileSet : HashMap[String, Long],
+ currentJarSet : HashMap[String, Long]) {
+
+ // Fetch missing file dependencies
+ fileSet.filter { case(k,v) =>
+ !currentFileSet.contains(k) || currentFileSet(k) <= v
+ }.foreach { case (k,v) =>
Utils.fetchFile(k, new File(System.getProperty("user.dir")))
currentFileSet(k) = v
}
+ // Fetch missing jar dependencies
+ jarSet.filter { case(k,v) =>
+ !currentJarSet.contains(k) || currentJarSet(k) <= v
+ }.foreach { case (k,v) =>
+ Utils.fetchFile(k, new File(System.getProperty("user.dir")))
+ currentJarSet(k) = v
+ }
+
}
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index a9ab82040c..750231ac31 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -60,7 +60,6 @@ class ClusterScheduler(sc: SparkContext)
def initialize(context: SchedulerBackend) {
backend = context
- createJarServer()
}
def newTaskId(): Long = nextTaskId.getAndIncrement()
@@ -88,7 +87,10 @@ class ClusterScheduler(sc: SparkContext)
def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
- tasks.foreach { task => task.fileSet ++= sc.files }
+ tasks.foreach { task =>
+ task.fileSet ++= sc.addedFiles
+ task.jarSet ++= sc.addedJars
+ }
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = new TaskSetManager(this, taskSet)
@@ -237,25 +239,6 @@ class ClusterScheduler(sc: SparkContext)
override def defaultParallelism() = backend.defaultParallelism()
- // Copies all the JARs added by the user to the SparkContext
- // to the fileserver directory.
- private def createJarServer() {
- val fileServerDir = SparkEnv.get.httpFileServer.fileDir
- val fileServerUri = SparkEnv.get.httpFileServer.serverUri
- val filenames = ArrayBuffer[String]()
- for ((path, index) <- sc.jars.zipWithIndex) {
- val file = new File(path)
- if (file.exists) {
- val filename = index + "_" + file.getName
- Utils.copyFile(file, new File(fileServerDir, filename))
- filenames += filename
- }
- }
- jarUris = filenames.map(f => fileServerUri + "/" + f).mkString(",")
- System.setProperty("spark.jar.uris", jarUris)
- logInfo("JARs available at " + jarUris)
- }
-
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 4bd9d13637..65078b026e 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -1,5 +1,7 @@
package spark.scheduler.local
+import java.io.File
+import java.net.URLClassLoader
import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.HashMap
@@ -18,10 +20,11 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
+ val jarSet: HashMap[String, Long] = new HashMap[String, Long]()
// TODO: Need to take into account stage priority in scheduling
- override def start() {}
+ override def start() { }
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
@@ -32,7 +35,8 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
- task.fileSet ++= sc.files
+ task.fileSet ++= sc.addedFiles
+ task.jarSet ++= sc.addedJars
val myAttemptId = attemptId.getAndIncrement()
threadPool.submit(new Runnable {
def run() {
@@ -45,7 +49,9 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T
logInfo("Running task " + idInJob)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
- task.downloadFileDependencies(fileSet)
+ task.downloadDependencies(fileSet, jarSet)
+ // Create a new classLaoder for the downloaded JARs
+ Thread.currentThread.setContextClassLoader(createClassLoader())
try {
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
@@ -90,5 +96,14 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T
threadPool.shutdownNow()
}
+ private def createClassLoader() : ClassLoader = {
+ val currentLoader = Thread.currentThread.getContextClassLoader()
+ val urls = jarSet.keySet.map { uri =>
+ new File(uri.split("/").last).toURI.toURL
+ }.toArray
+ logInfo("Creating ClassLoader with jars: " + urls.mkString)
+ return new URLClassLoader(urls, currentLoader)
+ }
+
override def defaultParallelism() = threads
}
diff --git a/core/src/test/resources/uncommons-maths-1.2.2.jar b/core/src/test/resources/uncommons-maths-1.2.2.jar
new file mode 100644
index 0000000000..e126001c1c
--- /dev/null
+++ b/core/src/test/resources/uncommons-maths-1.2.2.jar
Binary files differ
diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala
index 883149feca..05517e8be4 100644
--- a/core/src/test/scala/spark/FileServerSuite.scala
+++ b/core/src/test/scala/spark/FileServerSuite.scala
@@ -1,16 +1,23 @@
package spark
+import com.google.common.io.Files
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import java.io.{File, PrintWriter}
+import SparkContext._
class FileServerSuite extends FunSuite with BeforeAndAfter {
var sc: SparkContext = _
+ var tmpFile : File = _
+ var testJarFile : File = _
before {
// Create a sample text file
- val pw = new PrintWriter(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt")
+ val tmpdir = new File(Files.createTempDir(), "test")
+ tmpdir.mkdir()
+ tmpFile = new File(tmpdir, "FileServerSuite.txt")
+ val pw = new PrintWriter(tmpFile)
pw.println("100")
pw.close()
}
@@ -21,7 +28,6 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc = null
}
// Clean up downloaded file
- val tmpFile = new File("FileServerSuite.txt")
if (tmpFile.exists) {
tmpFile.delete()
}
@@ -29,15 +35,30 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
test("Distributing files") {
sc = new SparkContext("local[4]", "test")
- sc.addFile(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt")
- val testRdd = sc.parallelize(List(1,2,3,4))
- val result = testRdd.map { x =>
- val in = new java.io.BufferedReader(new java.io.FileReader("FileServerSuite.txt"))
+ sc.addFile(tmpFile.toString)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
+ val result = sc.parallelize(testData).reduceByKey {
+ val in = new java.io.BufferedReader(new java.io.FileReader(tmpFile))
val fileVal = in.readLine().toInt
in.close()
- fileVal
- }.reduce(_ + _)
- assert(result == 400)
+ _ * fileVal + _ * fileVal
+ }.collect
+ println(result)
+ assert(result.toSet === Set((1,200), (2,300), (3,500)))
+ }
+
+ test ("Dynamically adding JARS") {
+ sc = new SparkContext("local[4]", "test")
+ val sampleJarFile = getClass().getClassLoader().getResource("uncommons-maths-1.2.2.jar").getFile()
+ sc.addJar(sampleJarFile)
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0))
+ val result = sc.parallelize(testData).reduceByKey { (x,y) =>
+ val fac = Thread.currentThread.getContextClassLoader().loadClass("org.uncommons.maths.Maths").getDeclaredMethod("factorial", classOf[Int])
+ val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ a + b
+ }.collect()
+ assert(result.toSet === Set((1,2), (2,7), (3,121)))
}
} \ No newline at end of file