diff options
Diffstat (limited to 'core/src/main')
95 files changed, 3376 insertions, 1893 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index bacd0ace37..57c6df35be 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -25,8 +25,7 @@ class Accumulable[R, T] ( extends Serializable { val id = Accumulators.newId - @transient - private var value_ = initialValue // Current value on master + @transient private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers var deserialized = false @@ -39,19 +38,36 @@ class Accumulable[R, T] ( def += (term: T) { value_ = param.addAccumulator(value_, term) } /** + * Add more data to this accumulator / accumulable + * @param term the data to add + */ + def add(term: T) { value_ = param.addAccumulator(value_, term) } + + /** * Merge two accumulable objects together - * + * * Normally, a user will not want to use this version, but will instead call `+=`. - * @param term the other Accumulable that will get merged with this + * @param term the other `R` that will get merged with this */ def ++= (term: R) { value_ = param.addInPlace(value_, term)} /** + * Merge two accumulable objects together + * + * Normally, a user will not want to use this version, but will instead call `add`. + * @param term the other `R` that will get merged with this + */ + def merge(term: R) { value_ = param.addInPlace(value_, term)} + + /** * Access the accumulator's current value; only allowed on master. */ - def value = { - if (!deserialized) value_ - else throw new UnsupportedOperationException("Can't read accumulator value in task") + def value: R = { + if (!deserialized) { + value_ + } else { + throw new UnsupportedOperationException("Can't read accumulator value in task") + } } /** @@ -68,10 +84,17 @@ class Accumulable[R, T] ( /** * Set the accumulator's value; only allowed on master. */ - def value_= (r: R) { - if (!deserialized) value_ = r + def value_= (newValue: R) { + if (!deserialized) value_ = newValue else throw new UnsupportedOperationException("Can't assign accumulator value in task") } + + /** + * Set the accumulator's value; only allowed on master + */ + def setValue(newValue: R) { + this.value = newValue + } // Called by Java when deserializing an object private def readObject(in: ObjectInputStream) { diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala deleted file mode 100644 index e8392a194f..0000000000 --- a/core/src/main/scala/spark/BoundedMemoryCache.scala +++ /dev/null @@ -1,118 +0,0 @@ -package spark - -import java.util.LinkedHashMap - -/** - * An implementation of Cache that estimates the sizes of its entries and attempts to limit its - * total memory usage to a fraction of the JVM heap. Objects' sizes are estimated using - * SizeEstimator, which has limitations; most notably, we will overestimate total memory used if - * some cache entries have pointers to a shared object. Nonetheless, this Cache should work well - * when most of the space is used by arrays of primitives or of simple classes. - */ -private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { - logInfo("BoundedMemoryCache.maxBytes = " + maxBytes) - - def this() { - this(BoundedMemoryCache.getMaxBytes) - } - - private var currentBytes = 0L - private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true) - - override def get(datasetId: Any, partition: Int): Any = { - synchronized { - val entry = map.get((datasetId, partition)) - if (entry != null) { - entry.value - } else { - null - } - } - } - - override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { - val key = (datasetId, partition) - logInfo("Asked to add key " + key) - val size = estimateValueSize(key, value) - synchronized { - if (size > getCapacity) { - return CachePutFailure() - } else if (ensureFreeSpace(datasetId, size)) { - logInfo("Adding key " + key) - map.put(key, new Entry(value, size)) - currentBytes += size - logInfo("Number of entries is now " + map.size) - return CachePutSuccess(size) - } else { - logInfo("Didn't add key " + key + " because we would have evicted part of same dataset") - return CachePutFailure() - } - } - } - - override def getCapacity: Long = maxBytes - - /** - * Estimate sizeOf 'value' - */ - private def estimateValueSize(key: (Any, Int), value: Any) = { - val startTime = System.currentTimeMillis - val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef]) - val timeTaken = System.currentTimeMillis - startTime - logInfo("Estimated size for key %s is %d".format(key, size)) - logInfo("Size estimation for key %s took %d ms".format(key, timeTaken)) - size - } - - /** - * Remove least recently used entries from the map until at least space bytes are free, in order - * to make space for a partition from the given dataset ID. If this cannot be done without - * evicting other data from the same dataset, returns false; otherwise, returns true. Assumes - * that a lock is held on the BoundedMemoryCache. - */ - private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = { - logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format( - datasetId, space, currentBytes, maxBytes)) - val iter = map.entrySet.iterator // Will give entries in LRU order - while (maxBytes - currentBytes < space && iter.hasNext) { - val mapEntry = iter.next() - val (entryDatasetId, entryPartition) = mapEntry.getKey - if (entryDatasetId == datasetId) { - // Cannot make space without removing part of the same dataset, or a more recently used one - return false - } - reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue) - currentBytes -= mapEntry.getValue.size - iter.remove() - } - return true - } - - protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { - logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) - // TODO: remove BoundedMemoryCache - - val (keySpaceId, innerDatasetId) = datasetId.asInstanceOf[(Any, Any)] - innerDatasetId match { - case rddId: Int => - SparkEnv.get.cacheTracker.dropEntry(rddId, partition) - case broadcastUUID: java.util.UUID => - // TODO: Maybe something should be done if the broadcasted variable falls out of cache - case _ => - } - } -} - -// An entry in our map; stores a cached object and its size in bytes -private[spark] case class Entry(value: Any, size: Long) - -private[spark] object BoundedMemoryCache { - /** - * Get maximum cache capacity from system configuration - */ - def getMaxBytes: Long = { - val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble - (Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong - } -} - diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala new file mode 100644 index 0000000000..a0b53fd9d6 --- /dev/null +++ b/core/src/main/scala/spark/CacheManager.scala @@ -0,0 +1,65 @@ +package spark + +import scala.collection.mutable.{ArrayBuffer, HashSet} +import spark.storage.{BlockManager, StorageLevel} + + +/** Spark class responsible for passing RDDs split contents to the BlockManager and making + sure a node doesn't load two copies of an RDD at once. + */ +private[spark] class CacheManager(blockManager: BlockManager) extends Logging { + private val loading = new HashSet[String] + + /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */ + def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) + : Iterator[T] = { + val key = "rdd_%d_%d".format(rdd.id, split.index) + logInfo("Cache key is " + key) + blockManager.get(key) match { + case Some(cachedValues) => + // Split is in cache, so just return its values + logInfo("Found partition in cache!") + return cachedValues.asInstanceOf[Iterator[T]] + + case None => + // Mark the split as loading (unless someone else marks it first) + loading.synchronized { + if (loading.contains(key)) { + logInfo("Loading contains " + key + ", waiting...") + while (loading.contains(key)) { + try {loading.wait()} catch {case _ =>} + } + logInfo("Loading no longer contains " + key + ", so returning cached result") + // See whether someone else has successfully loaded it. The main way this would fail + // is for the RDD-level cache eviction policy if someone else has loaded the same RDD + // partition but we didn't want to make space for it. However, that case is unlikely + // because it's unlikely that two threads would work on the same RDD partition. One + // downside of the current code is that threads wait serially if this does happen. + blockManager.get(key) match { + case Some(values) => + return values.asInstanceOf[Iterator[T]] + case None => + logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") + loading.add(key) + } + } else { + loading.add(key) + } + } + try { + // If we got here, we have to load the split + val elements = new ArrayBuffer[Any] + logInfo("Computing partition " + split) + elements ++= rdd.compute(split, context) + // Try to put this block in the blockManager + blockManager.put(key, elements, storageLevel, true) + return elements.iterator.asInstanceOf[Iterator[T]] + } finally { + loading.synchronized { + loading.remove(key) + loading.notifyAll() + } + } + } + } +} diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala deleted file mode 100644 index 3d79078733..0000000000 --- a/core/src/main/scala/spark/CacheTracker.scala +++ /dev/null @@ -1,238 +0,0 @@ -package spark - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet - -import akka.actor._ -import akka.dispatch._ -import akka.pattern.ask -import akka.remote._ -import akka.util.Duration -import akka.util.Timeout -import akka.util.duration._ - -import spark.storage.BlockManager -import spark.storage.StorageLevel - -private[spark] sealed trait CacheTrackerMessage - -private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L) - extends CacheTrackerMessage -private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L) - extends CacheTrackerMessage -private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage -private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage -private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage -private[spark] case object GetCacheStatus extends CacheTrackerMessage -private[spark] case object GetCacheLocations extends CacheTrackerMessage -private[spark] case object StopCacheTracker extends CacheTrackerMessage - -private[spark] class CacheTrackerActor extends Actor with Logging { - // TODO: Should probably store (String, CacheType) tuples - private val locs = new HashMap[Int, Array[List[String]]] - - /** - * A map from the slave's host name to its cache size. - */ - private val slaveCapacity = new HashMap[String, Long] - private val slaveUsage = new HashMap[String, Long] - - private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) - private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) - private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) - - def receive = { - case SlaveCacheStarted(host: String, size: Long) => - slaveCapacity.put(host, size) - slaveUsage.put(host, 0) - sender ! true - - case RegisterRDD(rddId: Int, numPartitions: Int) => - logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") - locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) - sender ! true - - case AddedToCache(rddId, partition, host, size) => - slaveUsage.put(host, getCacheUsage(host) + size) - locs(rddId)(partition) = host :: locs(rddId)(partition) - sender ! true - - case DroppedFromCache(rddId, partition, host, size) => - slaveUsage.put(host, getCacheUsage(host) - size) - // Do a sanity check to make sure usage is greater than 0. - locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) - sender ! true - - case MemoryCacheLost(host) => - logInfo("Memory cache lost on " + host) - for ((id, locations) <- locs) { - for (i <- 0 until locations.length) { - locations(i) = locations(i).filterNot(_ == host) - } - } - sender ! true - - case GetCacheLocations => - logInfo("Asked for current cache locations") - sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())} - - case GetCacheStatus => - val status = slaveCapacity.map { case (host, capacity) => - (host, capacity, getCacheUsage(host)) - }.toSeq - sender ! status - - case StopCacheTracker => - logInfo("Stopping CacheTrackerActor") - sender ! true - context.stop(self) - } -} - -private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager) - extends Logging { - - // Tracker actor on the master, or remote reference to it on workers - val ip: String = System.getProperty("spark.master.host", "localhost") - val port: Int = System.getProperty("spark.master.port", "7077").toInt - val actorName: String = "CacheTracker" - - val timeout = 10.seconds - - var trackerActor: ActorRef = if (isMaster) { - val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName) - logInfo("Registered CacheTrackerActor actor") - actor - } else { - val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName) - actorSystem.actorFor(url) - } - - val registeredRddIds = new HashSet[Int] - - // Remembers which splits are currently being loaded (on worker nodes) - val loading = new HashSet[String] - - // Send a message to the trackerActor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askTracker(message: Any): Any = { - try { - val future = trackerActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with CacheTracker", e) - } - } - - // Send a one-way message to the trackerActor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askTracker(message) != true) { - throw new SparkException("Error reply received from CacheTracker") - } - } - - // Registers an RDD (on master only) - def registerRDD(rddId: Int, numPartitions: Int) { - registeredRddIds.synchronized { - if (!registeredRddIds.contains(rddId)) { - logInfo("Registering RDD ID " + rddId + " with cache") - registeredRddIds += rddId - communicate(RegisterRDD(rddId, numPartitions)) - } - } - } - - // For BlockManager.scala only - def cacheLost(host: String) { - communicate(MemoryCacheLost(host)) - logInfo("CacheTracker successfully removed entries on " + host) - } - - // Get the usage status of slave caches. Each tuple in the returned sequence - // is in the form of (host name, capacity, usage). - def getCacheStatus(): Seq[(String, Long, Long)] = { - askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]] - } - - // For BlockManager.scala only - def notifyFromBlockManager(t: AddedToCache) { - communicate(t) - } - - // Get a snapshot of the currently known locations - def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { - askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] - } - - // Gets or computes an RDD split - def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) - : Iterator[T] = { - val key = "rdd_%d_%d".format(rdd.id, split.index) - logInfo("Cache key is " + key) - blockManager.get(key) match { - case Some(cachedValues) => - // Split is in cache, so just return its values - logInfo("Found partition in cache!") - return cachedValues.asInstanceOf[Iterator[T]] - - case None => - // Mark the split as loading (unless someone else marks it first) - loading.synchronized { - if (loading.contains(key)) { - logInfo("Loading contains " + key + ", waiting...") - while (loading.contains(key)) { - try {loading.wait()} catch {case _ =>} - } - logInfo("Loading no longer contains " + key + ", so returning cached result") - // See whether someone else has successfully loaded it. The main way this would fail - // is for the RDD-level cache eviction policy if someone else has loaded the same RDD - // partition but we didn't want to make space for it. However, that case is unlikely - // because it's unlikely that two threads would work on the same RDD partition. One - // downside of the current code is that threads wait serially if this does happen. - blockManager.get(key) match { - case Some(values) => - return values.asInstanceOf[Iterator[T]] - case None => - logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") - loading.add(key) - } - } else { - loading.add(key) - } - } - // If we got here, we have to load the split - // Tell the master that we're doing so - //val host = System.getProperty("spark.hostname", Utils.localHostName) - //val future = trackerActor !! AddedToCache(rdd.id, split.index, host) - // TODO: fetch any remote copy of the split that may be available - // TODO: also register a listener for when it unloads - logInfo("Computing partition " + split) - val elements = new ArrayBuffer[Any] - elements ++= rdd.compute(split, context) - try { - // Try to put this block in the blockManager - blockManager.put(key, elements, storageLevel, true) - //future.apply() // Wait for the reply from the cache tracker - } finally { - loading.synchronized { - loading.remove(key) - loading.notifyAll() - } - } - return elements.iterator.asInstanceOf[Iterator[T]] - } - } - - // Called by the Cache to report that an entry has been dropped from it - def dropEntry(rddId: Int, partition: Int) { - communicate(DroppedFromCache(rddId, partition, Utils.localHostName())) - } - - def stop() { - communicate(StopCacheTracker) - registeredRddIds.clear() - trackerActor = null - } -} diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala deleted file mode 100644 index 56e59adeb7..0000000000 --- a/core/src/main/scala/spark/DaemonThreadFactory.scala +++ /dev/null @@ -1,18 +0,0 @@ -package spark - -import java.util.concurrent.ThreadFactory - -/** - * A ThreadFactory that creates daemon threads - */ -private object DaemonThreadFactory extends ThreadFactory { - override def newThread(r: Runnable): Thread = new DaemonThread(r) -} - -private class DaemonThread(r: Runnable = null) extends Thread { - override def run() { - if (r != null) { - r.run() - } - } -}
\ No newline at end of file diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala index 659d17718f..00901d95e2 100644 --- a/core/src/main/scala/spark/HttpFileServer.scala +++ b/core/src/main/scala/spark/HttpFileServer.scala @@ -1,9 +1,7 @@ package spark -import java.io.{File, PrintWriter} -import java.net.URL -import scala.collection.mutable.HashMap -import org.apache.hadoop.fs.FileUtil +import java.io.{File} +import com.google.common.io.Files private[spark] class HttpFileServer extends Logging { @@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging { } def addFileToDir(file: File, dir: File) : String = { - Utils.copyFile(file, new File(dir, file.getName)) + Files.copy(file, new File(dir, file.getName)) return dir + "/" + file.getName } diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala index 0196595ba1..4e0507c080 100644 --- a/core/src/main/scala/spark/HttpServer.scala +++ b/core/src/main/scala/spark/HttpServer.scala @@ -4,6 +4,7 @@ import java.io.File import java.net.InetAddress import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.bio.SocketConnector import org.eclipse.jetty.server.handler.DefaultHandler import org.eclipse.jetty.server.handler.HandlerList import org.eclipse.jetty.server.handler.ResourceHandler @@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging { if (server != null) { throw new ServerStateException("Server is already started") } else { - server = new Server(0) + server = new Server() + val connector = new SocketConnector + connector.setMaxIdleTime(60*1000) + connector.setSoLingerTime(-1) + connector.setPort(0) + server.addConnector(connector) + val threadPool = new QueuedThreadPool threadPool.setDaemon(true) server.setThreadPool(threadPool) diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 44b630e478..0bd73e936b 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -9,153 +9,80 @@ import scala.collection.mutable import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo.{Serializer => KSerializer} -import com.esotericsoftware.kryo.serialize.ClassSerializer -import com.esotericsoftware.kryo.serialize.SerializableSerializer +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import de.javakaffee.kryoserializers.KryoReflectionFactorySupport import serializer.{SerializerInstance, DeserializationStream, SerializationStream} import spark.broadcast._ import spark.storage._ -/** - * Zig-zag encoder used to write object sizes to serialization streams. - * Based on Kryo's integer encoder. - */ -private[spark] object ZigZag { - def writeInt(n: Int, out: OutputStream) { - var value = n - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - if ((value & ~0x7F) == 0) { - out.write(value) - return - } - out.write(((value & 0x7F) | 0x80)) - value >>>= 7 - out.write(value) - } +private[spark] +class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream { - def readInt(in: InputStream): Int = { - var offset = 0 - var result = 0 - while (offset < 32) { - val b = in.read() - if (b == -1) { - throw new EOFException("End of stream") - } - result |= ((b & 0x7F) << offset) - if ((b & 0x80) == 0) { - return result - } - offset += 7 - } - throw new SparkException("Malformed zigzag-encoded integer") - } -} - -private[spark] -class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream) -extends SerializationStream { - val channel = Channels.newChannel(out) + val output = new KryoOutput(outStream) def writeObject[T](t: T): SerializationStream = { - kryo.writeClassAndObject(threadBuffer, t) - ZigZag.writeInt(threadBuffer.position(), out) - threadBuffer.flip() - channel.write(threadBuffer) - threadBuffer.clear() + kryo.writeClassAndObject(output, t) this } - def flush() { out.flush() } - def close() { out.close() } + def flush() { output.flush() } + def close() { output.close() } } -private[spark] -class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream) -extends DeserializationStream { +private[spark] +class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream { + + val input = new KryoInput(inStream) + def readObject[T](): T = { - val len = ZigZag.readInt(in) - objectBuffer.readClassAndObject(in, len).asInstanceOf[T] + try { + kryo.readClassAndObject(input).asInstanceOf[T] + } catch { + // DeserializationStream uses the EOF exception to indicate stopping condition. + case e: com.esotericsoftware.kryo.KryoException => throw new java.io.EOFException + } } - def close() { in.close() } + def close() { + // Kryo's Input automatically closes the input stream it is using. + input.close() + } } private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val kryo = ks.kryo - val threadBuffer = ks.threadBuffer.get() - val objectBuffer = ks.objectBuffer.get() + + val kryo = ks.kryo.get() + val output = ks.output.get() + val input = ks.input.get() def serialize[T](t: T): ByteBuffer = { - // Write it to our thread-local scratch buffer first to figure out the size, then return a new - // ByteBuffer of the appropriate size - threadBuffer.clear() - kryo.writeClassAndObject(threadBuffer, t) - val newBuf = ByteBuffer.allocate(threadBuffer.position) - threadBuffer.flip() - newBuf.put(threadBuffer) - newBuf.flip() - newBuf + output.clear() + kryo.writeClassAndObject(output, t) + ByteBuffer.wrap(output.toBytes) } def deserialize[T](bytes: ByteBuffer): T = { - kryo.readClassAndObject(bytes).asInstanceOf[T] + input.setBuffer(bytes.array) + kryo.readClassAndObject(input).asInstanceOf[T] } def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { val oldClassLoader = kryo.getClassLoader kryo.setClassLoader(loader) - val obj = kryo.readClassAndObject(bytes).asInstanceOf[T] + input.setBuffer(bytes.array) + val obj = kryo.readClassAndObject(input).asInstanceOf[T] kryo.setClassLoader(oldClassLoader) obj } def serializeStream(s: OutputStream): SerializationStream = { - threadBuffer.clear() - new KryoSerializationStream(kryo, threadBuffer, s) + new KryoSerializationStream(kryo, s) } def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(objectBuffer, s) - } - - override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { - threadBuffer.clear() - while (iterator.hasNext) { - val element = iterator.next() - // TODO: Do we also want to write the object's size? Doesn't seem necessary. - kryo.writeClassAndObject(threadBuffer, element) - } - val newBuf = ByteBuffer.allocate(threadBuffer.position) - threadBuffer.flip() - newBuf.put(threadBuffer) - newBuf.flip() - newBuf - } - - override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { - buffer.rewind() - new Iterator[Any] { - override def hasNext: Boolean = buffer.remaining > 0 - override def next(): Any = kryo.readClassAndObject(buffer) - } + new KryoDeserializationStream(kryo, s) } } @@ -171,18 +98,19 @@ trait KryoRegistrator { * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. */ class KryoSerializer extends spark.serializer.Serializer with Logging { - // Make this lazy so that it only gets called once we receive our first task on each executor, - // so we can pull out any custom Kryo registrator from the user's JARs. - lazy val kryo = createKryo() - val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024 + val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 - val objectBuffer = new ThreadLocal[ObjectBuffer] { - override def initialValue = new ObjectBuffer(kryo, bufferSize) + val kryo = new ThreadLocal[Kryo] { + override def initialValue = createKryo() } - val threadBuffer = new ThreadLocal[ByteBuffer] { - override def initialValue = ByteBuffer.allocate(bufferSize) + val output = new ThreadLocal[KryoOutput] { + override def initialValue = new KryoOutput(bufferSize) + } + + val input = new ThreadLocal[KryoInput] { + override def initialValue = new KryoInput(bufferSize) } def createKryo(): Kryo = { @@ -213,41 +141,44 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { kryo.register(obj.getClass) } - // Register the following classes for passing closures. - kryo.register(classOf[Class[_]], new ClassSerializer(kryo)) - kryo.setRegistrationOptional(true) - // Allow sending SerializableWritable - kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer()) - kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer()) + kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we // deserialize rather than returning a clone as FieldSerializer would. - class SingletonSerializer(obj: AnyRef) extends KSerializer { - override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {} - override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = obj.asInstanceOf[T] + class SingletonSerializer[T](obj: T) extends KSerializer[T] { + override def write(kryo: Kryo, output: KryoOutput, obj: T) {} + override def read(kryo: Kryo, input: KryoInput, cls: java.lang.Class[T]): T = obj } - kryo.register(None.getClass, new SingletonSerializer(None)) - kryo.register(Nil.getClass, new SingletonSerializer(Nil)) + kryo.register(None.getClass, new SingletonSerializer[AnyRef](None)) + kryo.register(Nil.getClass, new SingletonSerializer[AnyRef](Nil)) // Register maps with a special serializer since they have complex internal structure class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any]) - extends KSerializer { - override def writeObjectData(buf: ByteBuffer, obj: AnyRef) { + extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] { + override def write( + kryo: Kryo, + output: KryoOutput, + obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) { val map = obj.asInstanceOf[scala.collection.Map[Any, Any]] - kryo.writeObject(buf, map.size.asInstanceOf[java.lang.Integer]) + kryo.writeObject(output, map.size.asInstanceOf[java.lang.Integer]) for ((k, v) <- map) { - kryo.writeClassAndObject(buf, k) - kryo.writeClassAndObject(buf, v) + kryo.writeClassAndObject(output, k) + kryo.writeClassAndObject(output, v) } } - override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = { - val size = kryo.readObject(buf, classOf[java.lang.Integer]).intValue + override def read ( + kryo: Kryo, + input: KryoInput, + cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]]) + : Array[(Any, Any)] => scala.collection.Map[Any, Any] = { + val size = kryo.readObject(input, classOf[java.lang.Integer]).intValue val elems = new Array[(Any, Any)](size) for (i <- 0 until size) - elems(i) = (kryo.readClassAndObject(buf), kryo.readClassAndObject(buf)) - buildMap(elems).asInstanceOf[T] + elems(i) = (kryo.readClassAndObject(input), kryo.readClassAndObject(input)) + buildMap(elems).asInstanceOf[Array[(Any, Any)] => scala.collection.Map[Any, Any]] } } kryo.register(mutable.HashMap().getClass, new ScalaMapSerializer(mutable.HashMap() ++ _)) @@ -275,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging { kryo } - def newInstance(): SerializerInstance = new KryoSerializerInstance(this) + def newInstance(): SerializerInstance = { + this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader) + new KryoSerializerInstance(this) + } } diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 90bae26202..7c1c1bb144 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine - @transient - private var log_ : Logger = null + @transient private var log_ : Logger = null // Method to get or create the logger for this object protected def log: Logger = { diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 70eb9f702e..ac02f3363a 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -17,6 +17,7 @@ import akka.util.duration._ import spark.scheduler.MapStatus import spark.storage.BlockManagerId +import spark.util.{MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait MapOutputTrackerMessage @@ -44,7 +45,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea val timeout = 10.seconds - var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]] + var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. @@ -53,7 +54,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea // Cache a serialized version of the output statuses for each shuffle to send them out faster var cacheGeneration = generation - val cachedSerializedStatuses = new HashMap[Int, Array[Byte]] + val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] var trackerActor: ActorRef = if (isMaster) { val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) @@ -64,6 +65,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea actorSystem.actorFor(url) } + val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) + // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. def askTracker(message: Any): Any = { @@ -84,14 +87,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.get(shuffleId) != null) { + if (mapStatuses.get(shuffleId) != None) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)) } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - var array = mapStatuses.get(shuffleId) + var array = mapStatuses(shuffleId) array.synchronized { array(mapId) = status } @@ -108,7 +111,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea } def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = mapStatuses.get(shuffleId) + var array = mapStatuses(shuffleId) if (array != null) { array.synchronized { if (array(mapId) != null && array(mapId).address == bmAddress) { @@ -126,7 +129,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { - val statuses = mapStatuses.get(shuffleId) + val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") fetching.synchronized { @@ -139,8 +142,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case e: InterruptedException => } } - return mapStatuses.get(shuffleId).map(status => - (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId)) } else { fetching += shuffleId } @@ -156,27 +158,27 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea fetchedStatuses = deserializeStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) - if (fetchedStatuses.contains(null)) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing an output location for shuffle " + shuffleId)) - } } finally { fetching.synchronized { fetching -= shuffleId fetching.notifyAll() } } - return fetchedStatuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } else { - return statuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) } } + def cleanup(cleanupTime: Long) { + mapStatuses.clearOldValues(cleanupTime) + cachedSerializedStatuses.clearOldValues(cleanupTime) + } + def stop() { communicate(StopMapOutputTracker) mapStatuses.clear() + metadataCleaner.cancel() trackerActor = null } @@ -202,7 +204,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea generationLock.synchronized { if (newGen > generation) { logInfo("Updating generation to " + newGen + " and clearing cache") - mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]] + mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] generation = newGen } } @@ -220,7 +222,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case Some(bytes) => return bytes case None => - statuses = mapStatuses.get(shuffleId) + statuses = mapStatuses(shuffleId) generationGotten = generation } } @@ -258,6 +260,28 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea private[spark] object MapOutputTracker { private val LOG_BASE = 1.1 + // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If + // any of the statuses is null (indicating a missing location due to a failed mapper), + // throw a FetchFailedException. + def convertMapStatuses( + shuffleId: Int, + reduceId: Int, + statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + if (statuses == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing all output locations for shuffle " + shuffleId)) + } + statuses.map { + status => + if (status == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing an output location for shuffle " + shuffleId)) + } else { + (status.address, decompressSize(status.compressedSizes(reduceId))) + } + } + } + /** * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. * We do this by encoding the log base 1.1 of the size as an integer, which can support diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 08ae06e865..53b051f1c5 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -52,6 +52,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true): RDD[(K, C)] = { + if (getKeyClass().isArray) { + if (mapSideCombine) { + throw new SparkException("Cannot use map-side combining with array keys.") + } + if (partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (mapSideCombine) { @@ -92,6 +100,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { + + if (getKeyClass().isArray) { + throw new SparkException("reduceByKeyLocally() does not support array keys") + } + def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { val map = new JHashMap[K, V] for ((k, v) <- iter) { @@ -165,6 +178,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * be set to true. */ def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = { + if (getKeyClass().isArray) { + if (mapSideCombine) { + throw new SparkException("Cannot use map-side combining with array keys.") + } + if (partitioner.isInstanceOf[HashPartitioner]) { + throw new SparkException("Default partitioner cannot partition array keys.") + } + } if (mapSideCombine) { def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v @@ -178,9 +199,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } /** - * Merge the values for each key using an associative reduce function. This will also perform - * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { this.cogroup(other, partitioner).flatMapValues { @@ -336,6 +357,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * list of values for that key in `this` as well as `other`. */ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } val cg = new CoGroupedRDD[K]( Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]), partitioner) @@ -352,6 +376,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { + throw new SparkException("Default partitioner cannot partition array keys.") + } val cg = new CoGroupedRDD[K]( Seq(self.asInstanceOf[RDD[(_, _)]], other1.asInstanceOf[RDD[(_, _)]], @@ -438,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( val res = self.context.runJob(self, process _, Array(index), false) res(0) case None => - throw new UnsupportedOperationException("lookup() called on an RDD without a partitioner") + self.filter(_._1 == key).map(_._2).collect } } @@ -466,20 +493,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( path: String, keyClass: Class[_], valueClass: Class[_], - outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) { - saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration) - } - - /** - * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` - * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. - */ - def saveAsNewAPIHadoopFile( - path: String, - keyClass: Class[_], - valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration) { + conf: Configuration = self.context.hadoopConfiguration) { val job = new NewAPIHadoopJob(conf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) @@ -530,7 +545,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: OutputFormat[_, _]], - conf: JobConf = new JobConf) { + conf: JobConf = new JobConf(self.context.hadoopConfiguration)) { conf.setOutputKeyClass(keyClass) conf.setOutputValueClass(valueClass) // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug @@ -588,6 +603,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( writer.cleanup() } + /** + * Return an RDD with the keys of each tuple. + */ + def keys: RDD[K] = self.map(_._1) + + /** + * Return an RDD with the values of each tuple. + */ + def values: RDD[V] = self.map(_._2) + private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure @@ -624,24 +649,23 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } private[spark] -class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override val partitioner = prev.partitioner - override def compute(split: Split, taskContext: TaskContext) = - prev.iterator(split, taskContext).map{case (k, v) => (k, f(v))} +class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) + extends RDD[(K, U)](prev) { + + override def getSplits = firstParent[(K, V)].splits + override val partitioner = firstParent[(K, V)].partitioner + override def compute(split: Split, context: TaskContext) = + firstParent[(K, V)].iterator(split, context).map{ case (k, v) => (k, f(v)) } } private[spark] class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]) - extends RDD[(K, U)](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override val partitioner = prev.partitioner + extends RDD[(K, U)](prev) { - override def compute(split: Split, taskContext: TaskContext) = { - prev.iterator(split, taskContext).flatMap { case (k, v) => f(v).map(x => (k, x)) } + override def getSplits = firstParent[(K, V)].splits + override val partitioner = firstParent[(K, V)].partitioner + override def compute(split: Split, context: TaskContext) = { + firstParent[(K, V)].iterator(split, context).flatMap { case (k, v) => f(v).map(x => (k, x)) } } } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index a27f766e31..10adcd53ec 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -2,6 +2,7 @@ package spark import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer +import scala.collection.Map private[spark] class ParallelCollectionSplit[T: ClassManifest]( val rddId: Long, @@ -22,28 +23,33 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( } private[spark] class ParallelCollection[T: ClassManifest]( - sc: SparkContext, + @transient sc: SparkContext, @transient data: Seq[T], - numSlices: Int) - extends RDD[T](sc) { + numSlices: Int, + locationPrefs: Map[Int,Seq[String]]) + extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split // instead. + // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. - @transient - val splits_ = { + @transient var splits_ : Array[Split] = { val slices = ParallelCollection.slice(data, numSlices).toArray slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } - override def splits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_ - override def compute(s: Split, taskContext: TaskContext) = + override def compute(s: Split, context: TaskContext) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator - override def preferredLocations(s: Split): Seq[String] = Nil + override def getPreferredLocations(s: Split): Seq[String] = { + locationPrefs.getOrElse(s.index, Nil) + } - override val dependencies: List[Dependency[_]] = Nil + override def clearDependencies() { + splits_ = null + } } private object ParallelCollection { diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index b71021a082..9d5b966e1e 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -11,6 +11,10 @@ abstract class Partitioner extends Serializable { /** * A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. + * + * Java arrays have hashCodes that are based on the arrays' identities rather than their contents, + * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will + * produce an unexpected or incorrect result. */ class HashPartitioner(partitions: Int) extends Partitioner { def numPartitions = partitions diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index bb4c13c494..c79f34342f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -1,10 +1,8 @@ package spark -import java.io.EOFException -import java.io.ObjectInputStream +import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream} import java.net.URL -import java.util.Random -import java.util.Date +import java.util.{Date, Random} import java.util.{HashMap => JHashMap} import java.util.concurrent.atomic.AtomicLong @@ -13,6 +11,7 @@ import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text @@ -73,41 +72,42 @@ import SparkContext._ * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details * on RDD internals. */ -abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable { +abstract class RDD[T: ClassManifest]( + @transient var sc: SparkContext, + var dependencies_ : List[Dependency[_]] + ) extends Serializable with Logging { - // Methods that must be implemented by subclasses: - /** Set of partitions in this RDD. */ - def splits: Array[Split] + def this(@transient oneParent: RDD[_]) = + this(oneParent.context , List(new OneToOneDependency(oneParent))) + + // ======================================================================= + // Methods that should be implemented by subclasses of RDD + // ======================================================================= /** Function for computing a given partition. */ def compute(split: Split, context: TaskContext): Iterator[T] - /** How this RDD depends on any parent RDDs. */ - @transient val dependencies: List[Dependency[_]] + /** Set of partitions in this RDD. */ + protected def getSplits(): Array[Split] - // Methods available on all RDDs: + /** How this RDD depends on any parent RDDs. */ + protected def getDependencies(): List[Dependency[_]] = dependencies_ - /** Record user function generating this RDD. */ - private[spark] val origin = Utils.getSparkCallSite + /** Optionally overridden by subclasses to specify placement preferences. */ + protected def getPreferredLocations(split: Split): Seq[String] = Nil /** Optionally overridden by subclasses to specify how they are partitioned. */ val partitioner: Option[Partitioner] = None - /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = Nil - - /** The [[spark.SparkContext]] that this RDD was created on. */ - def context = sc - private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] + // ======================================================================= + // Methods and fields available on all RDDs + // ======================================================================= /** A unique ID for this RDD (within its SparkContext). */ val id = sc.newRddId() - // Variables relating to persistence - private var storageLevel: StorageLevel = StorageLevel.NONE - /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. @@ -131,22 +131,39 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { - if (!level.useDisk && level.replication < 2) { - throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") + /** + * Get the preferred location of a split, taking into account whether the + * RDD is checkpointed or not. + */ + final def preferredLocations(split: Split): Seq[String] = { + if (isCheckpointed) { + checkpointData.get.getPreferredLocations(split) + } else { + getPreferredLocations(split) } + } - // This is a hack. Ideally this should re-use the code used by the CacheTracker - // to generate the key. - def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index) - - persist(level) - sc.runJob(this, (iter: Iterator[T]) => {} ) - - val p = this.partitioner + /** + * Get the array of splits of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def splits: Array[Split] = { + if (isCheckpointed) { + checkpointData.get.getSplits + } else { + getSplits + } + } - new BlockRDD[T](sc, splits.map(getSplitKey).toArray) { - override val partitioner = p + /** + * Get the list of dependencies of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def dependencies: List[Dependency[_]] = { + if (isCheckpointed) { + dependencies_ + } else { + getDependencies } } @@ -156,8 +173,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial * subclasses of RDD. */ final def iterator(split: Split, context: TaskContext): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel) + if (isCheckpointed) { + checkpointData.get.iterator(split, context) + } else if (storageLevel != StorageLevel.NONE) { + SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) } else { compute(split, context) } @@ -185,9 +204,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(numSplits: Int = splits.size): RDD[T] = + def distinct(numSplits: Int): RDD[T] = map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1) + def distinct(): RDD[T] = distinct(splits.size) + /** * Return a sampled subset of this RDD. */ @@ -328,6 +349,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def toArray(): Array[T] = collect() /** + * Return an RDD that contains all matching values by applying `f`. + */ + def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { + filter(f.isDefinedAt).map(f) + } + + /** * Reduces the elements of this RDD using the specified associative binary operator. */ def reduce(f: (T, T) => T): T = { @@ -415,6 +443,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): Map[T, Long] = { + if (elementClassManifest.erasure.isArray) { + throw new SparkException("countByValue() does not support arrays") + } // TODO: This should perhaps be distributed by default. def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { val map = new OLMap[T] @@ -443,6 +474,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial timeout: Long, confidence: Double = 0.95 ): PartialResult[Map[T, BoundedDouble]] = { + if (elementClassManifest.erasure.isArray) { + throw new SparkException("countByValueApprox() does not support arrays") + } val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => val map = new OLMap[T] while (iter.hasNext) { @@ -502,8 +536,95 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial .saveAsSequenceFile(path) } + /** + * Creates tuples of the elements in this RDD by applying `f`. + */ + def keyBy[K](f: T => K): RDD[(K, T)] = { + map(x => (f(x), x)) + } + /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } + + /** + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. + */ + def checkpoint() { + if (context.checkpointDir.isEmpty) { + throw new Exception("Checkpoint directory has not been set in the SparkContext") + } else if (checkpointData.isEmpty) { + checkpointData = Some(new RDDCheckpointData(this)) + checkpointData.get.markForCheckpoint() + } + } + + /** + * Return whether this RDD has been checkpointed or not + */ + def isCheckpointed(): Boolean = { + if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false + } + + /** + * Gets the name of the file to which this RDD was checkpointed + */ + def getCheckpointFile(): Option[String] = { + if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None + } + + // ======================================================================= + // Other internal methods and fields + // ======================================================================= + + private var storageLevel: StorageLevel = StorageLevel.NONE + + /** Record user function generating this RDD. */ + private[spark] val origin = Utils.getSparkCallSite + + private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] + + private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None + + /** Returns the first parent RDD */ + protected[spark] def firstParent[U: ClassManifest] = { + dependencies.head.rdd.asInstanceOf[RDD[U]] + } + + /** The [[spark.SparkContext]] that this RDD was created on. */ + def context = sc + + /** + * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler + * after a job using this RDD has completed (therefore the RDD has been materialized and + * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. + */ + protected[spark] def doCheckpoint() { + if (checkpointData.isDefined) checkpointData.get.doCheckpoint() + dependencies.foreach(_.rdd.doCheckpoint()) + } + + /** + * Changes the dependencies of this RDD from its original parents to the new RDD + * (`newRDD`) created from the checkpoint file. + */ + protected[spark] def changeDependencies(newRDD: RDD[_]) { + clearDependencies() + dependencies_ = List(new OneToOneDependency(newRDD)) + } + + /** + * Clears the dependencies of this RDD. This method must ensure that all references + * to the original parent RDDs is removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own cleaning + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected[spark] def clearDependencies() { + dependencies_ = null + } } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala new file mode 100644 index 0000000000..18df530b7d --- /dev/null +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -0,0 +1,105 @@ +package spark + +import org.apache.hadoop.fs.Path +import rdd.{CheckpointRDD, CoalescedRDD} +import scheduler.{ResultTask, ShuffleMapTask} + +/** + * Enumeration to manage state transitions of an RDD through checkpointing + * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ] + */ +private[spark] object CheckpointState extends Enumeration { + type CheckpointState = Value + val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value +} + +/** + * This class contains all the information related to RDD checkpointing. Each instance of this class + * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as, + * manages the post-checkpoint state by providing the updated splits, iterator and preferred locations + * of the checkpointed RDD. + */ +private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) +extends Logging with Serializable { + + import CheckpointState._ + + // The checkpoint state of the associated RDD. + var cpState = Initialized + + // The file to which the associated RDD has been checkpointed to + @transient var cpFile: Option[String] = None + + // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD. + @transient var cpRDD: Option[RDD[T]] = None + + // Mark the RDD for checkpointing + def markForCheckpoint() { + RDDCheckpointData.synchronized { + if (cpState == Initialized) cpState = MarkedForCheckpoint + } + } + + // Is the RDD already checkpointed + def isCheckpointed(): Boolean = { + RDDCheckpointData.synchronized { cpState == Checkpointed } + } + + // Get the file to which this RDD was checkpointed to as an Option + def getCheckpointFile(): Option[String] = { + RDDCheckpointData.synchronized { cpFile } + } + + // Do the checkpointing of the RDD. Called after the first job using that RDD is over. + def doCheckpoint() { + // If it is marked for checkpointing AND checkpointing is not already in progress, + // then set it to be in progress, else return + RDDCheckpointData.synchronized { + if (cpState == MarkedForCheckpoint) { + cpState = CheckpointingInProgress + } else { + return + } + } + + // Save to file, and reload it as an RDD + val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) + val newRDD = new CheckpointRDD[T](rdd.context, path) + + // Change the dependencies and splits of the RDD + RDDCheckpointData.synchronized { + cpFile = Some(path) + cpRDD = Some(newRDD) + rdd.changeDependencies(newRDD) + cpState = Checkpointed + RDDCheckpointData.clearTaskCaches() + logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) + } + } + + // Get preferred location of a split after checkpointing + def getPreferredLocations(split: Split) = { + RDDCheckpointData.synchronized { + cpRDD.get.preferredLocations(split) + } + } + + def getSplits: Array[Split] = { + RDDCheckpointData.synchronized { + cpRDD.get.splits + } + } + + // Get iterator. This is called at the worker nodes. + def iterator(split: Split, context: TaskContext): Iterator[T] = { + rdd.firstParent[T].iterator(split, context) + } +} + +private[spark] object RDDCheckpointData { + def clearTaskCaches() { + ShuffleMapTask.clearCache() + ResultTask.clearCache() + } +} diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala index a34aee69c1..6b4a11d6d3 100644 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala @@ -42,7 +42,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { classManifest[T].erasure } else { - implicitly[T => Writable].getClass.getMethods()(0).getReturnType + // We get the type of the Writable class by looking at the apply method which converts + // from T to Writable. Since we have two apply methods we filter out the one which + // is of the form "java.lang.Object apply(java.lang.Object)" + implicitly[T => Writable].getClass.getDeclaredMethods().filter( + m => m.getReturnType().toString != "java.lang.Object" && + m.getName() == "apply")(0).getReturnType + } // TODO: use something like WritableConverter to avoid reflection } diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala index 7c3e8640e9..d4e1157250 100644 --- a/core/src/main/scala/spark/SizeEstimator.scala +++ b/core/src/main/scala/spark/SizeEstimator.scala @@ -9,7 +9,6 @@ import java.util.Random import javax.management.MBeanServer import java.lang.management.ManagementFactory -import com.sun.management.HotSpotDiagnosticMXBean import scala.collection.mutable.ArrayBuffer @@ -76,12 +75,20 @@ private[spark] object SizeEstimator extends Logging { if (System.getProperty("spark.test.useCompressedOops") != null) { return System.getProperty("spark.test.useCompressedOops").toBoolean } + try { val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" val server = ManagementFactory.getPlatformMBeanServer() + + // NOTE: This should throw an exception in non-Sun JVMs + val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean") + val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", + Class.forName("java.lang.String")) + val bean = ManagementFactory.newPlatformMXBeanProxy(server, - hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean]) - return bean.getVMOption("UseCompressedOops").getValue.toBoolean + hotSpotMBeanName, hotSpotMBeanClass) + // TODO: We could use reflection on the VMOption returned ? + return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") } catch { case e: Exception => { // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 0afab522af..66bdbe7cda 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -3,10 +3,12 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger import java.net.{URI, URLClassLoader} +import java.lang.ref.WeakReference import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.JavaConversions._ import akka.actor.Actor import akka.actor.Actor._ @@ -36,12 +38,8 @@ import spark.broadcast._ import spark.deploy.LocalSparkCluster import spark.partial.ApproximateEvaluator import spark.partial.PartialResult -import spark.rdd.HadoopRDD -import spark.rdd.NewHadoopRDD -import spark.rdd.UnionRDD -import spark.scheduler.ShuffleMapTask -import spark.scheduler.DAGScheduler -import spark.scheduler.TaskScheduler +import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD} +import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} @@ -58,29 +56,13 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend * @param environment Environment variables to set on worker nodes. */ class SparkContext( - master: String, - jobName: String, - val sparkHome: String, - jars: Seq[String], - environment: Map[String, String]) + val master: String, + val jobName: String, + val sparkHome: String = null, + val jars: Seq[String] = Nil, + environment: Map[String, String] = Map()) extends Logging { - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI - * @param sparkHome Location where Spark is installed on cluster nodes. - * @param jars Collection of JARs to send to the cluster. These can be paths on the local file - * system or HDFS, HTTP, HTTPS, or FTP URLs. - */ - def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) = - this(master, jobName, sparkHome, jars, Map()) - - /** - * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI - */ - def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map()) - // Ensure logging is initialized before we spawn any threads initLogging() @@ -187,11 +169,32 @@ class SparkContext( private var dagScheduler = new DAGScheduler(taskScheduler) + /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ + val hadoopConfiguration = { + val conf = new Configuration() + // Explicitly check for S3 environment variables + if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { + conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID")) + conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) + } + // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" + for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) { + conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) + } + val bufferSize = System.getProperty("spark.buffer.size", "65536") + conf.set("io.file.buffer.size", bufferSize) + conf + } + + private[spark] var checkpointDir: Option[String] = None + // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { - new ParallelCollection[T](this, seq, numSlices) + new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]()) } /** Distribute a local Scala collection to form an RDD. */ @@ -199,6 +202,14 @@ class SparkContext( parallelize(seq, numSlices) } + /** Distribute a local Scala collection to form an RDD, with one or more + * location preferences (hostnames of Spark nodes) for each object. + * Create a new partition for each collection item. */ + def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { + val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap + new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs) + } + /** * Read a text file from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI, and return it as an RDD of Strings. @@ -231,10 +242,8 @@ class SparkContext( valueClass: Class[V], minSplits: Int = defaultMinSplits ) : RDD[(K, V)] = { - val conf = new JobConf() + val conf = new JobConf(hadoopConfiguration) FileInputFormat.setInputPaths(conf, path) - val bufferSize = System.getProperty("spark.buffer.size", "65536") - conf.set("io.file.buffer.size", bufferSize) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) } @@ -275,8 +284,7 @@ class SparkContext( path, fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]], - new Configuration) + vm.erasure.asInstanceOf[Class[V]]) } /** @@ -288,7 +296,7 @@ class SparkContext( fClass: Class[F], kClass: Class[K], vClass: Class[V], - conf: Configuration): RDD[(K, V)] = { + conf: Configuration = hadoopConfiguration): RDD[(K, V)] = { val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration @@ -300,7 +308,7 @@ class SparkContext( * and extra configuration options to pass to the input format. */ def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( - conf: Configuration, + conf: Configuration = hadoopConfiguration, fClass: Class[F], kClass: Class[K], vClass: Class[V]): RDD[(K, V)] = { @@ -365,6 +373,13 @@ class SparkContext( .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes)) } + + protected[spark] def checkpointFile[T: ClassManifest]( + path: String + ): RDD[T] = { + new CheckpointRDD[T](this, path) + } + /** Build the union of a list of RDDs. */ def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) @@ -382,11 +397,12 @@ class SparkContext( new Accumulator(initialValue, param) /** - * Create an [[spark.Accumulable]] shared variable, with a `+=` method + * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. + * Only the master can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ - def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = + def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) /** @@ -404,12 +420,13 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal) + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** - * Add a file to be downloaded into the working directory of this Spark job on every node. + * Add a file to be downloaded with this Spark job on every node. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. */ def addFile(path: String) { val uri = new URI(path) @@ -419,9 +436,10 @@ class SparkContext( } addedFiles(key) = System.currentTimeMillis - // Fetch the file locally in case the task is executed locally - val filename = new File(path.split("/").last) - Utils.fetchFile(path, new File(".")) + // Fetch the file locally in case a job is executed locally. + // Jobs that run through LocalScheduler will already fetch the required dependencies, + // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } @@ -437,11 +455,10 @@ class SparkContext( } /** - * Clear the job's list of files added by `addFile` so that they do not get donwloaded to + * Clear the job's list of files added by `addFile` so that they do not get downloaded to * any new nodes. */ def clearFiles() { - addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } addedFiles.clear() } @@ -465,23 +482,27 @@ class SparkContext( * any new nodes. */ def clearJars() { - addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } addedJars.clear() } /** Shut down the SparkContext. */ def stop() { - dagScheduler.stop() - dagScheduler = null - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - // Clean up locally linked files - clearFiles() - clearJars() - SparkEnv.set(null) - ShuffleMapTask.clearCache() - logInfo("Successfully stopped SparkContext") + if (dagScheduler != null) { + dagScheduler.stop() + dagScheduler = null + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + // Clean up locally linked files + clearFiles() + clearJars() + SparkEnv.set(null) + ShuffleMapTask.clearCache() + ResultTask.clearCache() + logInfo("Successfully stopped SparkContext") + } else { + logInfo("SparkContext already stopped") + } } /** @@ -518,6 +539,7 @@ class SparkContext( val start = System.nanoTime val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") + rdd.doCheckpoint() result } @@ -574,6 +596,26 @@ class SparkContext( return f } + /** + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. + */ + def setCheckpointDir(dir: String, useExisting: Boolean = false) { + val path = new Path(dir) + val fs = path.getFileSystem(new Configuration()) + if (!useExisting) { + if (fs.exists(path)) { + throw new Exception("Checkpoint directory '" + path + "' already exists.") + } else { + fs.mkdirs(path) + } + } + checkpointDir = Some(dir) + } + /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */ def defaultParallelism: Int = taskScheduler.defaultParallelism @@ -595,6 +637,7 @@ class SparkContext( * various Spark features. */ object SparkContext { + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 272d7cdad3..2a7a8af83d 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -22,24 +22,19 @@ class SparkEnv ( val actorSystem: ActorSystem, val serializer: Serializer, val closureSerializer: Serializer, - val cacheTracker: CacheTracker, + val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, val shuffleFetcher: ShuffleFetcher, val broadcastManager: BroadcastManager, val blockManager: BlockManager, val connectionManager: ConnectionManager, - val httpFileServer: HttpFileServer + val httpFileServer: HttpFileServer, + val sparkFilesDir: String ) { - /** No-parameter constructor for unit tests. */ - def this() = { - this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null) - } - def stop() { httpFileServer.stop() mapOutputTracker.stop() - cacheTracker.stop() shuffleFetcher.stop() broadcastManager.stop() blockManager.stop() @@ -86,10 +81,13 @@ object SparkEnv extends Logging { } val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") - - val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal) + + val masterIp: String = System.getProperty("spark.master.host", "localhost") + val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt + val blockManagerMaster = new BlockManagerMaster( + actorSystem, isMaster, isLocal, masterIp, masterPort) val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer) - + val connectionManager = blockManager.connectionManager val broadcastManager = new BroadcastManager(isMaster) @@ -97,18 +95,26 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClass[Serializer]( "spark.closure.serializer", "spark.JavaSerializer") - val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager) - blockManager.cacheTracker = cacheTracker + val cacheManager = new CacheManager(blockManager) val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") - + val httpFileServer = new HttpFileServer() httpFileServer.initialize() System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) + // Set the sparkFiles directory, used when downloading dependencies. In local mode, + // this is a temporary directory; in distributed mode, this is the executor's current working + // directory. + val sparkFilesDir: String = if (isMaster) { + Utils.createTempDir().getAbsolutePath + } else { + "." + } + // Warn about deprecated spark.cache.class property if (System.getProperty("spark.cache.class") != null) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -119,12 +125,13 @@ object SparkEnv extends Logging { actorSystem, serializer, closureSerializer, - cacheTracker, + cacheManager, mapOutputTracker, shuffleFetcher, broadcastManager, blockManager, connectionManager, - httpFileServer) + httpFileServer, + sparkFilesDir) } } diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java new file mode 100644 index 0000000000..566aec622c --- /dev/null +++ b/core/src/main/scala/spark/SparkFiles.java @@ -0,0 +1,25 @@ +package spark; + +import java.io.File; + +/** + * Resolves paths to files added through `SparkContext.addFile()`. + */ +public class SparkFiles { + + private SparkFiles() {} + + /** + * Get the absolute path of a file added through `SparkContext.addFile()`. + */ + public static String get(String filename) { + return new File(getRootDirectory(), filename).getAbsolutePath(); + } + + /** + * Get the root directory that contains files added through `SparkContext.addFile()`. + */ + public static String getRootDirectory() { + return SparkEnv.get().sparkFilesDir(); + } +} diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala index d2746b26b3..eab85f85a2 100644 --- a/core/src/main/scala/spark/TaskContext.scala +++ b/core/src/main/scala/spark/TaskContext.scala @@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable { - @transient - val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] // Add a callback function to be executed on task completion. An example use // is for HadoopRDD to register a callback to close the input stream. diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 6d64b32174..ae77264372 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,7 +1,7 @@ package spark import java.io._ -import java.net.{NetworkInterface, InetAddress, URL, URI} +import java.net.{NetworkInterface, InetAddress, Inet4Address, URL, URI} import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} import org.apache.hadoop.conf.Configuration @@ -9,6 +9,8 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.io.Source +import com.google.common.io.Files +import com.google.common.util.concurrent.ThreadFactoryBuilder /** * Various utility methods used by Spark. @@ -110,48 +112,56 @@ private object Utils extends Logging { } } - /** Copy a file on the local file system */ - def copyFile(source: File, dest: File) { - val in = new FileInputStream(source) - val out = new FileOutputStream(dest) - copyStream(in, out, true) - } - - /** Download a file from a given URL to the local filesystem */ - def downloadFile(url: URL, localPath: String) { - val in = url.openStream() - val out = new FileOutputStream(localPath) - Utils.copyStream(in, out, true) - } - /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + * + * Throws SparkException if the target file already exists and has different contents than + * the requested file. */ def fetchFile(url: String, targetDir: File) { val filename = url.split("/").last + val tempDir = getLocalDir + val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) uri.getScheme match { case "http" | "https" | "ftp" => - logInfo("Fetching " + url + " to " + targetFile) + logInfo("Fetching " + url + " to " + tempFile) val in = new URL(url).openStream() - val out = new FileOutputStream(targetFile) + val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) + if (targetFile.exists && !Files.equal(tempFile, targetFile)) { + tempFile.delete() + throw new SparkException("File " + targetFile + " exists and does not match contents of" + + " " + url) + } else { + Files.move(tempFile, targetFile) + } case "file" | null => - // Remove the file if it already exists - targetFile.delete() - // Symlink the file locally. - if (uri.isAbsolute) { - // url is absolute, i.e. it starts with "file:///". Extract the source - // file's absolute path from the url. - val sourceFile = new File(uri) - logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) - FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath) + val sourceFile = if (uri.isAbsolute) { + new File(uri) + } else { + new File(url) + } + if (targetFile.exists && !Files.equal(sourceFile, targetFile)) { + throw new SparkException("File " + targetFile + " exists and does not match contents of" + + " " + url) } else { - // url is not absolute, i.e. itself is the path to the source file. - logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath) - FileUtil.symLink(url, targetFile.getAbsolutePath) + // Remove the file if it already exists + targetFile.delete() + // Symlink the file locally. + if (uri.isAbsolute) { + // url is absolute, i.e. it starts with "file:///". Extract the source + // file's absolute path from the url. + val sourceFile = new File(uri) + logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) + FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath) + } else { + // url is not absolute, i.e. itself is the path to the source file. + logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath) + FileUtil.symLink(url, targetFile.getAbsolutePath) + } } case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others @@ -159,8 +169,15 @@ private object Utils extends Logging { val conf = new Configuration() val fs = FileSystem.get(uri, conf) val in = fs.open(new Path(uri)) - val out = new FileOutputStream(targetFile) + val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) + if (targetFile.exists && !Files.equal(tempFile, targetFile)) { + tempFile.delete() + throw new SparkException("File " + targetFile + " exists and does not match contents of" + + " " + url) + } else { + Files.move(tempFile, targetFile) + } } // Decompress the file if it's a .tar or .tar.gz if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { @@ -171,7 +188,16 @@ private object Utils extends Logging { Utils.execute(Seq("tar", "-xf", filename), targetDir) } // Make the file executable - That's necessary for scripts - FileUtil.chmod(filename, "a+x") + FileUtil.chmod(targetFile.getAbsolutePath, "a+x") + } + + /** + * Get a temporary directory using Spark's spark.local.dir property, if set. This will always + * return a single directory, even though the spark.local.dir property might be a list of + * multiple paths. + */ + def getLocalDir: String = { + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) } /** @@ -212,7 +238,8 @@ private object Utils extends Logging { // Address resolves to something like 127.0.1.1, which happens on Debian; try to find // a better address using the local network interfaces for (ni <- NetworkInterface.getNetworkInterfaces) { - for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) { + for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && + !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) { // We've found an address that looks reasonable! logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" + " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress + @@ -247,48 +274,28 @@ private object Utils extends Logging { customHostname.getOrElse(InetAddress.getLocalHost.getHostName) } - /** - * Returns a standard ThreadFactory except all threads are daemons. - */ - private def newDaemonThreadFactory: ThreadFactory = { - new ThreadFactory { - def newThread(r: Runnable): Thread = { - var t = Executors.defaultThreadFactory.newThread (r) - t.setDaemon (true) - return t - } - } - } + private[spark] val daemonThreadFactory: ThreadFactory = + new ThreadFactoryBuilder().setDaemon(true).build() /** * Wrapper over newCachedThreadPool. */ - def newDaemonCachedThreadPool(): ThreadPoolExecutor = { - var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory (newDaemonThreadFactory) - - return threadPool - } + def newDaemonCachedThreadPool(): ThreadPoolExecutor = + Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] /** * Return the string to tell how long has passed in seconds. The passing parameter should be in * millisecond. */ def getUsedTimeMs(startTimeMs: Long): String = { - return " " + (System.currentTimeMillis - startTimeMs) + " ms " + return " " + (System.currentTimeMillis - startTimeMs) + " ms" } /** * Wrapper over newFixedThreadPool. */ - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = { - var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor] - - threadPool.setThreadFactory(newDaemonThreadFactory) - - return threadPool - } + def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = + Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] /** * Delete a file or directory and its contents recursively. diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index 5c2be534ff..8ce32e0e2f 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -471,6 +471,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x) fromRDD(new OrderedRDDFunctions(rdd).sortByKey(ascending)) } + + /** + * Return an RDD with the keys of each tuple. + */ + def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1)) + + /** + * Return an RDD with the values of each tuple. + */ + def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2)) } object JavaPairRDD { diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 81d3a94466..b3698ffa44 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -9,6 +9,7 @@ import spark.api.java.JavaPairRDD._ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import spark.partial.{PartialResult, BoundedDouble} import spark.storage.StorageLevel +import com.google.common.base.Optional trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { @@ -298,4 +299,36 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Save this RDD as a SequenceFile of serialized objects. */ def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) + + /** + * Creates tuples of the elements in this RDD by applying `f`. + */ + def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = { + implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] + JavaPairRDD.fromRDD(rdd.keyBy(f)) + } + + /** + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in + * memory, otherwise saving it on a file will require recomputation. + */ + def checkpoint() = rdd.checkpoint() + + /** + * Return whether this RDD has been checkpointed or not + */ + def isCheckpointed(): Boolean = rdd.isCheckpointed() + + /** + * Gets the name of the file to which this RDD was checkpointed + */ + def getCheckpointFile(): Optional[String] = { + rdd.getCheckpointFile match { + case Some(file) => Optional.of(file) + case _ => Optional.absent() + } + } } diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index edbb187b1b..50b8970cd8 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import spark.{Accumulator, AccumulatorParam, RDD, SparkContext} +import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext} import spark.SparkContext.IntAccumulatorParam import spark.SparkContext.DoubleAccumulatorParam import spark.broadcast.Broadcast @@ -265,26 +265,46 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def intAccumulator(initialValue: Int): Accumulator[Int] = - sc.accumulator(initialValue)(IntAccumulatorParam) + def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] /** * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def doubleAccumulator(initialValue: Double): Accumulator[Double] = - sc.accumulator(initialValue)(DoubleAccumulatorParam) + def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = + sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] + + /** + * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue) + + /** + * Create an [[spark.Accumulator]] double variable, which tasks can "add" values + * to using the `add` method. Only the master can access the accumulator's `value`. + */ + def accumulator(initialValue: Double): Accumulator[java.lang.Double] = + doubleAccumulator(initialValue) /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) /** + * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can + * "add" values with `add`. Only the master can access the accumuable's `value`. + */ + def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = + sc.accumulable(initialValue)(param) + + /** * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ @@ -301,6 +321,75 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * (in that order of preference). If neither of these is set, return None. */ def getSparkHome(): Option[String] = sc.getSparkHome() + + /** + * Add a file to be downloaded with this Spark job on every node. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(path)` to find its download location. + */ + def addFile(path: String) { + sc.addFile(path) + } + + /** + * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. + */ + def addJar(path: String) { + sc.addJar(path) + } + + /** + * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to + * any new nodes. + */ + def clearJars() { + sc.clearJars() + } + + /** + * Clear the job's list of files added by `addFile` so that they do not get downloaded to + * any new nodes. + */ + def clearFiles() { + sc.clearFiles() + } + + /** + * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. + */ + def hadoopConfiguration(): Configuration = { + sc.hadoopConfiguration + } + + /** + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists and useExisting is set to true, then the + * exisiting directory will be used. Otherwise an exception will be thrown to + * prevent accidental overriding of checkpoint files in the existing directory. + */ + def setCheckpointDir(dir: String, useExisting: Boolean) { + sc.setCheckpointDir(dir, useExisting) + } + + /** + * Set the directory under which RDDs are going to be checkpointed. The directory must + * be a HDFS path if running on a cluster. If the directory does not exist, it will + * be created. If the directory exists, an exception will be thrown to prevent accidental + * overriding of checkpoint files. + */ + def setCheckpointDir(dir: String) { + sc.setCheckpointDir(dir) + } + + protected def checkpointFile[T](path: String): JavaRDD[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + new JavaRDD(sc.checkpointFile(path)) + } } object JavaSparkContext { diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java index 722af3c06c..5e5845ac3a 100644 --- a/core/src/main/scala/spark/api/java/StorageLevels.java +++ b/core/src/main/scala/spark/api/java/StorageLevels.java @@ -17,4 +17,15 @@ public class StorageLevels { public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2); public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1); public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2); + + /** + * Create a new StorageLevel object. + * @param useDisk saved to disk, if true + * @param useMemory saved to memory, if true + * @param deserialized saved as deserialized objects, if true + * @param replication replication factor + */ + public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) { + return StorageLevel.apply(useDisk, useMemory, deserialized, replication); + } } diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala new file mode 100644 index 0000000000..519e310323 --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -0,0 +1,48 @@ +package spark.api.python + +import spark.Partitioner + +import java.util.Arrays + +/** + * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API. + * + * Stores the unique id() of the Python-side partitioning function so that it is incorporated into + * equality comparisons. Correctness requires that the id is a unique identifier for the + * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning + * function). This can be ensured by using the Python id() function and maintaining a reference + * to the Python partitioning function so that its id() is not reused. + */ +private[spark] class PythonPartitioner( + override val numPartitions: Int, + val pyPartitionFunctionId: Long) + extends Partitioner { + + override def getPartition(key: Any): Int = { + if (key == null) { + return 0 + } + else { + val hashCode = { + if (key.isInstanceOf[Array[Byte]]) { + Arrays.hashCode(key.asInstanceOf[Array[Byte]]) + } else { + key.hashCode() + } + } + val mod = hashCode % numPartitions + if (mod < 0) { + mod + numPartitions + } else { + mod // Guard against negative hash codes + } + } + } + + override def equals(other: Any): Boolean = other match { + case h: PythonPartitioner => + h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId + case _ => + false + } +} diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala new file mode 100644 index 0000000000..f43a152ca7 --- /dev/null +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -0,0 +1,293 @@ +package spark.api.python + +import java.io._ +import java.net._ +import java.util.{List => JList, ArrayList => JArrayList, Collections} + +import scala.collection.JavaConversions._ +import scala.io.Source + +import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import spark.broadcast.Broadcast +import spark._ +import spark.rdd.PipedRDD + + +private[spark] class PythonRDD[T: ClassManifest]( + parent: RDD[T], + command: Seq[String], + envVars: java.util.Map[String, String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) + extends RDD[Array[Byte]](parent) { + + // Similar to Runtime.exec(), if we are given a single string, split it into words + // using a standard StringTokenizer (i.e. by spaces) + def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) = + this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, + broadcastVars, accumulator) + + override def getSplits = parent.splits + + override val partitioner = if (preservePartitoning) parent.partitioner else None + + override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = { + val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") + + val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py")) + // Add the environmental variables to the process. + val currentEnvVars = pb.environment() + + for ((variable, value) <- envVars) { + currentEnvVars.put(variable, value) + } + + val proc = pb.start() + val env = SparkEnv.get + + // Start a thread to print the process's stderr to ours + new Thread("stderr reader for " + command) { + override def run() { + for (line <- Source.fromInputStream(proc.getErrorStream).getLines) { + System.err.println(line) + } + } + }.start() + + // Start a thread to feed the process input from our parent's iterator + new Thread("stdin writer for " + command) { + override def run() { + SparkEnv.set(env) + val out = new PrintWriter(proc.getOutputStream) + val dOut = new DataOutputStream(proc.getOutputStream) + // Split index + dOut.writeInt(split.index) + // sparkFilesDir + PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut) + // Broadcast variables + dOut.writeInt(broadcastVars.length) + for (broadcast <- broadcastVars) { + dOut.writeLong(broadcast.id) + dOut.writeInt(broadcast.value.length) + dOut.write(broadcast.value) + dOut.flush() + } + // Serialized user code + for (elem <- command) { + out.println(elem) + } + out.flush() + // Data values + for (elem <- parent.iterator(split, context)) { + PythonRDD.writeAsPickle(elem, dOut) + } + dOut.flush() + out.flush() + proc.getOutputStream.close() + } + }.start() + + // Return an iterator that read lines from the process's stdout + val stream = new DataInputStream(proc.getInputStream) + return new Iterator[Array[Byte]] { + def next(): Array[Byte] = { + val obj = _nextObj + _nextObj = read() + obj + } + + private def read(): Array[Byte] = { + try { + val length = stream.readInt() + if (length != -1) { + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + } else { + // We've finished the data section of the output, but we can still read some + // accumulator updates; let's do that, breaking when we get EOFException + while (true) { + val len2 = stream.readInt() + val update = new Array[Byte](len2) + stream.readFully(update) + accumulator += Collections.singletonList(update) + } + new Array[Byte](0) + } + } catch { + case eof: EOFException => { + val exitStatus = proc.waitFor() + if (exitStatus != 0) { + throw new Exception("Subprocess exited with status " + exitStatus) + } + new Array[Byte](0) + } + case e => throw e + } + } + + var _nextObj = read() + + def hasNext = _nextObj.length != 0 + } + } + + val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) +} + +/** + * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. + * This is used by PySpark's shuffle operations. + */ +private class PairwiseRDD(prev: RDD[Array[Byte]]) extends + RDD[(Array[Byte], Array[Byte])](prev) { + override def getSplits = prev.splits + override def compute(split: Split, context: TaskContext) = + prev.iterator(split, context).grouped(2).map { + case Seq(a, b) => (a, b) + case x => throw new Exception("PairwiseRDD: unexpected value: " + x) + } + val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) +} + +private[spark] object PythonRDD { + + /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ + def stripPickle(arr: Array[Byte]) : Array[Byte] = { + arr.slice(2, arr.length - 1) + } + + /** + * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. + * The data format is a 32-bit integer representing the pickled object's length (in bytes), + * followed by the pickled data. + * + * Pickle module: + * + * http://docs.python.org/2/library/pickle.html + * + * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: + * + * http://hg.python.org/cpython/file/2.6/Lib/pickle.py + * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py + * + * @param elem the object to write + * @param dOut a data output stream + */ + def writeAsPickle(elem: Any, dOut: DataOutputStream) { + if (elem.isInstanceOf[Array[Byte]]) { + val arr = elem.asInstanceOf[Array[Byte]] + dOut.writeInt(arr.length) + dOut.write(arr) + } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { + val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] + val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(PythonRDD.stripPickle(t._1)) + dOut.write(PythonRDD.stripPickle(t._2)) + dOut.writeByte(Pickle.TUPLE2) + dOut.writeByte(Pickle.STOP) + } else if (elem.isInstanceOf[String]) { + // For uniformity, strings are wrapped into Pickles. + val s = elem.asInstanceOf[String].getBytes("UTF-8") + val length = 2 + 1 + 4 + s.length + 1 + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(Pickle.BINUNICODE) + dOut.writeInt(Integer.reverseBytes(s.length)) + dOut.write(s) + dOut.writeByte(Pickle.STOP) + } else { + throw new Exception("Unexpected RDD type") + } + } + + def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + JavaRDD[Array[Byte]] = { + val file = new DataInputStream(new FileInputStream(filename)) + val objs = new collection.mutable.ArrayBuffer[Array[Byte]] + try { + while (true) { + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + objs.append(obj) + } + } catch { + case eof: EOFException => {} + case e => throw e + } + JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) + } + + def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { + val file = new DataOutputStream(new FileOutputStream(filename)) + for (item <- items) { + writeAsPickle(item, file) + } + file.close() + } + + def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] = + rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head +} + +private object Pickle { + val PROTO: Byte = 0x80.toByte + val TWO: Byte = 0x02.toByte + val BINUNICODE: Byte = 'X' + val STOP: Byte = '.' + val TUPLE2: Byte = 0x86.toByte + val EMPTY_LIST: Byte = ']' + val MARK: Byte = '(' + val APPENDS: Byte = 'e' +} + +private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { + override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") +} + +/** + * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it + * collects a list of pickled strings that we pass to Python through a socket. + */ +class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) + extends AccumulatorParam[JList[Array[Byte]]] { + + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList + + override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) + : JList[Array[Byte]] = { + if (serverHost == null) { + // This happens on the worker node, where we just want to remember all the updates + val1.addAll(val2) + val1 + } else { + // This happens on the master, where we pass the updates to Python through a socket + val socket = new Socket(serverHost, serverPort) + val in = socket.getInputStream + val out = new DataOutputStream(socket.getOutputStream) + out.writeInt(val2.size) + for (array <- val2) { + out.writeInt(array.length) + out.write(array) + } + out.flush() + // Wait for a byte from the Python side as an acknowledgement + val byteRead = in.read() + if (byteRead == -1) { + throw new SparkException("EOF reached before Python server acknowledged") + } + socket.close() + null + } + } +} diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index 6055bfd045..2ffe7f741d 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong import spark._ -abstract class Broadcast[T](id: Long) extends Serializable { +abstract class Broadcast[T](private[spark] val id: Long) extends Serializable { def value: T // We cannot have an abstract readObject here due to some weird issues with diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 7eb4ddb74f..8e490e6bad 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -11,6 +11,7 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import spark._ import spark.storage.StorageLevel +import util.{MetadataCleaner, TimeStampedHashSet} private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { @@ -64,6 +65,10 @@ private object HttpBroadcast extends Logging { private var serverUri: String = null private var server: HttpServer = null + private val files = new TimeStampedHashSet[String] + private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) + + def initialize(isMaster: Boolean) { synchronized { if (!initialized) { @@ -85,11 +90,12 @@ private object HttpBroadcast extends Logging { server = null } initialized = false + cleaner.cancel() } } private def createServer() { - broadcastDir = Utils.createTempDir() + broadcastDir = Utils.createTempDir(Utils.getLocalDir) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri @@ -108,6 +114,7 @@ private object HttpBroadcast extends Logging { val serOut = ser.serializeStream(out) serOut.writeObject(value) serOut.close() + files += file.getAbsolutePath } def read[T](id: Long): T = { @@ -123,4 +130,21 @@ private object HttpBroadcast extends Logging { serIn.close() obj } + + def cleanup(cleanupTime: Long) { + val iterator = files.internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val (file, time) = (entry.getKey, entry.getValue) + if (time < cleanupTime) { + try { + iterator.remove() + new File(file.toString).delete() + logInfo("Deleted broadcast file '" + file + "'") + } catch { + case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) + } + } + } + } } diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 457122745b..35f40c6e91 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -4,7 +4,6 @@ import spark.deploy.ExecutorState.ExecutorState import spark.deploy.master.{WorkerInfo, JobInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List -import scala.collection.mutable.HashMap private[spark] sealed trait DeployMessage extends Serializable @@ -42,7 +41,8 @@ private[spark] case class LaunchExecutor( execId: Int, jobDesc: JobDescription, cores: Int, - memory: Int) + memory: Int, + sparkHome: String) extends DeployMessage diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala index 20879c5f11..7160fc05fc 100644 --- a/core/src/main/scala/spark/deploy/JobDescription.scala +++ b/core/src/main/scala/spark/deploy/JobDescription.scala @@ -4,7 +4,8 @@ private[spark] class JobDescription( val name: String, val cores: Int, val memoryPerSlave: Int, - val command: Command) + val command: Command, + val sparkHome: String) extends Serializable { val user = System.getProperty("user.name", "<unknown>") diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala new file mode 100644 index 0000000000..732fa08064 --- /dev/null +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -0,0 +1,78 @@ +package spark.deploy + +import master.{JobInfo, WorkerInfo} +import worker.ExecutorRunner +import cc.spray.json._ + +/** + * spray-json helper class containing implicit conversion to json for marshalling responses + */ +private[spark] object JsonProtocol extends DefaultJsonProtocol { + implicit object WorkerInfoJsonFormat extends RootJsonWriter[WorkerInfo] { + def write(obj: WorkerInfo) = JsObject( + "id" -> JsString(obj.id), + "host" -> JsString(obj.host), + "webuiaddress" -> JsString(obj.webUiAddress), + "cores" -> JsNumber(obj.cores), + "coresused" -> JsNumber(obj.coresUsed), + "memory" -> JsNumber(obj.memory), + "memoryused" -> JsNumber(obj.memoryUsed) + ) + } + + implicit object JobInfoJsonFormat extends RootJsonWriter[JobInfo] { + def write(obj: JobInfo) = JsObject( + "starttime" -> JsNumber(obj.startTime), + "id" -> JsString(obj.id), + "name" -> JsString(obj.desc.name), + "cores" -> JsNumber(obj.desc.cores), + "user" -> JsString(obj.desc.user), + "memoryperslave" -> JsNumber(obj.desc.memoryPerSlave), + "submitdate" -> JsString(obj.submitDate.toString)) + } + + implicit object JobDescriptionJsonFormat extends RootJsonWriter[JobDescription] { + def write(obj: JobDescription) = JsObject( + "name" -> JsString(obj.name), + "cores" -> JsNumber(obj.cores), + "memoryperslave" -> JsNumber(obj.memoryPerSlave), + "user" -> JsString(obj.user) + ) + } + + implicit object ExecutorRunnerJsonFormat extends RootJsonWriter[ExecutorRunner] { + def write(obj: ExecutorRunner) = JsObject( + "id" -> JsNumber(obj.execId), + "memory" -> JsNumber(obj.memory), + "jobid" -> JsString(obj.jobId), + "jobdesc" -> obj.jobDesc.toJson.asJsObject + ) + } + + implicit object MasterStateJsonFormat extends RootJsonWriter[MasterState] { + def write(obj: MasterState) = JsObject( + "url" -> JsString("spark://" + obj.uri), + "workers" -> JsArray(obj.workers.toList.map(_.toJson)), + "cores" -> JsNumber(obj.workers.map(_.cores).sum), + "coresused" -> JsNumber(obj.workers.map(_.coresUsed).sum), + "memory" -> JsNumber(obj.workers.map(_.memory).sum), + "memoryused" -> JsNumber(obj.workers.map(_.memoryUsed).sum), + "activejobs" -> JsArray(obj.activeJobs.toList.map(_.toJson)), + "completedjobs" -> JsArray(obj.completedJobs.toList.map(_.toJson)) + ) + } + + implicit object WorkerStateJsonFormat extends RootJsonWriter[WorkerState] { + def write(obj: WorkerState) = JsObject( + "id" -> JsString(obj.workerId), + "masterurl" -> JsString(obj.masterUrl), + "masterwebuiurl" -> JsString(obj.masterWebUiUrl), + "cores" -> JsNumber(obj.cores), + "coresused" -> JsNumber(obj.coresUsed), + "memory" -> JsNumber(obj.memory), + "memoryused" -> JsNumber(obj.memoryUsed), + "executors" -> JsArray(obj.executors.toList.map(_.toJson)), + "finishedexecutors" -> JsArray(obj.finishedExecutors.toList.map(_.toJson)) + ) + } +} diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index 57a7e123b7..8764c400e2 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -25,7 +25,7 @@ private[spark] object TestClient { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) val desc = new JobDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map())) + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home") val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) client.start() diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index b30c8e99b5..2c2cd0231b 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -156,7 +156,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor if (spreadOutJobs) { // Try to spread out each job among all the nodes, until it has all its cores for (job <- waitingJobs if job.coresLeft > 0) { - val usableWorkers = workers.toArray.filter(canUse(job, _)).sortBy(_.coresFree).reverse + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(canUse(job, _)).sortBy(_.coresFree).reverse val numUsable = usableWorkers.length val assigned = new Array[Int](numUsable) // Number of cores to give on each node var toAssign = math.min(job.coresLeft, usableWorkers.map(_.coresFree).sum) @@ -172,7 +173,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor for (pos <- 0 until numUsable) { if (assigned(pos) > 0) { val exec = job.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec) + launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome) job.state = JobState.RUNNING } } @@ -185,7 +186,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor val coresToUse = math.min(worker.coresFree, job.coresLeft) if (coresToUse > 0) { val exec = job.addExecutor(worker, coresToUse) - launchExecutor(worker, exec) + launchExecutor(worker, exec, job.desc.sparkHome) job.state = JobState.RUNNING } } @@ -194,15 +195,17 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } } - def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) { + def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory) + worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, publicAddress: String): WorkerInfo = { + // There may be one or more refs to dead workers on this same node (w/ different ID's), remove them. + workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _) val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) workers += worker idToWorker(worker.id) = worker @@ -213,7 +216,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor def removeWorker(worker: WorkerInfo) { logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) - workers -= worker + worker.setState(WorkerState.DEAD) idToWorker -= worker.id actorToWorker -= worker.actor addressToWorker -= worker.actor.path.address diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 3cdd3721f5..458ee2d665 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -8,7 +8,11 @@ import akka.util.duration._ import cc.spray.Directives import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ +import cc.spray.http.MediaTypes +import cc.spray.typeconversion.SprayJsonSupport._ + import spark.deploy._ +import spark.deploy.JsonProtocol._ private[spark] class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { @@ -19,29 +23,51 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct val handler = { get { - path("") { - completeWith { + (path("") & parameters('format ?)) { + case Some(js) if js.equalsIgnoreCase("json") => val future = master ? RequestMasterState - future.map { - masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState]) + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(future.mapTo[MasterState]) + } + case _ => + completeWith { + val future = master ? RequestMasterState + future.map { + masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState]) + } } - } } ~ path("job") { - parameter("jobId") { jobId => - completeWith { + parameters("jobId", 'format ?) { + case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) => val future = master ? RequestMasterState - future.map { state => - val masterState = state.asInstanceOf[MasterState] - - // A bit ugly an inefficient, but we won't have a number of jobs - // so large that it will make a significant difference. - (masterState.activeJobs ++ masterState.completedJobs).find(_.id == jobId) match { - case Some(job) => spark.deploy.master.html.job_details.render(job) - case _ => null + val jobInfo = for (masterState <- future.mapTo[MasterState]) yield { + masterState.activeJobs.find(_.id == jobId) match { + case Some(job) => job + case _ => masterState.completedJobs.find(_.id == jobId) match { + case Some(job) => job + case _ => null + } + } + } + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(jobInfo.mapTo[JobInfo]) + } + case (jobId, _) => + completeWith { + val future = master ? RequestMasterState + future.map { state => + val masterState = state.asInstanceOf[MasterState] + + masterState.activeJobs.find(_.id == jobId) match { + case Some(job) => spark.deploy.master.html.job_details.render(job) + case _ => masterState.completedJobs.find(_.id == jobId) match { + case Some(job) => spark.deploy.master.html.job_details.render(job) + case _ => null + } + } } } - } } } ~ pathPrefix("static") { diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala index a0a698ef04..5a7f5fef8a 100644 --- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala @@ -14,7 +14,7 @@ private[spark] class WorkerInfo( val publicAddress: String) { var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info - + var state: WorkerState.Value = WorkerState.ALIVE var coresUsed = 0 var memoryUsed = 0 @@ -42,4 +42,8 @@ private[spark] class WorkerInfo( def webUiAddress : String = { "http://" + this.publicAddress + ":" + this.webUiPort } + + def setState(state: WorkerState.Value) = { + this.state = state + } } diff --git a/core/src/main/scala/spark/deploy/master/WorkerState.scala b/core/src/main/scala/spark/deploy/master/WorkerState.scala new file mode 100644 index 0000000000..0bf35014c8 --- /dev/null +++ b/core/src/main/scala/spark/deploy/master/WorkerState.scala @@ -0,0 +1,7 @@ +package spark.deploy.master + +private[spark] object WorkerState extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED") { + type WorkerState = Value + + val ALIVE, DEAD, DECOMMISSIONED = Value +} diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index beceb55ecd..0d1fe2a6b4 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -106,11 +106,6 @@ private[spark] class ExecutorRunner( throw new IOException("Failed to create directory " + executorDir) } - // Download the files it depends on into it (disabled for now) - //for (url <- jobDesc.fileUrls) { - // fetchFile(url, executorDir) - //} - // Launch the process val command = buildCommandSeq() val builder = new ProcessBuilder(command: _*).directory(executorDir) diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 7c9e588ea2..19bf2be118 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -119,10 +119,10 @@ private[spark] class Worker( logError("Worker registration failed: " + message) System.exit(1) - case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) => + case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) => logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name)) val manager = new ExecutorRunner( - jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, sparkHome, workDir) + jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir) executors(jobId + "/" + execId) = manager manager.start() coresUsed += cores_ diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index 340920025b..37524a7c82 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -104,9 +104,25 @@ private[spark] class WorkerArguments(args: Array[String]) { } def inferDefaultMemory(): Int = { - val bean = ManagementFactory.getOperatingSystemMXBean - .asInstanceOf[com.sun.management.OperatingSystemMXBean] - val totalMb = (bean.getTotalPhysicalMemorySize / 1024 / 1024).toInt + val ibmVendor = System.getProperty("java.vendor").contains("IBM") + var totalMb = 0 + try { + val bean = ManagementFactory.getOperatingSystemMXBean() + if (ibmVendor) { + val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean") + val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory") + totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt + } else { + val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean") + val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize") + totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt + } + } catch { + case e: Exception => { + totalMb = 2*1024 + System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") + } + } // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, 512) } diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index d06f4884ee..f9489d99fc 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -7,7 +7,11 @@ import akka.util.Timeout import akka.util.duration._ import cc.spray.Directives import cc.spray.typeconversion.TwirlSupport._ +import cc.spray.http.MediaTypes +import cc.spray.typeconversion.SprayJsonSupport._ + import spark.deploy.{WorkerState, RequestWorkerState} +import spark.deploy.JsonProtocol._ private[spark] class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { @@ -18,13 +22,20 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct val handler = { get { - path("") { - completeWith{ + (path("") & parameters('format ?)) { + case Some(js) if js.equalsIgnoreCase("json") => { val future = worker ? RequestWorkerState - future.map { workerState => - spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState]) + respondWithMediaType(MediaTypes.`application/json`) { ctx => + ctx.complete(future.mapTo[WorkerState]) } } + case _ => + completeWith{ + val future = worker ? RequestWorkerState + future.map { workerState => + spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState]) + } + } } ~ path("log") { parameters("jobId", "executorId", "logType") { (jobId, executorId, logType) => diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 2552958d27..28d9d40d43 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -159,22 +159,24 @@ private[spark] class Executor extends Logging { * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) - currentFiles(name) = timestamp - } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL - if (!urlClassLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - urlClassLoader.addURL(url) + synchronized { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL + if (!urlClassLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } } } } diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 915f71ba9f..a29bf974d2 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -24,9 +24,6 @@ private[spark] class StandaloneExecutorBackend( with ExecutorBackend with Logging { - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) - var master: ActorRef = null override def preStart() { diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 80262ab7b4..c193bf7c8d 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -135,8 +135,11 @@ extends Connection(SocketChannel.open, selector_) { val chunk = message.getChunkForSending(defaultChunkSize) if (chunk.isDefined) { messages += message // this is probably incorrect, it wont work as fifo - if (!message.started) logDebug("Starting to send [" + message + "]") - message.started = true + if (!message.started) { + logDebug("Starting to send [" + message + "]") + message.started = true + message.startTime = System.currentTimeMillis + } return chunk } else { /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 642fa4b525..2ecd14f536 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -43,18 +43,17 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(4) + val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt) val serverChannel = ServerSocketChannel.open() val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new SynchronizedQueue[SendingConnection] + val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] val sendMessageRequests = new Queue[(Message, SendingConnection)] - implicit val futureExecContext = ExecutionContext.fromExecutor( - Executors.newCachedThreadPool(DaemonThreadFactory)) - + implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) + var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null serverChannel.configureBlocking(false) @@ -79,10 +78,10 @@ private[spark] class ConnectionManager(port: Int) extends Logging { def run() { try { while(!selectorThread.isInterrupted) { - while(!connectionRequests.isEmpty) { - val sendingConnection = connectionRequests.dequeue + for( (connectionManagerId, sendingConnection) <- connectionRequests) { sendingConnection.connect() addConnection(sendingConnection) + connectionRequests -= connectionManagerId } sendMessageRequests.synchronized { while(!sendMessageRequests.isEmpty) { @@ -300,8 +299,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = new SendingConnection(inetSocketAddress, selector) - connectionRequests += newConnection + val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector)) newConnection } val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress) @@ -473,6 +471,7 @@ private[spark] object ConnectionManager { val mb = size * count / 1024.0 / 1024.0 val ms = finishTime - startTime val tput = mb * 1000.0 / ms + println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") println("--------------------------") println() } diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala index 47ceaf3c07..533e4610f3 100644 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -13,8 +13,14 @@ import akka.util.duration._ private[spark] object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { + //<mesos cluster> - the master URL + //<slaves file> - a list slaves to run connectionTest on + //[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts + //[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10 + //[count] - how many times to run, default is 3 + //[await time in seconds] : await time (in seconds), default is 600 if (args.length < 2) { - println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>") + println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ") System.exit(1) } @@ -29,16 +35,19 @@ private[spark] object ConnectionManagerTest extends Logging{ /*println("Slaves")*/ /*slaves.foreach(println)*/ - - val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map( + val tasknum = if (args.length > 2) args(2).toInt else slaves.length + val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 + val count = if (args.length > 4) args(4).toInt else 3 + val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second + println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime) + val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( i => SparkEnv.get.connectionManager.id).collect() println("\nSlave ConnectionManagerIds") slaveConnManagerIds.foreach(println) println - val count = 10 (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => { + val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { val connManager = SparkEnv.get.connectionManager val thisConnManagerId = connManager.id connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { @@ -46,7 +55,6 @@ private[spark] object ConnectionManagerTest extends Logging{ None }) - val size = 100 * 1024 * 1024 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -56,13 +64,13 @@ private[spark] object ConnectionManagerTest extends Logging{ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) }) - val results = futures.map(f => Await.result(f, 1.second)) + val results = futures.map(f => Await.result(f, awaitTime)) val finishTime = System.currentTimeMillis Thread.sleep(5000) val mb = size * results.size / 1024.0 / 1024.0 val ms = finishTime - startTime - val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" + val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" logInfo(resultStr) resultStr }).collect() diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index f98528a183..2c022f88e0 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -1,9 +1,7 @@ package spark.rdd import scala.collection.mutable.HashMap - -import spark.{Dependency, RDD, SparkContext, SparkEnv, Split, TaskContext} - +import spark.{RDD, SparkContext, SparkEnv, Split, TaskContext} private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split { val index = idx @@ -11,22 +9,20 @@ private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split private[spark] class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) - extends RDD[T](sc) { + extends RDD[T](sc, Nil) { - @transient - val splits_ = (0 until blockIds.size).map(i => { + @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => { new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] }).toArray - @transient - lazy val locations_ = { + @transient lazy val locations_ = { val blockManager = SparkEnv.get.blockManager /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ val locations = blockManager.getLocations(blockIds) HashMap(blockIds.zip(locations):_*) } - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split, context: TaskContext): Iterator[T] = { val blockManager = SparkEnv.get.blockManager @@ -38,9 +34,11 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = + override def getPreferredLocations(split: Split) = locations_(split.asInstanceOf[BlockRDDSplit].blockId) - override val dependencies: List[Dependency[_]] = Nil + override def clearDependencies() { + splits_ = null + } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 4a7e5f3d06..453d410ad4 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,37 +1,53 @@ package spark.rdd -import spark.{NarrowDependency, RDD, SparkContext, Split, TaskContext} +import java.io.{ObjectOutputStream, IOException} +import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext} private[spark] -class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { +class CartesianSplit( + idx: Int, + @transient rdd1: RDD[_], + @transient rdd2: RDD[_], + s1Index: Int, + s2Index: Int + ) extends Split { + var s1 = rdd1.splits(s1Index) + var s2 = rdd2.splits(s2Index) override val index: Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + s1 = rdd1.splits(s1Index) + s2 = rdd2.splits(s2Index) + oos.defaultWriteObject() + } } private[spark] class CartesianRDD[T: ClassManifest, U:ClassManifest]( sc: SparkContext, - rdd1: RDD[T], - rdd2: RDD[U]) - extends RDD[Pair[T, U]](sc) + var rdd1 : RDD[T], + var rdd2 : RDD[U]) + extends RDD[Pair[T, U]](sc, Nil) with Serializable { val numSplitsInRdd2 = rdd2.splits.size - @transient - val splits_ = { + @transient var splits_ = { // create the cross product split val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) for (s1 <- rdd1.splits; s2 <- rdd2.splits) { val idx = s1.index * numSplitsInRdd2 + s2.index - array(idx) = new CartesianSplit(idx, s1, s2) + array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index) } array } - override def splits = splits_ + override def getSplits = splits_ - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { val currSplit = split.asInstanceOf[CartesianSplit] rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } @@ -42,7 +58,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) } - override val dependencies = List( + var deps_ = List( new NarrowDependency(rdd1) { def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) }, @@ -50,4 +66,13 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2) } ) + + override def getDependencies = deps_ + + override def clearDependencies() { + deps_ = Nil + splits_ = null + rdd1 = null + rdd2 = null + } } diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala new file mode 100644 index 0000000000..6f00f6ac73 --- /dev/null +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -0,0 +1,128 @@ +package spark.rdd + +import spark._ +import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.{NullWritable, BytesWritable} +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.fs.Path +import java.io.{File, IOException, EOFException} +import java.text.NumberFormat + +private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split { + override val index: Int = idx +} + +/** + * This RDD represents a RDD checkpoint file (similar to HadoopRDD). + */ +private[spark] +class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) + extends RDD[T](sc, Nil) { + + @transient val path = new Path(checkpointPath) + @transient val fs = path.getFileSystem(new Configuration()) + + @transient val splits_ : Array[Split] = { + val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted + splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray + } + + checkpointData = Some(new RDDCheckpointData[T](this)) + checkpointData.get.cpFile = Some(checkpointPath) + + override def getSplits = splits_ + + override def getPreferredLocations(split: Split): Seq[String] = { + val status = fs.getFileStatus(path) + val locations = fs.getFileBlockLocations(status, 0, status.getLen) + locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + } + + override def compute(split: Split, context: TaskContext): Iterator[T] = { + CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context) + } + + override def checkpoint() { + // Do nothing. Hadoop RDD should not be checkpointed. + } +} + +private[spark] object CheckpointRDD extends Logging { + + def splitIdToFileName(splitId: Int): String = { + val numfmt = NumberFormat.getInstance() + numfmt.setMinimumIntegerDigits(5) + numfmt.setGroupingUsed(false) + "part-" + numfmt.format(splitId) + } + + def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) { + val outputDir = new Path(path) + val fs = outputDir.getFileSystem(new Configuration()) + + val finalOutputName = splitIdToFileName(context.splitId) + val finalOutputPath = new Path(outputDir, finalOutputName) + val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId) + + if (fs.exists(tempOutputPath)) { + throw new IOException("Checkpoint failed: temporary path " + + tempOutputPath + " already exists") + } + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = SparkEnv.get.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + serializeStream.writeAll(iterator) + serializeStream.close() + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.delete(finalOutputPath, true)) { + throw new IOException("Checkpoint failed: failed to delete earlier output of task " + + context.attemptId) + } + if (!fs.rename(tempOutputPath, finalOutputPath)) { + throw new IOException("Checkpoint failed: failed to save output of task: " + + context.attemptId) + } + } + } + + def readFromFile[T](path: String, context: TaskContext): Iterator[T] = { + val inputPath = new Path(path) + val fs = inputPath.getFileSystem(new Configuration()) + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val fileInputStream = fs.open(inputPath, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback(() => deserializeStream.close()) + + deserializeStream.asIterator.asInstanceOf[Iterator[T]] + } + + // Test whether CheckpointRDD generate expected number of splits despite + // each split file having multiple blocks. This needs to be run on a + // cluster (mesos or standalone) using HDFS. + def main(args: Array[String]) { + import spark._ + + val Array(cluster, hdfsPath) = args + val sc = new SparkContext(cluster, "CheckpointRDD Test") + val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) + val path = new Path(hdfsPath, "temp") + val fs = path.getFileSystem(new Configuration()) + sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) + val cpRDD = new CheckpointRDD[Int](sc, path.toString) + assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same") + assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same") + fs.delete(path) + } +} diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index de0d9fad88..8fafd27bb6 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -1,14 +1,30 @@ package spark.rdd +import java.io.{ObjectOutputStream, IOException} +import java.util.{HashMap => JHashMap} +import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} private[spark] sealed trait CoGroupSplitDep extends Serializable -private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep + +private[spark] case class NarrowCoGroupSplitDep( + rdd: RDD[_], + splitIndex: Int, + var split: Split + ) extends CoGroupSplitDep { + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() + } +} + private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep private[spark] @@ -24,30 +40,29 @@ private[spark] class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging { +class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) + extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { val aggr = new CoGroupAggregator - @transient - override val dependencies = { + @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { - val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) - if (mapSideCombinedRDD.partitioner == Some(part)) { - logInfo("Adding one-to-one dependency with " + mapSideCombinedRDD) - deps += new OneToOneDependency(mapSideCombinedRDD) + if (rdd.partitioner == Some(part)) { + logInfo("Adding one-to-one dependency with " + rdd) + deps += new OneToOneDependency(rdd) } else { logInfo("Adding shuffle dependency with " + rdd) + val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) } } deps.toList } - @transient - val splits_ : Array[Split] = { - val firstRdd = rdds.head + override def getDependencies = deps_ + + @transient var splits_ : Array[Split] = { val array = new Array[Split](part.numPartitions) for (i <- 0 until array.size) { array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => @@ -55,28 +70,33 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) case s: ShuffleDependency[_, _] => new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep case _ => - new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep + new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep } }.toList) } array } - override def splits = splits_ - + override def getSplits = splits_ + override val partitioner = Some(part) - override def preferredLocations(s: Split) = Nil - override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] val numRdds = split.deps.size - val map = new HashMap[K, Seq[ArrayBuffer[Any]]] + val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) + val seq = map.get(k) + if (seq != null) { + seq + } else { + val seq = Array.fill(numRdds)(new ArrayBuffer[Any]) + map.put(k, seq) + seq + } } for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, itsSplit) => { + case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => { // Read them from the parent for ((k, v) <- rdd.iterator(itsSplit, context)) { getSeq(k.asInstanceOf[K])(depNum) += v @@ -93,6 +113,12 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair) } } - map.iterator + JavaConversions.mapAsScalaMap(map).iterator + } + + override def clearDependencies() { + deps_ = null + splits_ = null + rdds = null } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 1affe0e0ef..167755bbba 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -1,9 +1,22 @@ package spark.rdd -import spark.{NarrowDependency, RDD, Split, TaskContext} +import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Split, TaskContext} +import java.io.{ObjectOutputStream, IOException} +private[spark] case class CoalescedRDDSplit( + index: Int, + @transient rdd: RDD[_], + parentsIndices: Array[Int] + ) extends Split { + var parents: Seq[Split] = parentsIndices.map(rdd.splits(_)) -private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + parents = parentsIndices.map(rdd.splits(_)) + oos.defaultWriteObject() + } +} /** * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of @@ -13,34 +26,44 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten * This transformation is useful when an RDD with many partitions gets filtered into a smaller one, * or to avoid having a large number of small tasks when processing a directory with many files. */ -class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) - extends RDD[T](prev.context) { +class CoalescedRDD[T: ClassManifest]( + var prev: RDD[T], + maxPartitions: Int) + extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - @transient val splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val prevSplits = prev.splits if (prevSplits.length < maxPartitions) { - prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } + prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) } } else { (0 until maxPartitions).map { i => val rangeStart = (i * prevSplits.length) / maxPartitions val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions - new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd)) + new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray) }.toArray } } - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split, context: TaskContext): Iterator[T] = { - split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { - parentSplit => prev.iterator(parentSplit, context) + split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit => + firstParent[T].iterator(parentSplit, context) } } - val dependencies = List( + var deps_ : List[Dependency[_]] = List( new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = - splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) + splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices } ) + + override def getDependencies() = deps_ + + override def clearDependencies() { + deps_ = Nil + splits_ = null + prev = null + } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index b148da28de..6dbe235bd9 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -2,10 +2,15 @@ package spark.rdd import spark.{OneToOneDependency, RDD, Split, TaskContext} +private[spark] class FilteredRDD[T: ClassManifest]( + prev: RDD[T], + f: T => Boolean) + extends RDD[T](prev) { -private[spark] -class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f) -}
\ No newline at end of file + override def getSplits = firstParent[T].splits + + override val partitioner = prev.partitioner // Since filter cannot change a partition's keys + + override def compute(split: Split, context: TaskContext) = + firstParent[T].iterator(split, context).filter(f) +} diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 785662b2da..1b604c66e2 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -1,16 +1,16 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, Split, TaskContext} +import spark.{RDD, Split, TaskContext} + private[spark] class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], f: T => TraversableOnce[U]) - extends RDD[U](prev.context) { + extends RDD[U](prev) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) + override def getSplits = firstParent[T].splits override def compute(split: Split, context: TaskContext) = - prev.iterator(split, context).flatMap(f) + firstParent[T].iterator(split, context).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index fac8ffb4cb..051bffed19 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -1,12 +1,12 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, Split, TaskContext} +import spark.{RDD, Split, TaskContext} +private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) + extends RDD[Array[T]](prev) { + + override def getSplits = firstParent[T].splits -private[spark] -class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) override def compute(split: Split, context: TaskContext) = - Array(prev.iterator(split, context).toArray).iterator -}
\ No newline at end of file + Array(firstParent[T].iterator(split, context).toArray).iterator +} diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index ab163f569b..f547f53812 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -22,9 +22,8 @@ import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskCo * A Spark split class that wraps around a Hadoop InputSplit. */ private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) - extends Split - with Serializable { - + extends Split { + val inputSplit = new SerializableWritable[InputSplit](s) override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt @@ -43,7 +42,7 @@ class HadoopRDD[K, V]( keyClass: Class[K], valueClass: Class[V], minSplits: Int) - extends RDD[(K, V)](sc) { + extends RDD[(K, V)](sc, Nil) { // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it val confBroadcast = sc.broadcast(new SerializableWritable(conf)) @@ -64,7 +63,7 @@ class HadoopRDD[K, V]( .asInstanceOf[InputFormat[K, V]] } - override def splits = splits_ + override def getSplits = splits_ override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[HadoopSplit] @@ -110,11 +109,13 @@ class HadoopRDD[K, V]( } } - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { // TODO: Filtering out "localhost" in case of file:// URLs val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } - override val dependencies: List[Dependency[_]] = Nil + override def checkpoint() { + // Do nothing. Hadoop RDD should not be checkpointed. + } } diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index c764505345..073f7d7d2a 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -1,6 +1,6 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, Split, TaskContext} +import spark.{RDD, Split, TaskContext} private[spark] @@ -8,11 +8,13 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false) - extends RDD[U](prev.context) { + extends RDD[U](prev) { - override val partitioner = if (preservesPartitioning) prev.partitioner else None + override val partitioner = + if (preservesPartitioning) firstParent[T].partitioner else None - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split, context: TaskContext) = f(prev.iterator(split, context)) + override def getSplits = firstParent[T].splits + + override def compute(split: Split, context: TaskContext) = + f(firstParent[T].iterator(split, context)) }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index 3d9888bd34..2ddc3d01b6 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -1,6 +1,7 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, Split, TaskContext} +import spark.{RDD, Split, TaskContext} + /** * A variant of the MapPartitionsRDD that passes the split index into the @@ -11,12 +12,13 @@ private[spark] class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean) - extends RDD[U](prev.context) { + preservesPartitioning: Boolean + ) extends RDD[U](prev) { + + override def getSplits = firstParent[T].splits override val partitioner = if (preservesPartitioning) prev.partitioner else None - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split, context: TaskContext) = - f(split.index, prev.iterator(split, context)) + f(split.index, firstParent[T].iterator(split, context)) }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 70fa8f4497..c6ceb272cd 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -1,14 +1,15 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, Split, TaskContext} +import spark.{RDD, Split, TaskContext} private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], f: T => U) - extends RDD[U](prev.context) { + extends RDD[U](prev) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).map(f) + override def getSplits = firstParent[T].splits + + override def compute(split: Split, context: TaskContext) = + firstParent[T].iterator(split, context).map(f) }
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index 197ed5ea17..c3b155fcbd 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -20,11 +20,12 @@ class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit } class NewHadoopRDD[K, V]( - sc: SparkContext, + sc : SparkContext, inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], valueClass: Class[V], + keyClass: Class[K], + valueClass: Class[V], @transient conf: Configuration) - extends RDD[(K, V)](sc) + extends RDD[(K, V)](sc, Nil) with HadoopMapReduceUtil { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it @@ -36,11 +37,9 @@ class NewHadoopRDD[K, V]( formatter.format(new Date()) } - @transient - private val jobId = new JobID(jobtrackerId, id) + @transient private val jobId = new JobID(jobtrackerId, id) - @transient - private val splits_ : Array[Split] = { + @transient private val splits_ : Array[Split] = { val inputFormat = inputFormatClass.newInstance val jobContext = newJobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray @@ -51,7 +50,7 @@ class NewHadoopRDD[K, V]( result } - override def splits = splits_ + override def getSplits = splits_ override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] @@ -86,10 +85,8 @@ class NewHadoopRDD[K, V]( } } - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } - - override val dependencies: List[Dependency[_]] = Nil } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 336e193217..6631f83510 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -8,7 +8,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source -import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} +import spark.{RDD, SparkEnv, Split, TaskContext} /** @@ -16,18 +16,18 @@ import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext} * (printing them one per line) and returns the output as a collection of strings. */ class PipedRDD[T: ClassManifest]( - parent: RDD[T], command: Seq[String], envVars: Map[String, String]) - extends RDD[String](parent.context) { + prev: RDD[T], + command: Seq[String], + envVars: Map[String, String]) + extends RDD[String](prev) { - def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map()) + def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map()) // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command)) + def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) - override def splits = parent.splits - - override val dependencies = List(new OneToOneDependency(parent)) + override def getSplits = firstParent[T].splits override def compute(split: Split, context: TaskContext): Iterator[String] = { val pb = new ProcessBuilder(command) @@ -52,7 +52,7 @@ class PipedRDD[T: ClassManifest]( override def run() { SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) - for (elem <- parent.iterator(split, context)) { + for (elem <- firstParent[T].iterator(split, context)) { out.println(elem) } out.close() diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 6e4797aabb..e24ad23b21 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -1,11 +1,11 @@ package spark.rdd import java.util.Random + import cern.jet.random.Poisson import cern.jet.random.engine.DRand -import spark.{OneToOneDependency, RDD, Split, TaskContext} - +import spark.{RDD, Split, TaskContext} private[spark] class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { @@ -14,23 +14,20 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali class SampledRDD[T: ClassManifest]( prev: RDD[T], - withReplacement: Boolean, + withReplacement: Boolean, frac: Double, seed: Int) - extends RDD[T](prev.context) { + extends RDD[T](prev) { - @transient - val splits_ = { + @transient var splits_ : Array[Split] = { val rg = new Random(seed) - prev.splits.map(x => new SampledRDDSplit(x, rg.nextInt)) + firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } - override def splits = splits_.asInstanceOf[Array[Split]] - - override val dependencies = List(new OneToOneDependency(prev)) + override def getSplits = splits_ - override def preferredLocations(split: Split) = - prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) + override def getPreferredLocations(split: Split) = + firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) override def compute(splitIn: Split, context: TaskContext) = { val split = splitIn.asInstanceOf[SampledRDDSplit] @@ -38,7 +35,7 @@ class SampledRDD[T: ClassManifest]( // For large datasets, the expected number of occurrences of each element in a sample with // replacement is Poisson(frac). We use that to get a count for each element. val poisson = new Poisson(frac, new DRand(split.seed)) - prev.iterator(split.prev, context).flatMap { element => + firstParent[T].iterator(split.prev, context).flatMap { element => val count = poisson.nextInt() if (count == 0) { Iterator.empty // Avoid object allocation when we return 0 items, which is quite often @@ -48,7 +45,11 @@ class SampledRDD[T: ClassManifest]( } } else { // Sampling without replacement val rand = new Random(split.seed) - prev.iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) + firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) } } + + override def clearDependencies() { + splits_ = null + } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index f832633646..28ff19876d 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -1,7 +1,7 @@ package spark.rdd -import spark.{OneToOneDependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext} - +import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext} +import spark.SparkContext._ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { override val index = idx @@ -10,28 +10,28 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { /** * The resulting RDD from a shuffle (e.g. repartitioning of data). - * @param parent the parent RDD. + * @param prev the parent RDD. * @param part the partitioner used to partition the RDD * @tparam K the key class. * @tparam V the value class. */ class ShuffledRDD[K, V]( - @transient parent: RDD[(K, V)], - part: Partitioner) extends RDD[(K, V)](parent.context) { + prev: RDD[(K, V)], + part: Partitioner) + extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) { override val partitioner = Some(part) - @transient - val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) - - override def splits = splits_ - - override def preferredLocations(split: Split) = Nil + @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) - val dep = new ShuffleDependency(parent, part) - override val dependencies = List(dep) + override def getSplits = splits_ override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = { - SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index) + val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId + SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) + } + + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index a08473f7be..82f0a44ecd 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -1,43 +1,46 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer - import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext} +import java.io.{ObjectOutputStream, IOException} +private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) + extends Split { -private[spark] class UnionSplit[T: ClassManifest]( - idx: Int, - rdd: RDD[T], - split: Split) - extends Split - with Serializable { + var split: Split = rdd.splits(splitIndex) def iterator(context: TaskContext) = rdd.iterator(split, context) + def preferredLocations() = rdd.preferredLocations(split) + override val index: Int = idx + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + split = rdd.splits(splitIndex) + oos.defaultWriteObject() + } } class UnionRDD[T: ClassManifest]( sc: SparkContext, - @transient rdds: Seq[RDD[T]]) - extends RDD[T](sc) - with Serializable { + @transient var rdds: Seq[RDD[T]]) + extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs - @transient - val splits_ : Array[Split] = { + @transient var splits_ : Array[Split] = { val array = new Array[Split](rdds.map(_.splits.size).sum) var pos = 0 for (rdd <- rdds; split <- rdd.splits) { - array(pos) = new UnionSplit(pos, rdd, split) + array(pos) = new UnionSplit(pos, rdd, split.index) pos += 1 } array } - override def splits = splits_ + override def getSplits = splits_ - @transient - override val dependencies = { + @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { @@ -47,9 +50,17 @@ class UnionRDD[T: ClassManifest]( deps.toList } + override def getDependencies = deps_ + override def compute(s: Split, context: TaskContext): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator(context) - override def preferredLocations(s: Split): Seq[String] = + override def getPreferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() + + override def clearDependencies() { + deps_ = null + splits_ = null + rdds = null + } } diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index 92d667ff1e..d950b06c85 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -1,53 +1,65 @@ package spark.rdd import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext} +import java.io.{ObjectOutputStream, IOException} private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest]( idx: Int, - rdd1: RDD[T], - rdd2: RDD[U], - split1: Split, - split2: Split) - extends Split - with Serializable { + @transient rdd1: RDD[T], + @transient rdd2: RDD[U] + ) extends Split { - def iterator(context: TaskContext): Iterator[(T, U)] = - rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context)) + var split1 = rdd1.splits(idx) + var split2 = rdd1.splits(idx) + override val index: Int = idx - def preferredLocations(): Seq[String] = - rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2)) + def splits = (split1, split2) - override val index: Int = idx + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + split1 = rdd1.splits(idx) + split2 = rdd2.splits(idx) + oos.defaultWriteObject() + } } class ZippedRDD[T: ClassManifest, U: ClassManifest]( sc: SparkContext, - @transient rdd1: RDD[T], - @transient rdd2: RDD[U]) - extends RDD[(T, U)](sc) + var rdd1: RDD[T], + var rdd2: RDD[U]) + extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) with Serializable { - @transient - val splits_ : Array[Split] = { + // TODO: FIX THIS. + + @transient var splits_ : Array[Split] = { if (rdd1.splits.size != rdd2.splits.size) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } val array = new Array[Split](rdd1.splits.size) for (i <- 0 until rdd1.splits.size) { - array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i)) + array(i) = new ZippedSplit(i, rdd1, rdd2) } array } - override def splits = splits_ + override def getSplits = splits_ - @transient - override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)) + override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = { + val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits + rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context)) + } - override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = - s.asInstanceOf[ZippedSplit[T, U]].iterator(context) + override def getPreferredLocations(s: Split): Seq[String] = { + val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits + rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2)) + } - override def preferredLocations(s: Split): Seq[String] = - s.asInstanceOf[ZippedSplit[T, U]].preferredLocations() + override def clearDependencies() { + splits_ = null + rdd1 = null + rdd2 = null + } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 29757b1178..b320be8863 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -14,6 +14,7 @@ import spark.partial.ApproximateEvaluator import spark.partial.PartialResult import spark.storage.BlockManagerMaster import spark.storage.BlockManagerId +import util.{MetadataCleaner, TimeStampedHashMap} /** * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for @@ -61,28 +62,35 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val nextStageId = new AtomicInteger(0) - val idToStage = new HashMap[Int, Stage] + val idToStage = new TimeStampedHashMap[Int, Stage] - val shuffleToMapStage = new HashMap[Int, Stage] + val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] var cacheLocs = new HashMap[Int, Array[List[String]]] val env = SparkEnv.get - val cacheTracker = env.cacheTracker val mapOutputTracker = env.mapOutputTracker + val blockManagerMaster = env.blockManager.master - val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; - // that's not going to be a realistic assumption in general + // For tracking failed nodes, we use the MapOutputTracker's generation number, which is + // sent with every task. When we detect a node failing, we note the current generation number + // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask + // results. + // TODO: Garbage collect information about failure generations when we know there are no more + // stray messages to detect. + val failedGeneration = new HashMap[String, Long] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done val running = new HashSet[Stage] // Stages we are running right now val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures - val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage + val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] + val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) + // Start a thread to run the DAGScheduler event loop new Thread("DAGScheduler") { setDaemon(true) @@ -92,11 +100,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with }.start() def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { + if (!cacheLocs.contains(rdd.id)) { + val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray + cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { + locations => locations.map(_.ip).toList + }.toArray + } cacheLocs(rdd.id) } - def updateCacheLocs() { - cacheLocs = cacheTracker.getLocationsSnapshot() + def clearCacheLocs() { + cacheLocs.clear } /** @@ -123,7 +137,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of splits is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - cacheTracker.registerRDD(rdd.id, rdd.splits.size) if (shuffleDep != None) { mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) } @@ -145,8 +158,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with visited += r // Kind of ugly: need to register RDDs with the cache here since // we can't do it in its constructor because # of splits is unknown - logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")") - cacheTracker.registerRDD(r.id, r.splits.size) for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => @@ -247,7 +258,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val runId = nextRunId.getAndIncrement() val finalStage = newStage(finalRDD, None, runId) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) - updateCacheLocs() + clearCacheLocs() logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + " output partitions") logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") @@ -290,7 +301,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with // on the failed node. if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { logInfo("Resubmitting failed stages") - updateCacheLocs() + clearCacheLocs() val failed2 = failed.toArray failed.clear() for (stage <- failed2.sortBy(_.priority)) { @@ -426,7 +437,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val status = event.result.asInstanceOf[MapStatus] val host = status.address.ip logInfo("ShuffleMapTask finished with host " + host) - if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos + if (failedGeneration.contains(host) && smt.generation <= failedGeneration(host)) { + logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + host) + } else { stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { @@ -436,11 +449,18 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with logInfo("waiting: " + waiting) logInfo("failed: " + failed) if (stage.shuffleDep != None) { + // We supply true to increment the generation number here in case this is a + // recomputation of the map outputs. In that case, some nodes may have cached + // locations with holes (from when we detected the error) and will need the + // generation incremented to refetch them. + // TODO: Only increment the generation number if this is not the first time + // we registered these map outputs. mapOutputTracker.registerMapOutputs( stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) + stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + true) } - updateCacheLocs() + clearCacheLocs() if (stage.outputLocs.count(_ == Nil) != 0) { // Some tasks had failed; let's resubmit this stage // TODO: Lower-level scheduler should also deal with this @@ -492,7 +512,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock // TODO: mark the host as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleHostLost(bmAddress.ip) + handleHostLost(bmAddress.ip, Some(task.generation)) } case other => @@ -504,11 +524,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with /** * Responds to a host being lost. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside. + * + * Optionally the generation during which the failure was caught can be passed to avoid allowing + * stray fetch failures from possibly retriggering the detection of a node as lost. */ - def handleHostLost(host: String) { - if (!deadHosts.contains(host)) { - logInfo("Host lost: " + host) - deadHosts += host + def handleHostLost(host: String, maybeGeneration: Option[Long] = None) { + val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration) + if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) { + failedGeneration(host) = currentGeneration + logInfo("Host lost: " + host + " (generation " + currentGeneration + ")") env.blockManager.master.notifyADeadHost(host) // TODO: This will be really slow if we keep accumulating shuffle map stages for ((shuffleId, stage) <- shuffleToMapStage) { @@ -516,8 +540,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray mapOutputTracker.registerMapOutputs(shuffleId, locs, true) } - cacheTracker.cacheLost(host) - updateCacheLocs() + if (shuffleToMapStage.isEmpty) { + mapOutputTracker.incrementGeneration() + } + clearCacheLocs() + } else { + logDebug("Additional host lost message for " + host + + "(generation " + currentGeneration + ")") } } @@ -594,8 +623,23 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return Nil } + def cleanup(cleanupTime: Long) { + var sizeBefore = idToStage.size + idToStage.clearOldValues(cleanupTime) + logInfo("idToStage " + sizeBefore + " --> " + idToStage.size) + + sizeBefore = shuffleToMapStage.size + shuffleToMapStage.clearOldValues(cleanupTime) + logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) + + sizeBefore = pendingTasks.size + pendingTasks.clearOldValues(cleanupTime) + logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) + } + def stop() { eventQueue.put(StopDAGScheduler) + metadataCleaner.cancel() taskSched.stop() } } diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala index 4532d9497f..fae643f3a8 100644 --- a/core/src/main/scala/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -20,7 +20,7 @@ private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: } def readExternal(in: ObjectInput) { - address = new BlockManagerId(in) + address = BlockManagerId(in) compressedSizes = new Array[Byte](in.readInt()) in.readFully(compressedSizes) } diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index e492279b4e..8cd4c661eb 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -1,26 +1,112 @@ package spark.scheduler import spark._ +import java.io._ +import util.{MetadataCleaner, TimeStampedHashMap} +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +private[spark] object ResultTask { + + // A simple map between the stage id to the serialized byte array of a task. + // Served as a cache for task serialization because serialization can be + // expensive on the master node if it needs to launch thousands of tasks. + val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] + + val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues) + + def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { + synchronized { + val old = serializedInfoCache.get(stageId).orNull + if (old != null) { + return old + } else { + val out = new ByteArrayOutputStream + val ser = SparkEnv.get.closureSerializer.newInstance + val objOut = ser.serializeStream(new GZIPOutputStream(out)) + objOut.writeObject(rdd) + objOut.writeObject(func) + objOut.close() + val bytes = out.toByteArray + serializedInfoCache.put(stageId, bytes) + return bytes + } + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { + synchronized { + val loader = Thread.currentThread.getContextClassLoader + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] + return (rdd, func) + } + } + + def clearCache() { + synchronized { + serializedInfoCache.clear() + } + } +} + private[spark] class ResultTask[T, U]( stageId: Int, - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - val partition: Int, + var rdd: RDD[T], + var func: (TaskContext, Iterator[T]) => U, + var partition: Int, @transient locs: Seq[String], val outputId: Int) - extends Task[U](stageId) { + extends Task[U](stageId) with Externalizable { - val split = rdd.splits(partition) + def this() = this(0, null, null, 0, null, 0) + + var split = if (rdd == null) { + null + } else { + rdd.splits(partition) + } override def run(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) - val result = func(context, rdd.iterator(split, context)) - context.executeOnCompleteCallbacks() - result + try { + func(context, rdd.iterator(split, context)) + } finally { + context.executeOnCompleteCallbacks() + } } override def preferredLocations: Seq[String] = locs override def toString = "ResultTask(" + stageId + ", " + partition + ")" + + override def writeExternal(out: ObjectOutput) { + RDDCheckpointData.synchronized { + split = rdd.splits(partition) + out.writeInt(stageId) + val bytes = ResultTask.serializeInfo( + stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeInt(outputId) + out.writeObject(split) + } + } + + override def readExternal(in: ObjectInput) { + val stageId = in.readInt() + val numBytes = in.readInt() + val bytes = new Array[Byte](numBytes) + in.readFully(bytes) + val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) + rdd = rdd_.asInstanceOf[RDD[T]] + func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] + partition = in.readInt() + val outputId = in.readInt() + split = in.readObject().asInstanceOf[Split] + } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index bd1911fce2..19f5328eee 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -14,17 +14,20 @@ import com.ning.compress.lzf.LZFOutputStream import spark._ import spark.storage._ +import util.{TimeStampedHashMap, MetadataCleaner} private[spark] object ShuffleMapTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new JHashMap[Int, Array[Byte]] + val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] + + val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { - val old = serializedInfoCache.get(stageId) + val old = serializedInfoCache.get(stageId).orNull if (old != null) { return old } else { @@ -87,13 +90,16 @@ private[spark] class ShuffleMapTask( } override def writeExternal(out: ObjectOutput) { - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partition) - out.writeLong(generation) - out.writeObject(split) + RDDCheckpointData.synchronized { + split = rdd.splits(partition) + out.writeInt(stageId) + val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeLong(generation) + out.writeObject(split) + } } override def readExternal(in: ObjectInput) { diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 20f6e65020..a639b72795 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -252,19 +252,24 @@ private[spark] class ClusterScheduler(val sc: SparkContext) def slaveLost(slaveId: String, reason: ExecutorLossReason) { var failedHost: Option[String] = None synchronized { - val host = slaveIdToHost(slaveId) - if (hostsAlive.contains(host)) { - logError("Lost an executor on " + host + ": " + reason) - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) - } else { - // We may get multiple slaveLost() calls with different loss reasons. For example, one - // may be triggered by a dropped connection from the slave while another may be a report - // of executor termination from Mesos. We produce log messages for both so we eventually - // report the termination reason. - logError("Lost an executor on " + host + " (already removed): " + reason) + slaveIdToHost.get(slaveId) match { + case Some(host) => + if (hostsAlive.contains(host)) { + logError("Lost an executor on " + host + ": " + reason) + slaveIdsWithExecutors -= slaveId + hostsAlive -= host + activeTaskSetsQueue.foreach(_.hostLost(host)) + failedHost = Some(host) + } else { + // We may get multiple slaveLost() calls with different loss reasons. For example, one + // may be triggered by a dropped connection from the slave while another may be a report + // of executor termination from Mesos. We produce log messages for both so we eventually + // report the termination reason. + logError("Lost an executor on " + host + " (already removed): " + reason) + } + case None => + // We were told about a slave being lost before we could even allocate work to it + logError("Lost slave " + slaveId + " (no work assigned yet)") } } if (failedHost != None) { diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index e2301347e5..4f82cd96dd 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -39,7 +39,8 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command) + val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) + val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) client = new Client(sc.env.actorSystem, master, jobDesc, this) client.start() diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index cf4aae03a7..a089b71644 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -201,7 +201,11 @@ private[spark] class TaskSetManager( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) "preferred" else "non-preferred" + val prefStr = if (preferred) { + "preferred" + } else { + "non-preferred, not one of " + task.preferredLocations.mkString(", ") + } logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( taskSet.id, index, taskId, slaveId, host, prefStr)) // Do various bookkeeping diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index eb20fe41b2..9ff7c02097 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -20,7 +20,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon with Logging { var attemptId = new AtomicInteger(0) - var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + var threadPool = Utils.newDaemonFixedThreadPool(threads) val env = SparkEnv.get var listener: TaskSchedulerListener = null @@ -81,7 +81,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( ser.serialize(Accumulators.values)) logInfo("Finished task " + idInJob) - listener.taskEnded(task, Success, resultToReturn, accumUpdates) + + // If the threadpool has not already been shutdown, notify DAGScheduler + if (!Thread.currentThread().isInterrupted) + listener.taskEnded(task, Success, resultToReturn, accumUpdates) } catch { case t: Throwable => { logError("Exception in task " + idInJob, t) @@ -91,7 +94,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon submitTask(task, idInJob) } else { // TODO: Do something nicer here to return all the way to the user - listener.taskEnded(task, new ExceptionFailure(t), null, null) + if (!Thread.currentThread().isInterrupted) + listener.taskEnded(task, new ExceptionFailure(t), null, null) } } } @@ -108,22 +112,24 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) - currentFiles(name) = timestamp - } - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(".")) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(".", localName).toURI.toURL - if (!classLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - classLoader.addURL(url) + synchronized { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL + if (!classLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + classLoader.addURL(url) + } } } } diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index c45c7df69c..014906b028 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -64,13 +64,9 @@ private[spark] class CoarseMesosSchedulerBackend( val taskIdToSlaveId = new HashMap[Int, String] val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed - val sparkHome = sc.getSparkHome() match { - case Some(path) => - path - case None => - throw new SparkException("Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor") - } + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt @@ -184,7 +180,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { + private def getResource(res: JList[Resource], name: String): Double = { for (r <- res if r.getName == name) { return r.getScalar.getValue } @@ -193,7 +189,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Build a Mesos resource protobuf object */ - def createResource(resourceName: String, quantity: Double): Protos.Resource = { + private def createResource(resourceName: String, quantity: Double): Protos.Resource = { Resource.newBuilder() .setName(resourceName) .setType(Value.Type.SCALAR) @@ -202,7 +198,7 @@ private[spark] class CoarseMesosSchedulerBackend( } /** Check whether a Mesos task state represents a finished task */ - def isFinished(state: MesosTaskState) = { + private def isFinished(state: MesosTaskState) = { state == MesosTaskState.TASK_FINISHED || state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_KILLED || diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 8c7a1dfbc0..2989e31f5e 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -76,13 +76,9 @@ private[spark] class MesosSchedulerBackend( } def createExecutorInfo(): ExecutorInfo = { - val sparkHome = sc.getSparkHome() match { - case Some(path) => - path - case None => - throw new SparkException("Spark home is not set; set it through the spark.home system " + - "property, the SPARK_HOME environment variable or the SparkContext constructor") - } + val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException( + "Spark home is not set; set it through the spark.home system " + + "property, the SPARK_HOME environment variable or the SparkContext constructor")) val execScript = new File(sparkHome, "spark-executor").getCanonicalPath val environment = Environment.newBuilder() sc.executorEnvs.foreach { case (key, value) => diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index df295b1820..19cdaaa984 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -1,59 +1,39 @@ package spark.storage -import akka.actor.{ActorSystem, Cancellable} -import akka.dispatch.{Await, Future} -import akka.util.Duration -import akka.util.duration._ - -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - -import java.io.{InputStream, OutputStream, Externalizable, ObjectInput, ObjectOutput} -import java.nio.{MappedByteBuffer, ByteBuffer} +import java.io.{InputStream, OutputStream} +import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import scala.collection.JavaConversions._ -import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils} -import spark.network._ -import spark.serializer.Serializer -import spark.util.ByteBufferInputStream -import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} -import sun.nio.ch.DirectBuffer - - -private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { - def this() = this(null, 0) // For deserialization only - - def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) +import akka.actor.{ActorSystem, Cancellable, Props} +import akka.dispatch.{Await, Future} +import akka.util.Duration +import akka.util.duration._ - override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) - } +import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} - override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() - } +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - override def toString = "BlockManagerId(" + ip + ", " + port + ")" +import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils} +import spark.network._ +import spark.serializer.Serializer +import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap} - override def hashCode = ip.hashCode * 41 + port +import sun.nio.ch.DirectBuffer - override def equals(that: Any) = that match { - case id: BlockManagerId => port == id.port && ip == id.ip - case _ => false - } -} private[spark] case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) private[spark] -class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, - val serializer: Serializer, maxMemory: Long) +class BlockManager( + actorSystem: ActorSystem, + val master: BlockManagerMaster, + val serializer: Serializer, + maxMemory: Long) extends Logging { class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { @@ -79,7 +59,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, } } - private val blockInfo = new ConcurrentHashMap[String, BlockInfo](1000) + private val blockInfo = new TimeStampedHashMap[String, BlockInfo] private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) private[storage] val diskStore: BlockStore = @@ -89,10 +69,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, implicit val futureExecContext = connectionManager.futureExecContext val connectionManagerId = connectionManager.id - val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) - - // TODO: This will be removed after cacheTracker is removed from the code base. - var cacheTracker: CacheTracker = null + val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port) // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) @@ -110,16 +87,20 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, val host = System.getProperty("spark.hostname", Utils.localHostName()) + val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), + name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + @volatile private var shuttingDown = false private def heartBeat() { - if (!master.mustHeartBeat(HeartBeat(blockManagerId))) { + if (!master.sendHeartBeat(blockManagerId)) { reregister() } } var heartBeatTask: Cancellable = null + val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks) initialize() /** @@ -134,8 +115,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, * BlockManagerWorker actor. */ private def initialize() { - master.mustRegisterBlockManager( - RegisterBlockManager(blockManagerId, maxMemory)) + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) if (!BlockManager.getDisableHeartBeatsForTesting) { heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { @@ -156,8 +136,8 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, */ private def reportAllBlocks() { logInfo("Reporting " + blockInfo.size + " blocks to the master.") - for (blockId <- blockInfo.keys) { - if (!tryToReportBlockStatus(blockId)) { + for ((blockId, info) <- blockInfo) { + if (!tryToReportBlockStatus(blockId, info)) { logError("Failed to report " + blockId + " to master; giving up.") return } @@ -171,26 +151,22 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, def reregister() { // TODO: We might need to rate limit reregistering. logInfo("BlockManager reregistering with master") - master.mustRegisterBlockManager( - RegisterBlockManager(blockManagerId, maxMemory)) + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) reportAllBlocks() } /** * Get storage level of local block. If no info exists for the block, then returns null. */ - def getLevel(blockId: String): StorageLevel = { - val info = blockInfo.get(blockId) - if (info != null) info.level else null - } + def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull /** * Tell the master about the current storage status of a block. This will send a block update * message reflecting the current status, *not* the desired storage level in its block info. * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. */ - def reportBlockStatus(blockId: String) { - val needReregister = !tryToReportBlockStatus(blockId) + def reportBlockStatus(blockId: String, info: BlockInfo) { + val needReregister = !tryToReportBlockStatus(blockId, info) if (needReregister) { logInfo("Got told to reregister updating block " + blockId) // Reregistering will report our new block for free. @@ -200,33 +176,27 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, } /** - * Actually send a BlockUpdate message. Returns the mater's response, which will be true if the - * block was successfully recorded and false if the slave needs to re-register. + * Actually send a UpdateBlockInfo message. Returns the mater's response, + * which will be true if the block was successfully recorded and false if + * the slave needs to re-register. */ - private def tryToReportBlockStatus(blockId: String): Boolean = { - val (curLevel, inMemSize, onDiskSize, tellMaster) = blockInfo.get(blockId) match { - case null => - (StorageLevel.NONE, 0L, 0L, false) - case info => - info.synchronized { - info.level match { - case null => - (StorageLevel.NONE, 0L, 0L, false) - case level => - val inMem = level.useMemory && memoryStore.contains(blockId) - val onDisk = level.useDisk && diskStore.contains(blockId) - ( - new StorageLevel(onDisk, inMem, level.deserialized, level.replication), - if (inMem) memoryStore.getSize(blockId) else 0L, - if (onDisk) diskStore.getSize(blockId) else 0L, - info.tellMaster - ) - } - } + private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = { + val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { + info.level match { + case null => + (StorageLevel.NONE, 0L, 0L, false) + case level => + val inMem = level.useMemory && memoryStore.contains(blockId) + val onDisk = level.useDisk && diskStore.contains(blockId) + val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication) + val memSize = if (inMem) memoryStore.getSize(blockId) else 0L + val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L + (storageLevel, memSize, diskSize, info.tellMaster) + } } if (tellMaster) { - master.mustBlockUpdate(BlockUpdate(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)) + master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize) } else { true } @@ -238,7 +208,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, */ def getLocations(blockId: String): Seq[String] = { val startTimeMs = System.currentTimeMillis - var managers = master.mustGetLocations(GetLocations(blockId)) + var managers = master.getLocations(blockId) val locations = managers.map(_.ip) logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations @@ -249,8 +219,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, */ def getLocations(blockIds: Array[String]): Array[Seq[String]] = { val startTimeMs = System.currentTimeMillis - val locations = master.mustGetLocationsMultipleBlockIds( - GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray + val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -272,7 +241,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, } } - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { info.waitForReady() // In case the block is still being put() by another thread @@ -357,7 +326,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, } } - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { info.waitForReady() // In case the block is still being put() by another thread @@ -413,7 +382,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, } logDebug("Getting remote block " + blockId) // Get locations of block - val locations = master.mustGetLocations(GetLocations(blockId)) + val locations = master.getLocations(blockId) // Get block from remote locations for (loc <- locations) { @@ -615,7 +584,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, throw new IllegalArgumentException("Storage level is null or invalid") } - val oldBlock = blockInfo.get(blockId) + val oldBlock = blockInfo.get(blockId).orNull if (oldBlock != null) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") oldBlock.waitForReady() @@ -670,7 +639,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, // and tell the master about it. myInfo.markReady(size) if (tellMaster) { - reportBlockStatus(blockId) + reportBlockStatus(blockId, myInfo) } } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) @@ -690,10 +659,6 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, BlockManager.dispose(bytesAfterPut) - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyCacheTracker(blockId) - } logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)) return size @@ -716,7 +681,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, throw new IllegalArgumentException("Storage level is null or invalid") } - if (blockInfo.containsKey(blockId)) { + if (blockInfo.contains(blockId)) { logWarning("Block " + blockId + " already exists on this machine; not re-adding it") return } @@ -757,15 +722,10 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, // and tell the master about it. myInfo.markReady(bytes.limit) if (tellMaster) { - reportBlockStatus(blockId) + reportBlockStatus(blockId, myInfo) } } - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyCacheTracker(blockId) - } - // If replication had started, then wait for it to finish if (level.replication > 1) { if (replicationFuture == null) { @@ -788,10 +748,9 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, */ var cachedPeers: Seq[BlockManagerId] = null private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { - val tLevel: StorageLevel = - new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) + val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { - cachedPeers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + cachedPeers = master.getPeers(blockManagerId, level.replication - 1) } for (peer: BlockManagerId <- cachedPeers) { val start = System.nanoTime @@ -808,16 +767,6 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, } } - // TODO: This code will be removed when CacheTracker is gone. - private def notifyCacheTracker(key: String) { - if (cacheTracker != null) { - val rddInfo = key.split("_") - val rddId: Int = rddInfo(1).toInt - val partition: Int = rddInfo(2).toInt - cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host)) - } - } - /** * Read a block consisting of a single object. */ @@ -838,7 +787,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, */ def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") - val info = blockInfo.get(blockId) + val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { val level = info.level @@ -851,9 +800,12 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, diskStore.putBytes(blockId, bytes, level) } } - memoryStore.remove(blockId) + val blockWasRemoved = memoryStore.remove(blockId) + if (!blockWasRemoved) { + logWarning("Block " + blockId + " could not be dropped from memory as it does not exist") + } if (info.tellMaster) { - reportBlockStatus(blockId) + reportBlockStatus(blockId, info) } if (!level.useDisk) { // The block is completely gone from this node; forget it so we can put() it again later. @@ -865,6 +817,53 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, } } + /** + * Remove a block from both memory and disk. + */ + def removeBlock(blockId: String) { + logInfo("Removing block " + blockId) + val info = blockInfo.get(blockId).orNull + if (info != null) info.synchronized { + // Removals are idempotent in disk store and memory store. At worst, we get a warning. + val removedFromMemory = memoryStore.remove(blockId) + val removedFromDisk = diskStore.remove(blockId) + if (!removedFromMemory && !removedFromDisk) { + logWarning("Block " + blockId + " could not be removed as it was not found in either " + + "the disk or memory store") + } + blockInfo.remove(blockId) + if (info.tellMaster) { + reportBlockStatus(blockId, info) + } + } else { + // The block has already been removed; do nothing. + logWarning("Asked to remove block " + blockId + ", which does not exist") + } + } + + def dropOldBlocks(cleanupTime: Long) { + logInfo("Dropping blocks older than " + cleanupTime) + val iterator = blockInfo.internalMap.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) + if (time < cleanupTime) { + info.synchronized { + val level = info.level + if (level.useMemory) { + memoryStore.remove(id) + } + if (level.useDisk) { + diskStore.remove(id) + } + iterator.remove() + logInfo("Dropped block " + id) + } + reportBlockStatus(id, info) + } + } + } + def shouldCompress(blockId: String): Boolean = { if (blockId.startsWith("shuffle_")) { compressShuffle @@ -914,6 +913,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, heartBeatTask.cancel() } connectionManager.stop() + master.actorSystem.stop(slaveActor) blockInfo.clear() memoryStore.clear() diskStore.clear() @@ -923,6 +923,9 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster, private[spark] object BlockManager extends Logging { + + val ID_GENERATOR = new IdGenerator + def getMaxMemoryFromSystemProperties: Long = { val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble (Runtime.getRuntime.maxMemory * memoryFraction).toLong diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala new file mode 100644 index 0000000000..abb8b45a1f --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerId.scala @@ -0,0 +1,70 @@ +package spark.storage + +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} +import java.util.concurrent.ConcurrentHashMap + +/** + * This class represent an unique identifier for a BlockManager. + * The first 2 constructors of this class is made private to ensure that + * BlockManagerId objects can be created only using the factory method in + * [[spark.storage.BlockManager$]]. This allows de-duplication of id objects. + * Also, constructor parameters are private to ensure that parameters cannot + * be modified from outside this class. + */ +private[spark] class BlockManagerId private ( + private var ip_ : String, + private var port_ : Int + ) extends Externalizable { + + private def this() = this(null, 0) // For deserialization only + + def ip = ip_ + + def port = port_ + + override def writeExternal(out: ObjectOutput) { + out.writeUTF(ip_) + out.writeInt(port_) + } + + override def readExternal(in: ObjectInput) { + ip_ = in.readUTF() + port_ = in.readInt() + } + + @throws(classOf[IOException]) + private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) + + override def toString = "BlockManagerId(" + ip + ", " + port + ")" + + override def hashCode = ip.hashCode * 41 + port + + override def equals(that: Any) = that match { + case id: BlockManagerId => port == id.port && ip == id.ip + case _ => false + } +} + + +private[spark] object BlockManagerId { + + def apply(ip: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(ip, port)) + + def apply(in: ObjectInput) = { + val obj = new BlockManagerId() + obj.readExternal(in) + getCachedBlockManagerId(obj) + } + + val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]() + + def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = { + if (blockManagerIdCache.containsKey(id)) { + blockManagerIdCache.get(id) + } else { + blockManagerIdCache.put(id, id) + id + } + } +} diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 0a4e68f437..a3d8671834 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -1,676 +1,167 @@ package spark.storage -import java.io._ -import java.util.{HashMap => JHashMap} - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable.ArrayBuffer import scala.util.Random -import akka.actor._ -import akka.dispatch._ +import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import akka.dispatch.Await import akka.pattern.ask -import akka.remote._ import akka.util.{Duration, Timeout} import akka.util.duration._ import spark.{Logging, SparkException, Utils} -private[spark] -sealed trait ToBlockManagerMaster - -private[spark] -case class RegisterBlockManager( - blockManagerId: BlockManagerId, - maxMemSize: Long) - extends ToBlockManagerMaster - -private[spark] -case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - -private[spark] -class BlockUpdate( - var blockManagerId: BlockManagerId, - var blockId: String, - var storageLevel: StorageLevel, - var memSize: Long, - var diskSize: Long) - extends ToBlockManagerMaster - with Externalizable { - - def this() = this(null, null, null, 0, 0) // For deserialization only - - override def writeExternal(out: ObjectOutput) { - blockManagerId.writeExternal(out) - out.writeUTF(blockId) - storageLevel.writeExternal(out) - out.writeInt(memSize.toInt) - out.writeInt(diskSize.toInt) - } - - override def readExternal(in: ObjectInput) { - blockManagerId = new BlockManagerId() - blockManagerId.readExternal(in) - blockId = in.readUTF() - storageLevel = new StorageLevel() - storageLevel.readExternal(in) - memSize = in.readInt() - diskSize = in.readInt() - } -} - -private[spark] -object BlockUpdate { - def apply(blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - memSize: Long, - diskSize: Long): BlockUpdate = { - new BlockUpdate(blockManagerId, blockId, storageLevel, memSize, diskSize) - } - - // For pattern-matching - def unapply(h: BlockUpdate): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) - } -} - -private[spark] -case class GetLocations(blockId: String) extends ToBlockManagerMaster - -private[spark] -case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster - -private[spark] -case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster - -private[spark] -case class RemoveHost(host: String) extends ToBlockManagerMaster - -private[spark] -case object StopBlockManagerMaster extends ToBlockManagerMaster - -private[spark] -case object GetMemoryStatus extends ToBlockManagerMaster - -private[spark] -case object ExpireDeadHosts extends ToBlockManagerMaster - - -private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { - - class BlockManagerInfo( - val blockManagerId: BlockManagerId, - timeMs: Long, - val maxMem: Long) { - private var _lastSeenMs = timeMs - private var _remainingMem = maxMem - private val _blocks = new JHashMap[String, StorageLevel] - - logInfo("Registering block manager %s:%d with %s RAM".format( - blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) - - def updateLastSeenMs() { - _lastSeenMs = System.currentTimeMillis() - } - - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) - : Unit = synchronized { - - updateLastSeenMs() - - if (_blocks.containsKey(blockId)) { - // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId) - - if (originalLevel.useMemory) { - _remainingMem += memSize - } - } - - if (storageLevel.isValid) { - // isValid means it is either stored in-memory or on-disk. - _blocks.put(blockId, storageLevel) - if (storageLevel.useMemory) { - _remainingMem -= memSize - logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), - Utils.memoryBytesToString(_remainingMem))) - } - if (storageLevel.useDisk) { - logInfo("Added %s on disk on %s:%d (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) - } - } else if (_blocks.containsKey(blockId)) { - // If isValid is not true, drop the block. - val originalLevel: StorageLevel = _blocks.get(blockId) - _blocks.remove(blockId) - if (originalLevel.useMemory) { - _remainingMem += memSize - logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), - Utils.memoryBytesToString(_remainingMem))) - } - if (originalLevel.useDisk) { - logInfo("Removed %s on %s:%d on disk (size: %s)".format( - blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) - } - } - } - - def remainingMem: Long = _remainingMem - - def lastSeenMs: Long = _lastSeenMs - - def blocks: JHashMap[String, StorageLevel] = _blocks - - override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem - - def clear() { - _blocks.clear() - } - } - - private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo] - private val blockManagerIdByHost = new HashMap[String, BlockManagerId] - private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] - - initLogging() - - val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs", - "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong - - val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", - "5000").toLong - - var timeoutCheckingTask: Cancellable = null +private[spark] class BlockManagerMaster( + val actorSystem: ActorSystem, + isMaster: Boolean, + isLocal: Boolean, + masterIp: String, + masterPort: Int) + extends Logging { - override def preStart() { - if (!BlockManager.getDisableHeartBeatsForTesting) { - timeoutCheckingTask = context.system.scheduler.schedule( - 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) - } - super.preStart() - } + val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt + val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt - def removeBlockManager(blockManagerId: BlockManagerId) { - val info = blockManagerInfo(blockManagerId) - blockManagerIdByHost.remove(blockManagerId.ip) - blockManagerInfo.remove(blockManagerId) - var iterator = info.blocks.keySet.iterator - while (iterator.hasNext) { - val blockId = iterator.next - val locations = blockInfo.get(blockId)._2 - locations -= blockManagerId - if (locations.size == 0) { - blockInfo.remove(locations) - } - } - } - - def expireDeadHosts() { - logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.") - val now = System.currentTimeMillis() - val minSeenTime = now - slaveTimeout - val toRemove = new HashSet[BlockManagerId] - for (info <- blockManagerInfo.values) { - if (info.lastSeenMs < minSeenTime) { - logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats") - toRemove += info.blockManagerId - } - } - // TODO: Remove corresponding block infos - toRemove.foreach(removeBlockManager) - } - - def removeHost(host: String) { - logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") - logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) - blockManagerIdByHost.get(host).foreach(removeBlockManager) - logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) - sender ! true - } + val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager" + val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager" + val DEFAULT_MANAGER_IP: String = Utils.localHostName() - def heartBeat(blockManagerId: BlockManagerId) { - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { - sender ! true - } else { - sender ! false - } + val timeout = 10.seconds + var masterActor: ActorRef = { + if (isMaster) { + val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)), + name = MASTER_AKKA_ACTOR_NAME) + logInfo("Registered BlockManagerMaster Actor") + masterActor } else { - blockManagerInfo(blockManagerId).updateLastSeenMs() - sender ! true + val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME) + logInfo("Connecting to BlockManagerMaster: " + url) + actorSystem.actorFor(url) } } - def receive = { - case RegisterBlockManager(blockManagerId, maxMemSize) => - register(blockManagerId, maxMemSize) - - case BlockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) => - blockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) - - case GetLocations(blockId) => - getLocations(blockId) - - case GetLocationsMultipleBlockIds(blockIds) => - getLocationsMultipleBlockIds(blockIds) - - case GetPeers(blockManagerId, size) => - getPeersDeterministic(blockManagerId, size) - /*getPeers(blockManagerId, size)*/ - - case GetMemoryStatus => - getMemoryStatus - - case RemoveHost(host) => - removeHost(host) - sender ! true - - case StopBlockManagerMaster => - logInfo("Stopping BlockManagerMaster") - sender ! true - if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel - } - context.stop(self) - - case ExpireDeadHosts => - expireDeadHosts() - - case HeartBeat(blockManagerId) => - heartBeat(blockManagerId) - - case other => - logInfo("Got unknown message: " + other) + /** Remove a dead host from the master actor. This is only called on the master side. */ + def notifyADeadHost(host: String) { + tell(RemoveHost(host)) + logInfo("Removed " + host + " successfully in notifyADeadHost") } - // Return a map from the block manager id to max memory and remaining memory. - private def getMemoryStatus() { - val res = blockManagerInfo.map { case(blockManagerId, info) => - (blockManagerId, (info.maxMem, info.remainingMem)) - }.toMap - sender ! res + /** + * Send the master actor a heart beat from the slave. Returns true if everything works out, + * false if the master does not know about the given block manager, which means the block + * manager should re-register. + */ + def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = { + askMasterWithRetry[Boolean](HeartBeat(blockManagerId)) } - private def register(blockManagerId: BlockManagerId, maxMemSize: Long) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " - logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) - if (blockManagerIdByHost.contains(blockManagerId.ip) && - blockManagerIdByHost(blockManagerId.ip) != blockManagerId) { - val oldId = blockManagerIdByHost(blockManagerId.ip) - logInfo("Got second registration for host " + blockManagerId + - "; removing old slave " + oldId) - removeBlockManager(oldId) - } - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { - logInfo("Got Register Msg from master node, don't register it") - } else { - blockManagerInfo += (blockManagerId -> new BlockManagerInfo( - blockManagerId, System.currentTimeMillis(), maxMemSize)) - } - blockManagerIdByHost += (blockManagerId.ip -> blockManagerId) - logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) - sender ! true + /** Register the BlockManager's id with the master. */ + def registerBlockManager( + blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + logInfo("Trying to register BlockManager") + tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) + logInfo("Registered BlockManager") } - private def blockUpdate( + def updateBlockInfo( blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, memSize: Long, - diskSize: Long) { - - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " + blockId + " " - - if (!blockManagerInfo.contains(blockManagerId)) { - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { - // We intentionally do not register the master (except in local mode), - // so we should not indicate failure. - sender ! true - } else { - sender ! false - } - return - } - - if (blockId == null) { - blockManagerInfo(blockManagerId).updateLastSeenMs() - logDebug("Got in block update 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) - sender ! true - return - } - - blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) - - var locations: HashSet[BlockManagerId] = null - if (blockInfo.containsKey(blockId)) { - locations = blockInfo.get(blockId)._2 - } else { - locations = new HashSet[BlockManagerId] - blockInfo.put(blockId, (storageLevel.replication, locations)) - } - - if (storageLevel.isValid) { - locations += blockManagerId - } else { - locations.remove(blockManagerId) - } - - if (locations.size == 0) { - blockInfo.remove(blockId) - } - sender ! true + diskSize: Long): Boolean = { + val res = askMasterWithRetry[Boolean]( + UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) + logInfo("Updated info of block " + blockId) + res } - private def getLocations(blockId: String) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockId + " " - logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) - if (blockInfo.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockInfo.get(blockId)._2) - logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at " - + Utils.getUsedTimeMs(startTimeMs)) - sender ! res.toSeq - } else { - logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs)) - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - sender ! res - } + /** Get locations of the blockId from the master */ + def getLocations(blockId: String): Seq[BlockManagerId] = { + askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) } - private def getLocationsMultipleBlockIds(blockIds: Array[String]) { - def getLocations(blockId: String): Seq[BlockManagerId] = { - val tmp = blockId - logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp) - if (blockInfo.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockInfo.get(blockId)._2) - logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq) - return res.toSeq - } else { - logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp) - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - return res.toSeq - } - } - - logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq) - var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]] - for (blockId <- blockIds) { - res.append(getLocations(blockId)) - } - logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq) - sender ! res.toSeq + /** Get locations of multiple blockIds from the master */ + def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } - private def getPeers(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(peers) - res -= blockManagerId - val rand = new Random(System.currentTimeMillis()) - while (res.length > size) { - res.remove(rand.nextInt(res.length)) + /** Get ids of other nodes in the cluster from the master */ + def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { + val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) + if (result.length != numPeers) { + throw new SparkException( + "Error getting peers, only got " + result.size + " instead of " + numPeers) } - sender ! res.toSeq + result } - private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - - val peersWithIndices = peers.zipWithIndex - val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1) - if (selfIndex == -1) { - throw new Exception("Self index for " + blockManagerId + " not found") - } - - var index = selfIndex - while (res.size < size) { - index += 1 - if (index == selfIndex) { - throw new Exception("More peer expected than available") - } - res += peers(index % peers.size) - } - sender ! res.toSeq + /** + * Remove a block from the slaves that have it. This can only be used to remove + * blocks that the master knows about. + */ + def removeBlock(blockId: String) { + askMasterWithRetry(RemoveBlock(blockId)) } -} - -private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean) - extends Logging { - - val AKKA_ACTOR_NAME: String = "BlockMasterManager" - val REQUEST_RETRY_INTERVAL_MS = 100 - val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost") - val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt - val DEFAULT_MANAGER_IP: String = Utils.localHostName() - val timeout = 10.seconds - var masterActor: ActorRef = null - - if (isMaster) { - masterActor = actorSystem.actorOf( - Props(new BlockManagerMasterActor(isLocal)), name = AKKA_ACTOR_NAME) - logInfo("Registered BlockManagerMaster Actor") - } else { - val url = "akka://spark@%s:%s/user/%s".format( - DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT, AKKA_ACTOR_NAME) - logInfo("Connecting to BlockManagerMaster: " + url) - masterActor = actorSystem.actorFor(url) + /** + * Return the memory status for each block manager, in the form of a map from + * the block manager's id to two long values. The first value is the maximum + * amount of memory allocated for the block manager, while the second is the + * amount of remaining memory. + */ + def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { + askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } + /** Stop the master actor, called only on the Spark master node */ def stop() { if (masterActor != null) { - communicate(StopBlockManagerMaster) + tell(StopBlockManagerMaster) masterActor = null logInfo("BlockManagerMaster stopped") } } - // Send a message to the master actor and get its result within a default timeout, or - // throw a SparkException if this fails. - def askMaster(message: Any): Any = { - try { - val future = masterActor.ask(message)(timeout) - return Await.result(future, timeout) - } catch { - case e: Exception => - throw new SparkException("Error communicating with BlockManagerMaster", e) - } - } - - // Send a one-way message to the master actor, to which we expect it to reply with true. - def communicate(message: Any) { - if (askMaster(message) != true) { - throw new SparkException("Error reply received from BlockManagerMaster") - } - } - - def notifyADeadHost(host: String) { - communicate(RemoveHost(host)) - logInfo("Removed " + host + " successfully in notifyADeadHost") - } - - def mustRegisterBlockManager(msg: RegisterBlockManager) { - logInfo("Trying to register BlockManager") - while (! syncRegisterBlockManager(msg)) { - logWarning("Failed to register " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - logInfo("Done registering BlockManager") - } - - def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { - //val masterActor = RemoteActor.select(node, name) - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - communicate(msg) - logInfo("BlockManager registered successfully @ syncRegisterBlockManager") - logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return true - } catch { - case e: Exception => - logError("Failed in syncRegisterBlockManager", e) - return false - } - } - - def mustHeartBeat(msg: HeartBeat): Boolean = { - var res = syncHeartBeat(msg) - while (!res.isDefined) { - logWarning("Failed to send heart beat " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) + /** Send a one-way message to the master actor, to which we expect it to reply with true. */ + private def tell(message: Any) { + if (!askMasterWithRetry[Boolean](message)) { + throw new SparkException("BlockManagerMasterActor returned false, expected true.") } - return res.get } - def syncHeartBeat(msg: HeartBeat): Option[Boolean] = { - try { - val answer = askMaster(msg).asInstanceOf[Boolean] - return Some(answer) - } catch { - case e: Exception => - logError("Failed in syncHeartBeat", e) - return None + /** + * Send a message to the master actor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + private def askMasterWithRetry[T](message: Any): T = { + // TODO: Consider removing multiple attempts + if (masterActor == null) { + throw new SparkException("Error sending message to BlockManager as masterActor is null " + + "[message = " + message + "]") } - } - - def mustBlockUpdate(msg: BlockUpdate): Boolean = { - var res = syncBlockUpdate(msg) - while (!res.isDefined) { - logWarning("Failed to send block update " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - return res.get - } - - def syncBlockUpdate(msg: BlockUpdate): Option[Boolean] = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncBlockUpdate " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Boolean] - logDebug("Block update sent successfully") - logDebug("Got in synbBlockUpdate " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs)) - return Some(answer) - } catch { - case e: Exception => - logError("Failed in syncBlockUpdate", e) - return None - } - } - - def mustGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - var res = syncGetLocations(msg) - while (res == null) { - logInfo("Failed to get locations " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocations(msg) - } - return res - } - - def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[ArrayBuffer[BlockManagerId]] - if (answer != null) { - logDebug("GetLocations successful") - logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocations") - return null + var attempts = 0 + var lastException: Exception = null + while (attempts < AKKA_RETRY_ATTEMPS) { + attempts += 1 + try { + val future = masterActor.ask(message)(timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new Exception("BlockManagerMaster returned null") + } + return result.asInstanceOf[T] + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e) } - } catch { - case e: Exception => - logError("GetLocations failed", e) - return null + Thread.sleep(AKKA_RETRY_INTERVAL_MS) } - } - def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg) - while (res == null) { - logWarning("Failed to GetLocationsMultipleBlockIds " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocationsMultipleBlockIds(msg) - } - return res + throw new SparkException( + "Error sending message to BlockManagerMaster [message = " + message + "]", lastException) } - def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Seq[Seq[BlockManagerId]]] - if (answer != null) { - logDebug("GetLocationsMultipleBlockIds successful") - logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + - Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetLocationsMultipleBlockIds") - return null - } - } catch { - case e: Exception => - logError("GetLocationsMultipleBlockIds failed", e) - return null - } - } - - def mustGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - var res = syncGetPeers(msg) - while ((res == null) || (res.length != msg.size)) { - logInfo("Failed to get peers " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetPeers(msg) - } - - return res - } - - def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - try { - val answer = askMaster(msg).asInstanceOf[Seq[BlockManagerId]] - if (answer != null) { - logDebug("GetPeers successful") - logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return answer - } else { - logError("Master replied null in response to GetPeers") - return null - } - } catch { - case e: Exception => - logError("GetPeers failed", e) - return null - } - } - - def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - askMaster(GetMemoryStatus).asInstanceOf[Map[BlockManagerId, (Long, Long)]] - } } diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala new file mode 100644 index 0000000000..f4d026da33 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala @@ -0,0 +1,401 @@ +package spark.storage + +import java.util.{HashMap => JHashMap} + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.JavaConversions._ +import scala.util.Random + +import akka.actor.{Actor, ActorRef, Cancellable} +import akka.util.{Duration, Timeout} +import akka.util.duration._ + +import spark.{Logging, Utils} + +/** + * BlockManagerMasterActor is an actor on the master node to track statuses of + * all slaves' block managers. + */ +private[spark] +class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { + + // Mapping from block manager id to the block manager's information. + private val blockManagerInfo = + new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo] + + // Mapping from host name to block manager id. We allow multiple block managers + // on the same host name (ip). + private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]] + + // Mapping from block id to the set of block managers that have the block. + private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] + + initLogging() + + val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs", + "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong + + val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", + "5000").toLong + + var timeoutCheckingTask: Cancellable = null + + override def preStart() { + if (!BlockManager.getDisableHeartBeatsForTesting) { + timeoutCheckingTask = context.system.scheduler.schedule( + 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) + } + super.preStart() + } + + def receive = { + case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => + register(blockManagerId, maxMemSize, slaveActor) + + case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) + + case GetLocations(blockId) => + getLocations(blockId) + + case GetLocationsMultipleBlockIds(blockIds) => + getLocationsMultipleBlockIds(blockIds) + + case GetPeers(blockManagerId, size) => + getPeersDeterministic(blockManagerId, size) + /*getPeers(blockManagerId, size)*/ + + case GetMemoryStatus => + getMemoryStatus + + case RemoveBlock(blockId) => + removeBlock(blockId) + + case RemoveHost(host) => + removeHost(host) + sender ! true + + case StopBlockManagerMaster => + logInfo("Stopping BlockManagerMaster") + sender ! true + if (timeoutCheckingTask != null) { + timeoutCheckingTask.cancel + } + context.stop(self) + + case ExpireDeadHosts => + expireDeadHosts() + + case HeartBeat(blockManagerId) => + heartBeat(blockManagerId) + + case other => + logInfo("Got unknown message: " + other) + } + + def removeBlockManager(blockManagerId: BlockManagerId) { + val info = blockManagerInfo(blockManagerId) + + // Remove the block manager from blockManagerIdByHost. If the list of block + // managers belonging to the IP is empty, remove the entry from the hash map. + blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] => + managers -= blockManagerId + if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip) + } + + // Remove it from blockManagerInfo and remove all the blocks. + blockManagerInfo.remove(blockManagerId) + var iterator = info.blocks.keySet.iterator + while (iterator.hasNext) { + val blockId = iterator.next + val locations = blockLocations.get(blockId)._2 + locations -= blockManagerId + if (locations.size == 0) { + blockLocations.remove(locations) + } + } + } + + def expireDeadHosts() { + logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.") + val now = System.currentTimeMillis() + val minSeenTime = now - slaveTimeout + val toRemove = new HashSet[BlockManagerId] + for (info <- blockManagerInfo.values) { + if (info.lastSeenMs < minSeenTime) { + logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats") + toRemove += info.blockManagerId + } + } + toRemove.foreach(removeBlockManager) + } + + def removeHost(host: String) { + logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") + logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) + blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager)) + logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) + sender ! true + } + + def heartBeat(blockManagerId: BlockManagerId) { + if (!blockManagerInfo.contains(blockManagerId)) { + if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + sender ! true + } else { + sender ! false + } + } else { + blockManagerInfo(blockManagerId).updateLastSeenMs() + sender ! true + } + } + + // Remove a block from the slaves that have it. This can only be used to remove + // blocks that the master knows about. + private def removeBlock(blockId: String) { + val block = blockLocations.get(blockId) + if (block != null) { + block._2.foreach { blockManagerId: BlockManagerId => + val blockManager = blockManagerInfo.get(blockManagerId) + if (blockManager.isDefined) { + // Remove the block from the slave's BlockManager. + // Doesn't actually wait for a confirmation and the message might get lost. + // If message loss becomes frequent, we should add retry logic here. + blockManager.get.slaveActor ! RemoveBlock(blockId) + } + } + } + sender ! true + } + + // Return a map from the block manager id to max memory and remaining memory. + private def getMemoryStatus() { + val res = blockManagerInfo.map { case(blockManagerId, info) => + (blockManagerId, (info.maxMem, info.remainingMem)) + }.toMap + sender ! res + } + + private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockManagerId + " " + + if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + logInfo("Got Register Msg from master node, don't register it") + } else { + blockManagerIdByHost.get(blockManagerId.ip) match { + case Some(managers) => + // A block manager of the same host name already exists. + logInfo("Got another registration for host " + blockManagerId) + managers += blockManagerId + case None => + blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId)) + } + + blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo( + blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor)) + } + sender ! true + } + + private def updateBlockInfo( + blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long) { + + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockManagerId + " " + blockId + " " + + if (!blockManagerInfo.contains(blockManagerId)) { + if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + // We intentionally do not register the master (except in local mode), + // so we should not indicate failure. + sender ! true + } else { + sender ! false + } + return + } + + if (blockId == null) { + blockManagerInfo(blockManagerId).updateLastSeenMs() + sender ! true + return + } + + blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) + + var locations: HashSet[BlockManagerId] = null + if (blockLocations.containsKey(blockId)) { + locations = blockLocations.get(blockId)._2 + } else { + locations = new HashSet[BlockManagerId] + blockLocations.put(blockId, (storageLevel.replication, locations)) + } + + if (storageLevel.isValid) { + locations.add(blockManagerId) + } else { + locations.remove(blockManagerId) + } + + // Remove the block from master tracking if it has been removed on all slaves. + if (locations.size == 0) { + blockLocations.remove(blockId) + } + sender ! true + } + + private def getLocations(blockId: String) { + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockId + " " + if (blockLocations.containsKey(blockId)) { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(blockLocations.get(blockId)._2) + sender ! res.toSeq + } else { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + sender ! res + } + } + + private def getLocationsMultipleBlockIds(blockIds: Array[String]) { + def getLocations(blockId: String): Seq[BlockManagerId] = { + val tmp = blockId + if (blockLocations.containsKey(blockId)) { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(blockLocations.get(blockId)._2) + return res.toSeq + } else { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + return res.toSeq + } + } + + var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]] + for (blockId <- blockIds) { + res.append(getLocations(blockId)) + } + sender ! res.toSeq + } + + private def getPeers(blockManagerId: BlockManagerId, size: Int) { + var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(peers) + res -= blockManagerId + val rand = new Random(System.currentTimeMillis()) + while (res.length > size) { + res.remove(rand.nextInt(res.length)) + } + sender ! res.toSeq + } + + private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) { + var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + + val selfIndex = peers.indexOf(blockManagerId) + if (selfIndex == -1) { + throw new Exception("Self index for " + blockManagerId + " not found") + } + + // Note that this logic will select the same node multiple times if there aren't enough peers + var index = selfIndex + while (res.size < size) { + index += 1 + if (index == selfIndex) { + throw new Exception("More peer expected than available") + } + res += peers(index % peers.size) + } + sender ! res.toSeq + } +} + + +private[spark] +object BlockManagerMasterActor { + + case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) + + class BlockManagerInfo( + val blockManagerId: BlockManagerId, + timeMs: Long, + val maxMem: Long, + val slaveActor: ActorRef) + extends Logging { + + private var _lastSeenMs: Long = timeMs + private var _remainingMem: Long = maxMem + + // Mapping from block id to its status. + private val _blocks = new JHashMap[String, BlockStatus] + + logInfo("Registering block manager %s:%d with %s RAM".format( + blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) + + def updateLastSeenMs() { + _lastSeenMs = System.currentTimeMillis() + } + + def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) + : Unit = synchronized { + + updateLastSeenMs() + + if (_blocks.containsKey(blockId)) { + // The block exists on the slave already. + val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel + + if (originalLevel.useMemory) { + _remainingMem += memSize + } + } + + if (storageLevel.isValid) { + // isValid means it is either stored in-memory or on-disk. + _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize)) + if (storageLevel.useMemory) { + _remainingMem -= memSize + logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + Utils.memoryBytesToString(_remainingMem))) + } + if (storageLevel.useDisk) { + logInfo("Added %s on disk on %s:%d (size: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + } + } else if (_blocks.containsKey(blockId)) { + // If isValid is not true, drop the block. + val blockStatus: BlockStatus = _blocks.get(blockId) + _blocks.remove(blockId) + if (blockStatus.storageLevel.useMemory) { + _remainingMem += blockStatus.memSize + logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + Utils.memoryBytesToString(_remainingMem))) + } + if (blockStatus.storageLevel.useDisk) { + logInfo("Removed %s on %s:%d on disk (size: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + } + } + } + + def remainingMem: Long = _remainingMem + + def lastSeenMs: Long = _lastSeenMs + + def blocks: JHashMap[String, BlockStatus] = _blocks + + override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem + + def clear() { + _blocks.clear() + } + } +} diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala new file mode 100644 index 0000000000..30483b0b37 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala @@ -0,0 +1,100 @@ +package spark.storage + +import java.io.{Externalizable, ObjectInput, ObjectOutput} + +import akka.actor.ActorRef + + +////////////////////////////////////////////////////////////////////////////////// +// Messages from the master to slaves. +////////////////////////////////////////////////////////////////////////////////// +private[spark] +sealed trait ToBlockManagerSlave + +// Remove a block from the slaves that have it. This can only be used to remove +// blocks that the master knows about. +private[spark] +case class RemoveBlock(blockId: String) extends ToBlockManagerSlave + + +////////////////////////////////////////////////////////////////////////////////// +// Messages from slaves to the master. +////////////////////////////////////////////////////////////////////////////////// +private[spark] +sealed trait ToBlockManagerMaster + +private[spark] +case class RegisterBlockManager( + blockManagerId: BlockManagerId, + maxMemSize: Long, + sender: ActorRef) + extends ToBlockManagerMaster + +private[spark] +case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + +private[spark] +class UpdateBlockInfo( + var blockManagerId: BlockManagerId, + var blockId: String, + var storageLevel: StorageLevel, + var memSize: Long, + var diskSize: Long) + extends ToBlockManagerMaster + with Externalizable { + + def this() = this(null, null, null, 0, 0) // For deserialization only + + override def writeExternal(out: ObjectOutput) { + blockManagerId.writeExternal(out) + out.writeUTF(blockId) + storageLevel.writeExternal(out) + out.writeInt(memSize.toInt) + out.writeInt(diskSize.toInt) + } + + override def readExternal(in: ObjectInput) { + blockManagerId = BlockManagerId(in) + blockId = in.readUTF() + storageLevel = StorageLevel(in) + memSize = in.readInt() + diskSize = in.readInt() + } +} + +private[spark] +object UpdateBlockInfo { + def apply(blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long): UpdateBlockInfo = { + new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize) + } + + // For pattern-matching + def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) + } +} + +private[spark] +case class GetLocations(blockId: String) extends ToBlockManagerMaster + +private[spark] +case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster + +private[spark] +case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster + +private[spark] +case class RemoveHost(host: String) extends ToBlockManagerMaster + +private[spark] +case object StopBlockManagerMaster extends ToBlockManagerMaster + +private[spark] +case object GetMemoryStatus extends ToBlockManagerMaster + +private[spark] +case object ExpireDeadHosts extends ToBlockManagerMaster diff --git a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala new file mode 100644 index 0000000000..f570cdc52d --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala @@ -0,0 +1,16 @@ +package spark.storage + +import akka.actor.Actor + +import spark.{Logging, SparkException, Utils} + + +/** + * An actor to take commands from the master to execute options. For example, + * this is used to remove blocks from the slave's BlockManager. + */ +class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { + override def receive = { + case RemoveBlock(blockId) => blockManager.removeBlock(blockId) + } +} diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala index 3f234df654..30d7500e01 100644 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/spark/storage/BlockMessage.scala @@ -64,7 +64,7 @@ private[spark] class BlockMessage() { val booleanInt = buffer.getInt() val replication = buffer.getInt() - level = new StorageLevel(booleanInt, replication) + level = StorageLevel(booleanInt, replication) val dataLength = buffer.getInt() data = ByteBuffer.allocate(dataLength) diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala index 096bf8bdd9..8188d3595e 100644 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ b/core/src/main/scala/spark/storage/BlockStore.scala @@ -31,7 +31,12 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging { def getValues(blockId: String): Option[Iterator[Any]] - def remove(blockId: String) + /** + * Remove a block, if it exists. + * @param blockId the block to remove. + * @return True if the block was found and removed, False otherwise. + */ + def remove(blockId: String): Boolean def contains(blockId: String): Boolean diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index b5561479db..7e5b820cbb 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -92,10 +92,13 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) } - override def remove(blockId: String) { + override def remove(blockId: String): Boolean = { val file = getFile(blockId) if (file.exists()) { file.delete() + true + } else { + false } } diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index 02098b82fe..ae88ff0bb1 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -17,7 +17,6 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) private var currentMemory = 0L - // Object used to ensure that only one thread is putting blocks and if necessary, dropping // blocks from the memory store. private val putLock = new Object() @@ -90,7 +89,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def remove(blockId: String) { + override def remove(blockId: String): Boolean = { entries.synchronized { val entry = entries.get(blockId) if (entry != null) { @@ -98,8 +97,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) currentMemory -= entry.size logInfo("Block %s of size %d dropped from memory (free %d)".format( blockId, entry.size, freeMemory)) + true } else { - logWarning("Block " + blockId + " could not be removed as it does not exist") + false } } } diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index c497f03e0c..d1d1c61c1c 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -1,53 +1,60 @@ package spark.storage -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} /** * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for - * commonly useful storage levels. + * commonly useful storage levels. To create your own storage level object, use the factor method + * of the singleton object (`StorageLevel(...)`). */ -class StorageLevel( - var useDisk: Boolean, - var useMemory: Boolean, - var deserialized: Boolean, - var replication: Int = 1) +class StorageLevel private( + private var useDisk_ : Boolean, + private var useMemory_ : Boolean, + private var deserialized_ : Boolean, + private var replication_ : Int = 1) extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - - def this(flags: Int, replication: Int) { + private def this(flags: Int, replication: Int) { this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } def this() = this(false, true, false) // For deserialization + def useDisk = useDisk_ + def useMemory = useMemory_ + def deserialized = deserialized_ + def replication = replication_ + + assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + override def clone(): StorageLevel = new StorageLevel( this.useDisk, this.useMemory, this.deserialized, this.replication) override def equals(other: Any): Boolean = other match { case s: StorageLevel => - s.useDisk == useDisk && + s.useDisk == useDisk && s.useMemory == useMemory && s.deserialized == deserialized && - s.replication == replication + s.replication == replication case _ => false } - + def isValid = ((useMemory || useDisk) && (replication > 0)) def toInt: Int = { var ret = 0 - if (useDisk) { + if (useDisk_) { ret |= 4 } - if (useMemory) { + if (useMemory_) { ret |= 2 } - if (deserialized) { + if (deserialized_) { ret |= 1 } return ret @@ -55,21 +62,27 @@ class StorageLevel( override def writeExternal(out: ObjectOutput) { out.writeByte(toInt) - out.writeByte(replication) + out.writeByte(replication_) } override def readExternal(in: ObjectInput) { val flags = in.readByte() - useDisk = (flags & 4) != 0 - useMemory = (flags & 2) != 0 - deserialized = (flags & 1) != 0 - replication = in.readByte() + useDisk_ = (flags & 4) != 0 + useMemory_ = (flags & 2) != 0 + deserialized_ = (flags & 1) != 0 + replication_ = in.readByte() } + @throws(classOf[IOException]) + private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) + override def toString: String = "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) + + override def hashCode(): Int = toInt * 41 + replication } + object StorageLevel { val NONE = new StorageLevel(false, false, false) val DISK_ONLY = new StorageLevel(true, false, false) @@ -82,4 +95,31 @@ object StorageLevel { val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + + /** Create a new StorageLevel object */ + def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) = + getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication)) + + /** Create a new StorageLevel object from its integer representation */ + def apply(flags: Int, replication: Int) = + getCachedStorageLevel(new StorageLevel(flags, replication)) + + /** Read StorageLevel object from ObjectInput stream */ + def apply(in: ObjectInput) = { + val obj = new StorageLevel() + obj.readExternal(in) + getCachedStorageLevel(obj) + } + + private[spark] + val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]() + + private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = { + if (storageLevelCache.containsKey(level)) { + storageLevelCache.get(level) + } else { + storageLevelCache.put(level, level) + level + } + } } diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala index 5bb5a29cc4..689f07b969 100644 --- a/core/src/main/scala/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/spark/storage/ThreadingTest.scala @@ -58,8 +58,10 @@ private[spark] object ThreadingTest { val startTime = System.currentTimeMillis() manager.get(blockId) match { case Some(retrievedBlock) => - assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList, "Block " + blockId + " did not match") - println("Got block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") + assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList, + "Block " + blockId + " did not match") + println("Got block " + blockId + " in " + + (System.currentTimeMillis - startTime) + " ms") case None => assert(false, "Block " + blockId + " could not be retrieved") } @@ -73,7 +75,9 @@ private[spark] object ThreadingTest { System.setProperty("spark.kryoserializer.buffer.mb", "1") val actorSystem = ActorSystem("test") val serializer = new KryoSerializer - val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true) + val masterIp: String = System.getProperty("spark.master.host", "localhost") + val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt + val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort) val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) @@ -86,6 +90,7 @@ private[spark] object ThreadingTest { actorSystem.shutdown() actorSystem.awaitTermination() println("Everything stopped.") - println("It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.") + println( + "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.") } } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index e67cb0336d..fbd0ff46bf 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -32,6 +32,7 @@ private[spark] object AkkaUtils { akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] akka.actor.provider = "akka.remote.RemoteActorRefProvider" akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" + akka.remote.log-remote-lifecycle-events = on akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = %ds diff --git a/core/src/main/scala/spark/util/IdGenerator.scala b/core/src/main/scala/spark/util/IdGenerator.scala new file mode 100644 index 0000000000..b6e309fe1a --- /dev/null +++ b/core/src/main/scala/spark/util/IdGenerator.scala @@ -0,0 +1,14 @@ +package spark.util + +import java.util.concurrent.atomic.AtomicInteger + +/** + * A util used to get a unique generation ID. This is a wrapper around Java's + * AtomicInteger. An example usage is in BlockManager, where each BlockManager + * instance would start an Akka actor and we use this utility to assign the Akka + * actors unique names. + */ +private[spark] class IdGenerator { + private var id = new AtomicInteger + def next: Int = id.incrementAndGet +} diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala new file mode 100644 index 0000000000..139e21d09e --- /dev/null +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -0,0 +1,44 @@ +package spark.util + +import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} +import java.util.{TimerTask, Timer} +import spark.Logging + + +class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { + + val delaySeconds = MetadataCleaner.getDelaySeconds + val periodSeconds = math.max(10, delaySeconds / 10) + val timer = new Timer(name + " cleanup timer", true) + + val task = new TimerTask { + def run() { + try { + if (delaySeconds > 0) { + cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) + logInfo("Ran metadata cleaner for " + name) + } + } catch { + case e: Exception => logError("Error running cleanup task for " + name, e) + } + } + } + + if (periodSeconds > 0) { + logInfo( + "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and " + + "period of " + periodSeconds + " secs") + timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000) + } + + def cancel() { + timer.cancel() + } +} + + +object MetadataCleaner { + def getDelaySeconds = (System.getProperty("spark.cleaner.delay", "-100").toDouble * 60).toInt + def setDelaySeconds(delay: Long) { System.setProperty("spark.cleaner.delay", delay.toString) } +} + diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala new file mode 100644 index 0000000000..e3f00ea8c7 --- /dev/null +++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala @@ -0,0 +1,62 @@ +package spark.util + +import scala.annotation.tailrec + +import java.io.OutputStream +import java.util.concurrent.TimeUnit._ + +class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream { + val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS) + val CHUNK_SIZE = 8192 + var lastSyncTime = System.nanoTime + var bytesWrittenSinceSync: Long = 0 + + override def write(b: Int) { + waitToWrite(1) + out.write(b) + } + + override def write(bytes: Array[Byte]) { + write(bytes, 0, bytes.length) + } + + @tailrec + override final def write(bytes: Array[Byte], offset: Int, length: Int) { + val writeSize = math.min(length - offset, CHUNK_SIZE) + if (writeSize > 0) { + waitToWrite(writeSize) + out.write(bytes, offset, writeSize) + write(bytes, offset + writeSize, length) + } + } + + override def flush() { + out.flush() + } + + override def close() { + out.close() + } + + @tailrec + private def waitToWrite(numBytes: Int) { + val now = System.nanoTime + val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS) + val rate = bytesWrittenSinceSync.toDouble / elapsedSecs + if (rate < bytesPerSec) { + // It's okay to write; just update some variables and return + bytesWrittenSinceSync += numBytes + if (now > lastSyncTime + SYNC_INTERVAL) { + // Sync interval has passed; let's resync + lastSyncTime = now + bytesWrittenSinceSync = numBytes + } + } else { + // Calculate how much time we should sleep to bring ourselves to the desired rate. + // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala) + val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS) + if (sleepTime > 0) Thread.sleep(sleepTime) + waitToWrite(numBytes) + } + } +} diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala new file mode 100644 index 0000000000..bb7c5c01c8 --- /dev/null +++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala @@ -0,0 +1,93 @@ +package spark.util + +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConversions +import scala.collection.mutable.Map + +/** + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * time stamp along with each key-value pair. Key-value pairs that are older than a particular + * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in + * replacement of scala.collection.mutable.HashMap. + */ +class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging { + val internalMap = new ConcurrentHashMap[A, (B, Long)]() + + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null) Some(value._1) else None + } + + def iterator: Iterator[(A, B)] = { + val jIterator = internalMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) + } + + override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.put(kv._1, (kv._2, currentTime)) + newMap + } + + override def - (key: A): Map[A, B] = { + val newMap = new TimeStampedHashMap[A, B] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.remove(key) + newMap + } + + override def += (kv: (A, B)): this.type = { + internalMap.put(kv._1, (kv._2, currentTime)) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) + } + + override def apply(key: A): B = { + val value = internalMap.get(key) + if (value == null) throw new NoSuchElementException() + value._1 + } + + override def filter(p: ((A, B)) => Boolean): Map[A, B] = { + JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) + } + + override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: ((A, B)) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val kv = (entry.getKey, entry.getValue._1) + f(kv) + } + } + + /** + * Removes old key-value pairs that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue._2 < threshTime) { + logDebug("Removing key " + entry.getKey) + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() + +} diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala new file mode 100644 index 0000000000..5f1cc93752 --- /dev/null +++ b/core/src/main/scala/spark/util/TimeStampedHashSet.scala @@ -0,0 +1,69 @@ +package spark.util + +import scala.collection.mutable.Set +import scala.collection.JavaConversions +import java.util.concurrent.ConcurrentHashMap + + +class TimeStampedHashSet[A] extends Set[A] { + val internalMap = new ConcurrentHashMap[A, Long]() + + def contains(key: A): Boolean = { + internalMap.contains(key) + } + + def iterator: Iterator[A] = { + val jIterator = internalMap.entrySet().iterator() + JavaConversions.asScalaIterator(jIterator).map(_.getKey) + } + + override def + (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet += elem + newSet + } + + override def - (elem: A): Set[A] = { + val newSet = new TimeStampedHashSet[A] + newSet ++= this + newSet -= elem + newSet + } + + override def += (key: A): this.type = { + internalMap.put(key, currentTime) + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def empty: Set[A] = new TimeStampedHashSet[A]() + + override def size(): Int = internalMap.size() + + override def foreach[U](f: (A) => U): Unit = { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + f(iterator.next.getKey) + } + } + + /** + * Removes old values that have timestamp earlier than `threshTime` + */ + def clearOldValues(threshTime: Long) { + val iterator = internalMap.entrySet().iterator() + while(iterator.hasNext) { + val entry = iterator.next() + if (entry.getValue < threshTime) { + iterator.remove() + } + } + } + + private def currentTime: Long = System.currentTimeMillis() +} diff --git a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html index c32ab30401..be69e9bf02 100644 --- a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html +++ b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html @@ -7,6 +7,7 @@ <a href="@worker.webUiAddress">@worker.id</href> </td> <td>@{worker.host}:@{worker.port}</td> + <td>@worker.state</td> <td>@worker.cores (@worker.coresUsed Used)</td> <td>@{Utils.memoryMegabytesToString(worker.memory)} (@{Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</td> diff --git a/core/src/main/twirl/spark/deploy/master/worker_table.scala.html b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html index fad1af41dc..b249411a62 100644 --- a/core/src/main/twirl/spark/deploy/master/worker_table.scala.html +++ b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html @@ -5,6 +5,7 @@ <tr> <th>ID</th> <th>Address</th> + <th>State</th> <th>Cores</th> <th>Memory</th> </tr> |