aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-09-28 16:14:05 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-09-28 16:14:05 -0700
commit0121a26bd150e5f76d950e08cf4d536fad635a40 (patch)
tree57bf0fb307cb2ad296a31977c8f40b0036523ffc
parent2a8bfbca00a1701bfe22f5b0967c2d95c088c277 (diff)
downloadspark-0121a26bd150e5f76d950e08cf4d536fad635a40.tar.gz
spark-0121a26bd150e5f76d950e08cf4d536fad635a40.tar.bz2
spark-0121a26bd150e5f76d950e08cf4d536fad635a40.zip
Changed the way tasks' dependency files are sent to workers so that
custom serializers or Kryo registrators can be loaded.
-rw-r--r--bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala5
-rw-r--r--core/src/main/scala/spark/JavaSerializer.scala2
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala5
-rw-r--r--core/src/main/scala/spark/Serializer.scala2
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala48
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala56
-rw-r--r--core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala15
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala41
-rw-r--r--core/src/main/scala/spark/scheduler/Task.scala107
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala68
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala2
-rw-r--r--core/src/main/scala/spark/util/ByteBufferInputStream.scala2
14 files changed, 206 insertions, 156 deletions
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
index ed8ace3a57..8ced0f9c73 100644
--- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
+++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
@@ -142,7 +142,7 @@ class WPRSerializerInstance extends SerializerInstance {
class WPRSerializationStream(os: OutputStream) extends SerializationStream {
val dos = new DataOutputStream(os)
- def writeObject[T](t: T): Unit = t match {
+ def writeObject[T](t: T): SerializationStream = t match {
case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match {
case links: Array[String] => {
dos.writeInt(0) // links
@@ -151,17 +151,20 @@ class WPRSerializationStream(os: OutputStream) extends SerializationStream {
for (link <- links) {
dos.writeUTF(link)
}
+ this
}
case rank: Double => {
dos.writeInt(1) // rank
dos.writeUTF(id)
dos.writeDouble(rank)
+ this
}
}
case (id: String, rank: Double) => {
dos.writeInt(2) // rank without wrapper
dos.writeUTF(id)
dos.writeDouble(rank)
+ this
}
}
diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala
index d11ba5167d..1511c2620e 100644
--- a/core/src/main/scala/spark/JavaSerializer.scala
+++ b/core/src/main/scala/spark/JavaSerializer.scala
@@ -7,7 +7,7 @@ import spark.util.ByteBufferInputStream
class JavaSerializationStream(out: OutputStream) extends SerializationStream {
val objOut = new ObjectOutputStream(out)
- def writeObject[T](t: T) { objOut.writeObject(t) }
+ def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this }
def flush() { objOut.flush() }
def close() { objOut.close() }
}
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 8aa27a747b..376fcff4c8 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -72,12 +72,13 @@ class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputS
extends SerializationStream {
val channel = Channels.newChannel(out)
- def writeObject[T](t: T) {
+ def writeObject[T](t: T): SerializationStream = {
kryo.writeClassAndObject(threadBuffer, t)
ZigZag.writeInt(threadBuffer.position(), out)
threadBuffer.flip()
channel.write(threadBuffer)
threadBuffer.clear()
+ this
}
def flush() { out.flush() }
@@ -161,6 +162,8 @@ trait KryoRegistrator {
}
class KryoSerializer extends Serializer with Logging {
+ // Make this lazy so that it only gets called once we receive our first task on each executor,
+ // so we can pull out any custom Kryo registrator from the user's JARs.
lazy val kryo = createKryo()
val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala
index 5f26bd2a7b..9ec07cc173 100644
--- a/core/src/main/scala/spark/Serializer.scala
+++ b/core/src/main/scala/spark/Serializer.scala
@@ -51,7 +51,7 @@ trait SerializerInstance {
* A stream for writing serialized objects.
*/
trait SerializationStream {
- def writeObject[T](t: T): Unit
+ def writeObject[T](t: T): SerializationStream
def flush(): Unit
def close(): Unit
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 6ffae8e85f..2c9f46b1a0 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -74,11 +74,18 @@ object SparkEnv {
System.setProperty("spark.master.port", boundPort.toString)
}
- val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer")
- val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
+ val classLoader = Thread.currentThread.getContextClassLoader
+
+ // Create an instance of the class named by the given Java system property, or by
+ // defaultClassName if the property is not set, and return it as a T
+ def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
+ val name = System.getProperty(propertyName, defaultClassName)
+ Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
+ }
+
+ val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal)
-
val blockManager = new BlockManager(blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager
@@ -87,45 +94,22 @@ object SparkEnv {
val broadcastManager = new BroadcastManager(isMaster)
- val closureSerializerClass =
- System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
- val closureSerializer =
- Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
- val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
- val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
+ val closureSerializer = instantiateClass[Serializer](
+ "spark.closure.serializer", "spark.JavaSerializer")
+
+ val cache = instantiateClass[Cache]("spark.cache.class", "spark.BoundedMemoryCache")
val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager)
blockManager.cacheTracker = cacheTracker
val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster)
- val shuffleFetcherClass =
- System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
- val shuffleFetcher =
- Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
+ val shuffleFetcher = instantiateClass[ShuffleFetcher](
+ "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val httpFileServer = new HttpFileServer()
httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
-
- /*
- if (System.getProperty("spark.stream.distributed", "false") == "true") {
- val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]]
- if (isLocal || !isMaster) {
- (new Thread() {
- override def run() {
- println("Wait started")
- Thread.sleep(60000)
- println("Wait ended")
- val receiverClass = Class.forName("spark.stream.TestStreamReceiver4")
- val constructor = receiverClass.getConstructor(blockManagerClass)
- val receiver = constructor.newInstance(blockManager)
- receiver.asInstanceOf[Thread].start()
- }
- }).start()
- }
- }
- */
new SparkEnv(
actorSystem,
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 9999b6ba80..820428c727 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -20,10 +20,11 @@ class Executor extends Logging {
var urlClassLoader : ExecutorURLClassLoader = null
var threadPool: ExecutorService = null
var env: SparkEnv = null
-
- val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
- val jarSet: HashMap[String, Long] = new HashMap[String, Long]()
-
+
+ // Application dependencies (added through SparkContext) that we've fetched so far on this node.
+ // Each map holds the master's timestamp for the version of that file or JAR we got.
+ val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
+ val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
@@ -67,9 +68,9 @@ class Executor extends Logging {
try {
SparkEnv.set(env)
Accumulators.clear()
- val task = ser.deserialize[Task[Any]](serializedTask, urlClassLoader)
- task.downloadDependencies(fileSet, jarSet)
- updateClassLoader()
+ val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
+ updateDependencies(taskFiles, taskJars)
+ val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(taskId.toInt)
@@ -104,12 +105,11 @@ class Executor extends Logging {
* created by the interpreter to the search path
*/
private def createClassLoader(): ExecutorURLClassLoader = {
-
- var loader = this.getClass().getClassLoader()
+ var loader = this.getClass.getClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
- val urls = jarSet.keySet.map { uri =>
+ val urls = currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}.toArray
loader = new URLClassLoader(urls, loader)
@@ -134,22 +134,28 @@ class Executor extends Logging {
return new ExecutorURLClassLoader(Array(), loader)
}
- def updateClassLoader() {
- val currentURLs = urlClassLoader.getURLs()
- val urlSet = jarSet.keySet.map { x => new File(x.split("/").last).toURI.toURL }
- urlSet.filterNot(currentURLs.contains(_)).foreach { url =>
- logInfo("Adding " + url + " to the class loader.")
- urlClassLoader.addURL(url)
+ /**
+ * Download any missing dependencies if we receive a new set of files and JARs from the
+ * SparkContext. Also adds any new JARs we fetched to the class loader.
+ */
+ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name)
+ Utils.fetchFile(name, new File("."))
+ currentFiles(name) = timestamp
}
-
- }
-
- // The addURL method in URLClassLoader is protected. We subclass it to make it accessible.
- class ExecutorURLClassLoader(urls : Array[URL], parent : ClassLoader)
- extends URLClassLoader(urls, parent) {
- override def addURL(url: URL) {
- super.addURL(url)
+ for ((name, timestamp) <- newJars if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name)
+ Utils.fetchFile(name, new File("."))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(".", localName).toURI.toURL
+ if (!urlClassLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ urlClassLoader.addURL(url)
+ }
}
}
-
}
diff --git a/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala
new file mode 100644
index 0000000000..f74f036c4c
--- /dev/null
+++ b/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala
@@ -0,0 +1,15 @@
+package spark.executor
+
+import java.net.{URLClassLoader, URL}
+
+/**
+ * The addURL method in URLClassLoader is protected. We subclass it to make this accessible.
+ */
+private[spark]
+class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
+ extends URLClassLoader(urls, parent) {
+
+ override def addURL(url: URL) {
+ super.addURL(url)
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 745aa0c939..d70a061366 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -21,8 +21,6 @@ object ShuffleMapTask {
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new JHashMap[Int, Array[Byte]]
- val fileSetCache = new JHashMap[Int, Array[Byte]]
- val jarSetCache = new JHashMap[Int, Array[Byte]]
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = {
synchronized {
@@ -43,23 +41,6 @@ object ShuffleMapTask {
}
}
- // Since both the JarSet and FileSet have the same format this is used for both.
- def serializeFileSet(
- set : HashMap[String, Long], stageId: Int, cache : JHashMap[Int, Array[Byte]]) : Array[Byte] = {
- val old = cache.get(stageId)
- if (old != null) {
- return old
- } else {
- val out = new ByteArrayOutputStream
- val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
- objOut.writeObject(set.toArray)
- objOut.close()
- val bytes = out.toByteArray
- cache.put(stageId, bytes)
- return bytes
- }
- }
-
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = {
synchronized {
val loader = Thread.currentThread.getContextClassLoader
@@ -83,8 +64,6 @@ object ShuffleMapTask {
def clearCache() {
synchronized {
serializedInfoCache.clear()
- fileSetCache.clear()
- jarSetCache.clear()
}
}
}
@@ -112,15 +91,6 @@ class ShuffleMapTask(
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
out.write(bytes)
-
- val fileSetBytes = ShuffleMapTask.serializeFileSet(
- fileSet, stageId, ShuffleMapTask.fileSetCache)
- out.writeInt(fileSetBytes.length)
- out.write(fileSetBytes)
- val jarSetBytes = ShuffleMapTask.serializeFileSet(jarSet, stageId, ShuffleMapTask.jarSetCache)
- out.writeInt(jarSetBytes.length)
- out.write(jarSetBytes)
-
out.writeInt(partition)
out.writeLong(generation)
out.writeObject(split)
@@ -134,17 +104,6 @@ class ShuffleMapTask(
val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
rdd = rdd_
dep = dep_
-
- val fileSetNumBytes = in.readInt()
- val fileSetBytes = new Array[Byte](fileSetNumBytes)
- in.readFully(fileSetBytes)
- fileSet = ShuffleMapTask.deserializeFileSet(fileSetBytes)
-
- val jarSetNumBytes = in.readInt()
- val jarSetBytes = new Array[Byte](jarSetNumBytes)
- in.readFully(jarSetBytes)
- jarSet = ShuffleMapTask.deserializeFileSet(jarSetBytes)
-
partition = in.readInt()
generation = in.readLong()
split = in.readObject().asInstanceOf[Split]
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
index 6128e0b273..d69c259362 100644
--- a/core/src/main/scala/spark/scheduler/Task.scala
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -1,9 +1,12 @@
package spark.scheduler
import scala.collection.mutable.{HashMap}
-import spark.HttpFileServer
-import spark.Utils
-import java.io.File
+import spark.{SerializerInstance, Serializer, Utils}
+import java.io.{DataInputStream, DataOutputStream, File}
+import java.nio.ByteBuffer
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import spark.util.ByteBufferInputStream
+import scala.collection.mutable.HashMap
/**
* A task to execute on a worker node.
@@ -13,30 +16,80 @@ abstract class Task[T](val stageId: Int) extends Serializable {
def preferredLocations: Seq[String] = Nil
var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
-
- // Stores jar and file dependencies for this task.
- var fileSet : HashMap[String, Long] = new HashMap[String, Long]()
- var jarSet : HashMap[String, Long] = new HashMap[String, Long]()
-
- // Downloads all file dependencies from the Master file server
- def downloadDependencies(currentFileSet : HashMap[String, Long],
- currentJarSet : HashMap[String, Long]) {
-
- // Fetch missing file dependencies
- fileSet.filter { case(k,v) =>
- !currentFileSet.contains(k) || currentFileSet(k) < v
- }.foreach { case (k,v) =>
- Utils.fetchFile(k, new File(System.getProperty("user.dir")))
- currentFileSet(k) = v
+}
+
+/**
+ * Handles transmission of tasks and their dependencies, because this can be slightly tricky. We
+ * need to send the list of JARs and files added to the SparkContext with each task to ensure that
+ * worker nodes find out about it, but we can't make it part of the Task because the user's code in
+ * the task might depend on one of the JARs. Thus we serialize each task as multiple objects, by
+ * first writing out its dependencies.
+ */
+object Task {
+ /**
+ * Serialize a task and the current app dependencies (files and JARs added to the SparkContext)
+ */
+ def serializeWithDependencies(
+ task: Task[_],
+ currentFiles: HashMap[String, Long],
+ currentJars: HashMap[String, Long],
+ serializer: SerializerInstance)
+ : ByteBuffer = {
+
+ val out = new FastByteArrayOutputStream(4096)
+ val dataOut = new DataOutputStream(out)
+
+ // Write currentFiles
+ dataOut.writeInt(currentFiles.size)
+ for ((name, timestamp) <- currentFiles) {
+ dataOut.writeUTF(name)
+ dataOut.writeLong(timestamp)
}
- // Fetch missing jar dependencies
- jarSet.filter { case(k,v) =>
- !currentJarSet.contains(k) || currentJarSet(k) < v
- }.foreach { case (k,v) =>
- Utils.fetchFile(k, new File(System.getProperty("user.dir")))
- currentJarSet(k) = v
+
+ // Write currentJars
+ dataOut.writeInt(currentJars.size)
+ for ((name, timestamp) <- currentJars) {
+ dataOut.writeUTF(name)
+ dataOut.writeLong(timestamp)
}
-
+
+ // Write the task itself and finish
+ dataOut.flush()
+ val taskBytes = serializer.serialize(task).array()
+ out.write(taskBytes)
+ out.trim()
+ ByteBuffer.wrap(out.array)
}
-
-}
+
+ /**
+ * Deserialize the list of dependencies in a task serialized with serializeWithDependencies,
+ * and return the task itself as a serialized ByteBuffer. The caller can then update its
+ * ClassLoaders and deserialize the task.
+ *
+ * @return (taskFiles, taskJars, taskBytes)
+ */
+ def deserializeWithDependencies(serializedTask: ByteBuffer)
+ : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = {
+
+ val in = new ByteBufferInputStream(serializedTask)
+ val dataIn = new DataInputStream(in)
+
+ // Read task's files
+ val taskFiles = new HashMap[String, Long]()
+ val numFiles = dataIn.readInt()
+ for (i <- 0 until numFiles) {
+ taskFiles(dataIn.readUTF()) = dataIn.readLong()
+ }
+
+ // Read task's JARs
+ val taskJars = new HashMap[String, Long]()
+ val numJars = dataIn.readInt()
+ for (i <- 0 until numJars) {
+ taskJars(dataIn.readUTF()) = dataIn.readLong()
+ }
+
+ // Create a sub-buffer for the rest of the data, which is the serialized Task object
+ val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task
+ (taskFiles, taskJars, subBuffer)
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 952c9766bf..16fe5761c8 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -16,7 +16,7 @@ import java.util.concurrent.atomic.AtomicLong
* The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
* start(), then submit task sets through the runTasks method.
*/
-class ClusterScheduler(sc: SparkContext)
+class ClusterScheduler(val sc: SparkContext)
extends TaskScheduler
with Logging {
@@ -87,10 +87,6 @@ class ClusterScheduler(sc: SparkContext)
def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
- tasks.foreach { task =>
- task.fileSet ++= sc.addedFiles
- task.jarSet ++= sc.addedJars
- }
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = new TaskSetManager(this, taskSet)
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index e25a11e7c5..aa37462fb0 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -214,7 +214,8 @@ class TaskSetManager(
}
// Serialize and return the task
val startTime = System.currentTimeMillis
- val serializedTask = ser.serialize(task)
+ val serializedTask = Task.serializeWithDependencies(
+ task, sched.sc.addedFiles, sched.sc.addedJars, ser)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 65078b026e..53fc659345 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -7,6 +7,7 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.HashMap
import spark._
+import executor.ExecutorURLClassLoader
import spark.scheduler._
/**
@@ -14,13 +15,21 @@ import spark.scheduler._
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
-class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends TaskScheduler with Logging {
+class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext)
+ extends TaskScheduler
+ with Logging {
+
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
- val fileSet: HashMap[String, Long] = new HashMap[String, Long]()
- val jarSet: HashMap[String, Long] = new HashMap[String, Long]()
+
+ // Application dependencies (added through SparkContext) that we've fetched so far on this node.
+ // Each map holds the master's timestamp for the version of that file or JAR we got.
+ val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
+ val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
+
+ val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
// TODO: Need to take into account stage priority in scheduling
@@ -35,8 +44,6 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
- task.fileSet ++= sc.addedFiles
- task.jarSet ++= sc.addedJars
val myAttemptId = attemptId.getAndIncrement()
threadPool.submit(new Runnable {
def run() {
@@ -49,19 +56,23 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T
logInfo("Running task " + idInJob)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
- task.downloadDependencies(fileSet, jarSet)
- // Create a new classLaoder for the downloaded JARs
- Thread.currentThread.setContextClassLoader(createClassLoader())
try {
+ Accumulators.clear()
+ Thread.currentThread().setContextClassLoader(classLoader)
+
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
- Accumulators.clear
val ser = SparkEnv.get.closureSerializer.newInstance()
- val bytes = ser.serialize(task)
+ val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser)
logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
+ val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
+ updateDependencies(taskFiles, taskJars) // Download any files added with addFile
val deserializedTask = ser.deserialize[Task[_]](
- bytes, Thread.currentThread.getContextClassLoader)
+ taskBytes, Thread.currentThread.getContextClassLoader)
+
+ // Run it
val result: Any = deserializedTask.run(attemptId)
+
// Serialize and deserialize the result to emulate what the Mesos
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
@@ -90,20 +101,35 @@ class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) extends T
submitTask(task, i)
}
}
-
+
+ /**
+ * Download any missing dependencies if we receive a new set of files and JARs from the
+ * SparkContext. Also adds any new JARs we fetched to the class loader.
+ */
+ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name)
+ Utils.fetchFile(name, new File("."))
+ currentFiles(name) = timestamp
+ }
+ for ((name, timestamp) <- newJars if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name)
+ Utils.fetchFile(name, new File("."))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(".", localName).toURI.toURL
+ if (!classLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ classLoader.addURL(url)
+ }
+ }
+ }
override def stop() {
threadPool.shutdownNow()
}
- private def createClassLoader() : ClassLoader = {
- val currentLoader = Thread.currentThread.getContextClassLoader()
- val urls = jarSet.keySet.map { uri =>
- new File(uri.split("/").last).toURI.toURL
- }.toArray
- logInfo("Creating ClassLoader with jars: " + urls.mkString)
- return new URLClassLoader(urls, currentLoader)
- }
-
override def defaultParallelism() = threads
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 2f14db4e28..8e4f9f7c15 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -395,10 +395,12 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B
}
def mustRegisterBlockManager(msg: RegisterBlockManager) {
+ logInfo("Trying to register BlockManager")
while (! syncRegisterBlockManager(msg)) {
logWarning("Failed to register " + msg)
Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
}
+ logInfo("Done registering BlockManager")
}
def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = {
diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala
index 0ce255105a..c92b60a40c 100644
--- a/core/src/main/scala/spark/util/ByteBufferInputStream.scala
+++ b/core/src/main/scala/spark/util/ByteBufferInputStream.scala
@@ -31,4 +31,6 @@ class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream {
buffer.position(buffer.position + amountToSkip)
return amountToSkip
}
+
+ def position: Int = buffer.position
}