diff options
author | Denny <dennybritz@gmail.com> | 2012-09-04 18:52:07 -0700 |
---|---|---|
committer | Denny <dennybritz@gmail.com> | 2012-09-10 12:49:09 -0700 |
commit | b864c36a3098e0ad8a2e508c94877bb2f4f4205d (patch) | |
tree | a36902a063b0346790ebcd8b282661166af5e02d /core/src | |
parent | f275fb07da33cfa38fc02ed121a52caef20f61d0 (diff) | |
download | spark-b864c36a3098e0ad8a2e508c94877bb2f4f4205d.tar.gz spark-b864c36a3098e0ad8a2e508c94877bb2f4f4205d.tar.bz2 spark-b864c36a3098e0ad8a2e508c94877bb2f4f4205d.zip |
Dynamically adding jar files and caching fileSets.
Diffstat (limited to 'core/src')
-rw-r--r-- | core/src/main/scala/spark/HttpFileServer.scala | 26 | ||||
-rw-r--r-- | core/src/main/scala/spark/SparkContext.scala | 59 | ||||
-rw-r--r-- | core/src/main/scala/spark/Utils.scala | 37 | ||||
-rw-r--r-- | core/src/main/scala/spark/executor/Executor.scala | 72 | ||||
-rw-r--r-- | core/src/main/scala/spark/scheduler/ShuffleMapTask.scala | 54 | ||||
-rw-r--r-- | core/src/main/scala/spark/scheduler/Task.scala | 27 | ||||
-rw-r--r-- | core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala | 25 | ||||
-rw-r--r-- | core/src/main/scala/spark/scheduler/local/LocalScheduler.scala | 21 | ||||
-rw-r--r-- | core/src/test/resources/uncommons-maths-1.2.2.jar | bin | 0 -> 49019 bytes | |||
-rw-r--r-- | core/src/test/scala/spark/FileServerSuite.scala | 39 |
10 files changed, 246 insertions, 114 deletions
diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala index 3659de02c7..e6ad4dd28e 100644 --- a/core/src/main/scala/spark/HttpFileServer.scala +++ b/core/src/main/scala/spark/HttpFileServer.scala @@ -7,25 +7,39 @@ import org.apache.hadoop.fs.FileUtil class HttpFileServer extends Logging { + var baseDir : File = null var fileDir : File = null + var jarDir : File = null var httpServer : HttpServer = null var serverUri : String = null def initialize() { - fileDir = Utils.createTempDir() - logInfo("HTTP File server directory is " + fileDir) + baseDir = Utils.createTempDir() + fileDir = new File(baseDir, "files") + jarDir = new File(baseDir, "jars") + fileDir.mkdir() + jarDir.mkdir() + logInfo("HTTP File server directory is " + baseDir) httpServer = new HttpServer(fileDir) httpServer.start() serverUri = httpServer.uri } + def stop() { + httpServer.stop() + } + def addFile(file: File) : String = { - Utils.copyFile(file, new File(fileDir, file.getName)) - return serverUri + "/" + file.getName + return addFileToDir(file, fileDir) } - def stop() { - httpServer.stop() + def addJar(file: File) : String = { + return addFileToDir(file, jarDir) + } + + def addFileToDir(file: File, dir: File) : String = { + Utils.copyFile(file, new File(dir, file.getName)) + return dir + "/" + file.getName } }
\ No newline at end of file diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index dee7cd4925..7a1bf692e4 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -2,14 +2,14 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger -import java.net.URI +import java.net.{URI, URLClassLoader} import akka.actor.Actor import akka.actor.Actor._ import scala.collection.mutable.{ArrayBuffer, HashMap} -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.SequenceFileInputFormat @@ -78,8 +78,12 @@ class SparkContext( isLocal) SparkEnv.set(env) - // Used to store a URL for each static file together with the file's local timestamp - val files = HashMap[String, Long]() + // Used to store a URL for each static file/jar together with the file's local timestamp + val addedFiles = HashMap[String, Long]() + val addedJars = HashMap[String, Long]() + + // Add each JAR given through the constructor + jars.foreach { addJar(_) } // Create and start the scheduler private var taskScheduler: TaskScheduler = { @@ -316,20 +320,40 @@ class SparkContext( def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal) // Adds a file dependency to all Tasks executed in the future. - def addFile(path: String) : String = { + def addFile(path: String) { val uri = new URI(path) - uri.getScheme match { - // A local file - case null | "file" => - val file = new File(uri.getPath) - val url = env.httpFileServer.addFile(file) - files(url) = System.currentTimeMillis - logInfo("Added file " + path + " at " + url + " with timestamp " + files(url)) - return url - case _ => - files(path) = System.currentTimeMillis - return path + val key = uri.getScheme match { + case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) + case _ => path + } + addedFiles(key) = System.currentTimeMillis + + // Fetch the file locally in case the task is executed locally + val filename = new File(path.split("/").last) + Utils.fetchFile(path, new File("")) + + logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) + } + + def clearFiles() { + addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } + addedFiles.clear() + } + + // Adds a jar dependency to all Tasks executed in the future. + def addJar(path: String) { + val uri = new URI(path) + val key = uri.getScheme match { + case null | "file" => env.httpFileServer.addJar(new File(uri.getPath)) + case _ => path } + addedJars(key) = System.currentTimeMillis + logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) + } + + def clearJars() { + addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } + addedJars.clear() } // Stop the SparkContext @@ -339,6 +363,9 @@ class SparkContext( taskScheduler = null // TODO: Cache.stop()? env.stop() + // Clean up locally linked files + clearFiles() + clearJars() SparkEnv.set(null) ShuffleMapTask.clearCache() logInfo("Successfully stopped SparkContext") diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index eb0a4c99bb..07aa18e540 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -5,7 +5,7 @@ import java.net.{InetAddress, URL, URI} import java.util.{Locale, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import scala.collection.mutable.ArrayBuffer import scala.util.Random import scala.io.Source @@ -133,20 +133,27 @@ object Utils extends Logging { def fetchFile(url: String, targetDir: File) { val filename = url.split("/").last val targetFile = new File(targetDir, filename) - if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) { - // Use the java.net library to fetch it - logInfo("Fetching " + url + " to " + targetFile) - val in = new URL(url).openStream() - val out = new FileOutputStream(targetFile) - Utils.copyStream(in, out, true) - } else { - // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others - val uri = new URI(url) - val conf = new Configuration() - val fs = FileSystem.get(uri, conf) - val in = fs.open(new Path(uri)) - val out = new FileOutputStream(targetFile) - Utils.copyStream(in, out, true) + val uri = new URI(url) + uri.getScheme match { + case "http" | "https" | "ftp" => + logInfo("Fetching " + url + " to " + targetFile) + val in = new URL(url).openStream() + val out = new FileOutputStream(targetFile) + Utils.copyStream(in, out, true) + case "file" | null => + // Remove the file if it already exists + targetFile.delete() + // Symlink the file locally + logInfo("Symlinking " + url + " to " + targetFile) + FileUtil.symLink(url, targetFile.toString) + case _ => + // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others + val uri = new URI(url) + val conf = new Configuration() + val fs = FileSystem.get(uri, conf) + val in = fs.open(new Path(uri)) + val out = new FileOutputStream(targetFile) + Utils.copyStream(in, out, true) } // Decompress the file if it's a .tar or .tar.gz if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index ce3aa49726..2d53c7a6ad 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -1,12 +1,12 @@ package spark.executor import java.io.{File, FileOutputStream} -import java.net.{URL, URLClassLoader} +import java.net.{URI, URL, URLClassLoader} import java.util.concurrent._ import org.apache.hadoop.fs.FileUtil -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.{ArrayBuffer, Map, HashMap} import spark.broadcast._ import spark.scheduler._ @@ -17,11 +17,13 @@ import java.nio.ByteBuffer * The Mesos executor for Spark. */ class Executor extends Logging { - var classLoader: ClassLoader = null + var urlClassLoader : URLClassLoader = null var threadPool: ExecutorService = null var env: SparkEnv = null val fileSet: HashMap[String, Long] = new HashMap[String, Long]() + val jarSet: HashMap[String, Long] = new HashMap[String, Long]() + val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) @@ -40,13 +42,14 @@ class Executor extends Logging { env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false) SparkEnv.set(env) - // Create our ClassLoader (using spark properties) and set it on this thread - classLoader = createClassLoader() - Thread.currentThread.setContextClassLoader(classLoader) - // Start worker thread pool threadPool = new ThreadPoolExecutor( 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) + + // Create our ClassLoader and set it on this thread + urlClassLoader = createClassLoader() + Thread.currentThread.setContextClassLoader(urlClassLoader) + } def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { @@ -58,16 +61,16 @@ class Executor extends Logging { override def run() { SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(classLoader) + Thread.currentThread.setContextClassLoader(urlClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo("Running task ID " + taskId) context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) try { SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear() - val task = ser.deserialize[Task[Any]](serializedTask, classLoader) - task.downloadFileDependencies(fileSet) + val task = ser.deserialize[Task[Any]](serializedTask, urlClassLoader) + task.downloadDependencies(fileSet, jarSet) + updateClassLoader() logInfo("Its generation is " + task.generation) env.mapOutputTracker.updateGeneration(task.generation) val value = task.run(taskId.toInt) @@ -101,25 +104,16 @@ class Executor extends Logging { * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes * created by the interpreter to the search path */ - private def createClassLoader(): ClassLoader = { - var loader = this.getClass.getClassLoader - - // If any JAR URIs are given through spark.jar.uris, fetch them to the - // current directory and put them all on the classpath. We assume that - // each URL has a unique file name so that no local filenames will clash - // in this process. This is guaranteed by ClusterScheduler. - val uris = System.getProperty("spark.jar.uris", "") - val localFiles = ArrayBuffer[String]() - for (uri <- uris.split(",").filter(_.size > 0)) { - val url = new URL(uri) - val filename = url.getPath.split("/").last - Utils.downloadFile(url, filename) - localFiles += filename - } - if (localFiles.size > 0) { - val urls = localFiles.map(f => new File(f).toURI.toURL).toArray - loader = new URLClassLoader(urls, loader) - } + private def createClassLoader(): URLClassLoader = { + + var loader = this.getClass().getClassLoader() + + // For each of the jars in the jarSet, add them to the class loader. + // We assume each of the files has already been fetched. + val urls = jarSet.keySet.map { uri => + new File(uri.split("/").last).toURI.toURL + }.toArray + loader = new URLClassLoader(urls, loader) // If the REPL is in use, add another ClassLoader that will read // new classes defined by the REPL as the user types code @@ -138,7 +132,23 @@ class Executor extends Logging { } } - return loader + return new URLClassLoader(Array(), loader) + } + + def updateClassLoader() { + val currentURLs = urlClassLoader.getURLs() + + val urlSet = jarSet.keySet.map { x => new File(x.split("/").last).toURI.toURL } + + // For abstraction reasons the addURL method in URLClassLoader is protected. + // We'll save us the hassle of sublassing here and use relfection instead. + val m = classOf[URLClassLoader].getDeclaredMethod("addURL", classOf[URL]) + m.setAccessible(true) + urlSet.filterNot(currentURLs.contains(_)).foreach { url => + logInfo("Adding " + url + " to the class loader.") + m.invoke(urlClassLoader, url) + } + } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index a281ae94c5..3687bb990c 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -1,10 +1,10 @@ package spark.scheduler import java.io._ -import java.util.HashMap +import java.util.{HashMap => JHashMap} import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConversions._ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream @@ -20,7 +20,8 @@ object ShuffleMapTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new HashMap[Int, Array[Byte]] + val serializedInfoCache = new JHashMap[Int, Array[Byte]] + val fileSetCache = new JHashMap[Int, Array[Byte]] def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = { synchronized { @@ -40,6 +41,23 @@ object ShuffleMapTask { } } + // Since both the JarSet and FileSet have the same format this is used for both. + def serializeFileSet(set : HashMap[String, Long]) : Array[Byte] = { + val old = fileSetCache.get(set.hashCode) + if (old != null) { + return old + } else { + val out = new ByteArrayOutputStream + val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) + objOut.writeObject(set.toArray) + objOut.close() + val bytes = out.toByteArray + fileSetCache.put(set.hashCode, bytes) + return bytes + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = { synchronized { val loader = Thread.currentThread.getContextClassLoader @@ -54,9 +72,18 @@ object ShuffleMapTask { } } + // Since both the JarSet and FileSet have the same format this is used for both. + def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = { + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val objIn = new ObjectInputStream(in) + val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap + return (HashMap(set.toSeq: _*)) + } + def clearCache() { synchronized { serializedInfoCache.clear() + fileSetCache.clear() } } } @@ -84,6 +111,14 @@ class ShuffleMapTask( val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) out.writeInt(bytes.length) out.write(bytes) + + val fileSetBytes = ShuffleMapTask.serializeFileSet(fileSet) + out.writeInt(fileSetBytes.length) + out.write(fileSetBytes) + val jarSetBytes = ShuffleMapTask.serializeFileSet(jarSet) + out.writeInt(jarSetBytes.length) + out.write(jarSetBytes) + out.writeInt(partition) out.writeLong(generation) out.writeObject(split) @@ -97,6 +132,17 @@ class ShuffleMapTask( val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) rdd = rdd_ dep = dep_ + + val fileSetNumBytes = in.readInt() + val fileSetBytes = new Array[Byte](fileSetNumBytes) + in.readFully(fileSetBytes) + fileSet = ShuffleMapTask.deserializeFileSet(fileSetBytes) + + val jarSetNumBytes = in.readInt() + val jarSetBytes = new Array[Byte](jarSetNumBytes) + in.readFully(jarSetBytes) + fileSet = ShuffleMapTask.deserializeFileSet(jarSetBytes) + partition = in.readInt() generation = in.readLong() split = in.readObject().asInstanceOf[Split] @@ -110,7 +156,7 @@ class ShuffleMapTask( val bucketIterators = if (aggregator.mapSideCombine) { // Apply combiners (map-side aggregation) to the map output. - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any]) + val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any]) for (elem <- rdd.iterator(split)) { val (k, v) = elem.asInstanceOf[(Any, Any)] val bucketId = partitioner.getPartition(k) diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala index faf042ad02..0d5b71b06c 100644 --- a/core/src/main/scala/spark/scheduler/Task.scala +++ b/core/src/main/scala/spark/scheduler/Task.scala @@ -1,6 +1,6 @@ package spark.scheduler -import scala.collection.mutable.HashMap +import scala.collection.mutable.{HashMap} import spark.HttpFileServer import spark.Utils import java.io.File @@ -14,20 +14,29 @@ abstract class Task[T](val stageId: Int) extends Serializable { var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler. - // Stores file dependencies for this task. + // Stores jar and file dependencies for this task. var fileSet : HashMap[String, Long] = new HashMap[String, Long]() + var jarSet : HashMap[String, Long] = new HashMap[String, Long]() // Downloads all file dependencies from the Master file server - def downloadFileDependencies(currentFileSet : HashMap[String, Long]) { - // Find files that either don't exist or have an earlier timestamp - val missingFiles = fileSet.filter { case(k,v) => - !currentFileSet.isDefinedAt(k) || currentFileSet(k) <= v - } - // Fetch each missing file - missingFiles.foreach { case (k,v) => + def downloadDependencies(currentFileSet : HashMap[String, Long], + currentJarSet : HashMap[String, Long]) { + + // Fetch missing file dependencies + fileSet.filter { case(k,v) => + !currentFileSet.contains(k) || currentFileSet(k) <= v + }.foreach { case (k,v) => Utils.fetchFile(k, new File(System.getProperty("user.dir"))) currentFileSet(k) = v } + // Fetch missing jar dependencies + jarSet.filter { case(k,v) => + !currentJarSet.contains(k) || currentJarSet(k) <= v + }.foreach { case (k,v) => + Utils.fetchFile(k, new File(System.getProperty("user.dir"))) + currentJarSet(k) = v + } + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index a9ab82040c..750231ac31 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -60,7 +60,6 @@ class ClusterScheduler(sc: SparkContext) def initialize(context: SchedulerBackend) { backend = context - createJarServer() } def newTaskId(): Long = nextTaskId.getAndIncrement() @@ -88,7 +87,10 @@ class ClusterScheduler(sc: SparkContext) def submitTasks(taskSet: TaskSet) { val tasks = taskSet.tasks - tasks.foreach { task => task.fileSet ++= sc.files } + tasks.foreach { task => + task.fileSet ++= sc.addedFiles + task.jarSet ++= sc.addedJars + } logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = new TaskSetManager(this, taskSet) @@ -237,25 +239,6 @@ class ClusterScheduler(sc: SparkContext) override def defaultParallelism() = backend.defaultParallelism() - // Copies all the JARs added by the user to the SparkContext - // to the fileserver directory. - private def createJarServer() { - val fileServerDir = SparkEnv.get.httpFileServer.fileDir - val fileServerUri = SparkEnv.get.httpFileServer.serverUri - val filenames = ArrayBuffer[String]() - for ((path, index) <- sc.jars.zipWithIndex) { - val file = new File(path) - if (file.exists) { - val filename = index + "_" + file.getName - Utils.copyFile(file, new File(fileServerDir, filename)) - filenames += filename - } - } - jarUris = filenames.map(f => fileServerUri + "/" + f).mkString(",") - System.setProperty("spark.jar.uris", jarUris) - logInfo("JARs available at " + jarUris) - } - // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { var shouldRevive = false diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 4bd9d13637..65078b026e 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -1,5 +1,7 @@ package spark.scheduler.local +import java.io.File +import java.net.URLClassLoader import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.HashMap @@ -18,10 +20,11 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T val env = SparkEnv.get var listener: TaskSchedulerListener = null val fileSet: HashMap[String, Long] = new HashMap[String, Long]() + val jarSet: HashMap[String, Long] = new HashMap[String, Long]() // TODO: Need to take into account stage priority in scheduling - override def start() {} + override def start() { } override def setListener(listener: TaskSchedulerListener) { this.listener = listener @@ -32,7 +35,8 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T val failCount = new Array[Int](tasks.size) def submitTask(task: Task[_], idInJob: Int) { - task.fileSet ++= sc.files + task.fileSet ++= sc.addedFiles + task.jarSet ++= sc.addedJars val myAttemptId = attemptId.getAndIncrement() threadPool.submit(new Runnable { def run() { @@ -45,7 +49,9 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T logInfo("Running task " + idInJob) // Set the Spark execution environment for the worker thread SparkEnv.set(env) - task.downloadFileDependencies(fileSet) + task.downloadDependencies(fileSet, jarSet) + // Create a new classLaoder for the downloaded JARs + Thread.currentThread.setContextClassLoader(createClassLoader()) try { // Serialize and deserialize the task so that accumulators are changed to thread-local ones; // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. @@ -90,5 +96,14 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T threadPool.shutdownNow() } + private def createClassLoader() : ClassLoader = { + val currentLoader = Thread.currentThread.getContextClassLoader() + val urls = jarSet.keySet.map { uri => + new File(uri.split("/").last).toURI.toURL + }.toArray + logInfo("Creating ClassLoader with jars: " + urls.mkString) + return new URLClassLoader(urls, currentLoader) + } + override def defaultParallelism() = threads } diff --git a/core/src/test/resources/uncommons-maths-1.2.2.jar b/core/src/test/resources/uncommons-maths-1.2.2.jar Binary files differnew file mode 100644 index 0000000000..e126001c1c --- /dev/null +++ b/core/src/test/resources/uncommons-maths-1.2.2.jar diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala index 883149feca..05517e8be4 100644 --- a/core/src/test/scala/spark/FileServerSuite.scala +++ b/core/src/test/scala/spark/FileServerSuite.scala @@ -1,16 +1,23 @@ package spark +import com.google.common.io.Files import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import java.io.{File, PrintWriter} +import SparkContext._ class FileServerSuite extends FunSuite with BeforeAndAfter { var sc: SparkContext = _ + var tmpFile : File = _ + var testJarFile : File = _ before { // Create a sample text file - val pw = new PrintWriter(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt") + val tmpdir = new File(Files.createTempDir(), "test") + tmpdir.mkdir() + tmpFile = new File(tmpdir, "FileServerSuite.txt") + val pw = new PrintWriter(tmpFile) pw.println("100") pw.close() } @@ -21,7 +28,6 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { sc = null } // Clean up downloaded file - val tmpFile = new File("FileServerSuite.txt") if (tmpFile.exists) { tmpFile.delete() } @@ -29,15 +35,30 @@ class FileServerSuite extends FunSuite with BeforeAndAfter { test("Distributing files") { sc = new SparkContext("local[4]", "test") - sc.addFile(System.getProperty("java.io.tmpdir") + "FileServerSuite.txt") - val testRdd = sc.parallelize(List(1,2,3,4)) - val result = testRdd.map { x => - val in = new java.io.BufferedReader(new java.io.FileReader("FileServerSuite.txt")) + sc.addFile(tmpFile.toString) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val in = new java.io.BufferedReader(new java.io.FileReader(tmpFile)) val fileVal = in.readLine().toInt in.close() - fileVal - }.reduce(_ + _) - assert(result == 400) + _ * fileVal + _ * fileVal + }.collect + println(result) + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + + test ("Dynamically adding JARS") { + sc = new SparkContext("local[4]", "test") + val sampleJarFile = getClass().getClassLoader().getResource("uncommons-maths-1.2.2.jar").getFile() + sc.addJar(sampleJarFile) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0)) + val result = sc.parallelize(testData).reduceByKey { (x,y) => + val fac = Thread.currentThread.getContextClassLoader().loadClass("org.uncommons.maths.Maths").getDeclaredMethod("factorial", classOf[Int]) + val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt + a + b + }.collect() + assert(result.toSet === Set((1,2), (2,7), (3,121))) } }
\ No newline at end of file |