diff options
author | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-10-19 09:44:32 -0700 |
---|---|---|
committer | Josh Rosen <joshrosen@eecs.berkeley.edu> | 2012-10-19 09:44:32 -0700 |
commit | e21eb6e00ddb77f40ecca9144b7405a293b97573 (patch) | |
tree | 5aa734b0047a738f41c27d272499b7080e9054e1 /core/src/main/scala | |
parent | 9abdfa663360252d2edb346e6b3df4ff94ce78d7 (diff) | |
parent | 63fe4e9d33ec59d93b42507ca9ea286178c12ec4 (diff) | |
download | spark-e21eb6e00ddb77f40ecca9144b7405a293b97573.tar.gz spark-e21eb6e00ddb77f40ecca9144b7405a293b97573.tar.bz2 spark-e21eb6e00ddb77f40ecca9144b7405a293b97573.zip |
Merge tag 'v0.6.0' into python-api
Diffstat (limited to 'core/src/main/scala')
160 files changed, 4917 insertions, 3679 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index d764ffc29d..bacd0ace37 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -3,18 +3,20 @@ package spark import java.io._ import scala.collection.mutable.Map +import scala.collection.generic.Growable /** - * A datatype that can be accumulated, i.e. has an commutative and associative +. + * A datatype that can be accumulated, i.e. has an commutative and associative "add" operation, + * but where the result type, `R`, may be different from the element type being added, `T`. * - * You must define how to add data, and how to merge two of these together. For some datatypes, these might be - * the same operation (eg., a counter). In that case, you might want to use [[spark.AccumulatorParam]]. They won't - * always be the same, though -- eg., imagine you are accumulating a set. You will add items to the set, and you - * will union two sets together. + * You must define how to add data, and how to merge two of these together. For some datatypes, + * such as a counter, these might be the same operation. In that case, you can use the simpler + * [[spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are + * accumulating a set. You will add items to the set, and you will union two sets together. * * @param initialValue initial value of accumulator - * @param param helper object defining how to add elements of type `T` - * @tparam R the full accumulated data + * @param param helper object defining how to add elements of type `R` and `T` + * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ class Accumulable[R, T] ( @@ -43,13 +45,29 @@ class Accumulable[R, T] ( * @param term the other Accumulable that will get merged with this */ def ++= (term: R) { value_ = param.addInPlace(value_, term)} + + /** + * Access the accumulator's current value; only allowed on master. + */ def value = { if (!deserialized) value_ else throw new UnsupportedOperationException("Can't read accumulator value in task") } - private[spark] def localValue = value_ + /** + * Get the current value of this accumulator from within a task. + * + * This is NOT the global value of the accumulator. To get the global value after a + * completed operation on the dataset, call `value`. + * + * The typical use of this method is to directly mutate the local value, eg., to add + * an element to a Set. + */ + def localValue = value_ + /** + * Set the accumulator's value; only allowed on master. + */ def value_= (r: R) { if (!deserialized) value_ = r else throw new UnsupportedOperationException("Can't assign accumulator value in task") @@ -67,31 +85,64 @@ class Accumulable[R, T] ( } /** - * Helper object defining how to accumulate values of a particular type. + * Helper object defining how to accumulate values of a particular type. An implicit + * AccumulableParam needs to be available when you create Accumulables of a specific type. * - * @tparam R the full accumulated data + * @tparam R the full accumulated data (result type) * @tparam T partial data that can be added in */ trait AccumulableParam[R, T] extends Serializable { /** - * Add additional data to the accumulator value. + * Add additional data to the accumulator value. Is allowed to modify and return `r` + * for efficiency (to avoid allocating objects). + * * @param r the current value of the accumulator * @param t the data to be added to the accumulator * @return the new value of the accumulator */ - def addAccumulator(r: R, t: T) : R + def addAccumulator(r: R, t: T): R /** - * Merge two accumulated values together + * Merge two accumulated values together. Is allowed to modify and return the first value + * for efficiency (to avoid allocating objects). + * * @param r1 one set of accumulated data * @param r2 another set of accumulated data * @return both data sets merged together */ def addInPlace(r1: R, r2: R): R + /** + * Return the "zero" (identity) value for an accumulator type, given its initial value. For + * example, if R was a vector of N dimensions, this would return a vector of N zeroes. + */ def zero(initialValue: R): R } +private[spark] +class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] + extends AccumulableParam[R,T] { + + def addAccumulator(growable: R, elem: T): R = { + growable += elem + growable + } + + def addInPlace(t1: R, t2: R): R = { + t1 ++= t2 + t1 + } + + def zero(initialValue: R): R = { + // We need to clone initialValue, but it's hard to specify that R should also be Cloneable. + // Instead we'll serialize it to a buffer and load it back. + val ser = (new spark.JavaSerializer).newInstance() + val copy = ser.deserialize[R](ser.serialize(initialValue)) + copy.clear() // In case it contained stuff + copy + } +} + /** * A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same * as the types of elements being merged. @@ -100,17 +151,18 @@ trait AccumulableParam[R, T] extends Serializable { * @param param helper object defining how to add elements of type `T` * @tparam T result type */ -class Accumulator[T]( - @transient initialValue: T, - param: AccumulatorParam[T]) extends Accumulable[T,T](initialValue, param) +class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T]) + extends Accumulable[T,T](initialValue, param) /** * A simpler version of [[spark.AccumulableParam]] where the only datatype you can add in is the same type - * as the accumulated value + * as the accumulated value. An implicit AccumulatorParam object needs to be available when you create + * Accumulators of a specific type. + * * @tparam T type of value to accumulate */ trait AccumulatorParam[T] extends AccumulableParam[T, T] { - def addAccumulator(t1: T, t2: T) : T = { + def addAccumulator(t1: T, t2: T): T = { addInPlace(t1, t2) } } diff --git a/core/src/main/scala/spark/Aggregator.scala b/core/src/main/scala/spark/Aggregator.scala index 6f99270b1e..b0daa70cfd 100644 --- a/core/src/main/scala/spark/Aggregator.scala +++ b/core/src/main/scala/spark/Aggregator.scala @@ -1,7 +1,17 @@ package spark -class Aggregator[K, V, C] ( +/** A set of functions used to aggregate data. + * + * @param createCombiner function to create the initial value of the aggregation. + * @param mergeValue function to merge a new value into the aggregation result. + * @param mergeCombiners function to merge outputs from multiple mergeValue function. + * @param mapSideCombine whether to apply combiners on map partitions, also + * known as map-side aggregations. When set to false, + * mergeCombiners function is not used. + */ +case class Aggregator[K, V, C] ( val createCombiner: V => C, val mergeValue: (C, V) => C, - val mergeCombiners: (C, C) => C) - extends Serializable + val mergeCombiners: (C, C) => C, + val mapSideCombine: Boolean = true) + diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala index 3431ad2258..4554db2249 100644 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -11,56 +11,49 @@ import spark.storage.BlockManagerId import it.unimi.dsi.fastutil.io.FastBufferedInputStream - -class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { +private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val addresses = SparkEnv.get.mapOutputTracker.getServerAddresses(shuffleId) + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]] - for ((address, index) <- addresses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] + for (((address, size), index) <- statuses.zipWithIndex) { + splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) } - val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map { + val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map { case (address, splits) => - (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId))) + (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) } - try { - for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { - blockOption match { - case Some(block) => { - val values = block - for(value <- values) { - val v = value.asInstanceOf[(K, V)] - func(v._1, v._2) - } - } - case None => { - throw new BlockException(blockId, "Did not get block " + blockId) + for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) { + blockOption match { + case Some(block) => { + val values = block + for(value <- values) { + val v = value.asInstanceOf[(K, V)] + func(v._1, v._2) } } - } - } catch { - case be: BlockException => { - val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r - be.blockId match { - case regex(sId, mId, rId) => { - val address = addresses(mId.toInt) - throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be) - } - case _ => { - throw be + case None => { + val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r + blockId match { + case regex(shufId, mapId, _) => + val address = statuses(mapId.toInt)._1 + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block") } } } } + logDebug("Fetching and merging outputs of shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) } diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala index 6fe0b94297..e8392a194f 100644 --- a/core/src/main/scala/spark/BoundedMemoryCache.scala +++ b/core/src/main/scala/spark/BoundedMemoryCache.scala @@ -9,7 +9,7 @@ import java.util.LinkedHashMap * some cache entries have pointers to a shared object. Nonetheless, this Cache should work well * when most of the space is used by arrays of primitives or of simple classes. */ -class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { +private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { logInfo("BoundedMemoryCache.maxBytes = " + maxBytes) def this() { @@ -104,9 +104,9 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { } // An entry in our map; stores a cached object and its size in bytes -case class Entry(value: Any, size: Long) +private[spark] case class Entry(value: Any, size: Long) -object BoundedMemoryCache { +private[spark] object BoundedMemoryCache { /** * Get maximum cache capacity from system configuration */ diff --git a/core/src/main/scala/spark/Cache.scala b/core/src/main/scala/spark/Cache.scala index 150fe14e2c..20d677a854 100644 --- a/core/src/main/scala/spark/Cache.scala +++ b/core/src/main/scala/spark/Cache.scala @@ -2,9 +2,9 @@ package spark import java.util.concurrent.atomic.AtomicInteger -sealed trait CachePutResponse -case class CachePutSuccess(size: Long) extends CachePutResponse -case class CachePutFailure() extends CachePutResponse +private[spark] sealed trait CachePutResponse +private[spark] case class CachePutSuccess(size: Long) extends CachePutResponse +private[spark] case class CachePutFailure() extends CachePutResponse /** * An interface for caches in Spark, to allow for multiple implementations. Caches are used to store @@ -22,7 +22,7 @@ case class CachePutFailure() extends CachePutResponse * This abstract class handles the creation of key spaces, so that subclasses need only deal with * keys that are unique across modules. */ -abstract class Cache { +private[spark] abstract class Cache { private val nextKeySpaceId = new AtomicInteger(0) private def newKeySpaceId() = nextKeySpaceId.getAndIncrement() @@ -52,7 +52,7 @@ abstract class Cache { /** * A key namespace in a Cache. */ -class KeySpace(cache: Cache, val keySpaceId: Int) { +private[spark] class KeySpace(cache: Cache, val keySpaceId: Int) { def get(datasetId: Any, partition: Int): Any = cache.get((keySpaceId, datasetId), partition) diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 22110832f8..c5db6ce63a 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -15,19 +15,20 @@ import scala.collection.mutable.HashSet import spark.storage.BlockManager import spark.storage.StorageLevel -sealed trait CacheTrackerMessage -case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L) +private[spark] sealed trait CacheTrackerMessage + +private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L) extends CacheTrackerMessage -case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L) +private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L) extends CacheTrackerMessage -case class MemoryCacheLost(host: String) extends CacheTrackerMessage -case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage -case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage -case object GetCacheStatus extends CacheTrackerMessage -case object GetCacheLocations extends CacheTrackerMessage -case object StopCacheTracker extends CacheTrackerMessage - -class CacheTrackerActor extends Actor with Logging { +private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage +private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage +private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage +private[spark] case object GetCacheStatus extends CacheTrackerMessage +private[spark] case object GetCacheLocations extends CacheTrackerMessage +private[spark] case object StopCacheTracker extends CacheTrackerMessage + +private[spark] class CacheTrackerActor extends Actor with Logging { // TODO: Should probably store (String, CacheType) tuples private val locs = new HashMap[Int, Array[List[String]]] @@ -43,8 +44,6 @@ class CacheTrackerActor extends Actor with Logging { def receive = { case SlaveCacheStarted(host: String, size: Long) => - logInfo("Started slave cache (size %s) on %s".format( - Utils.memoryBytesToString(size), host)) slaveCapacity.put(host, size) slaveUsage.put(host, 0) sender ! true @@ -56,22 +55,12 @@ class CacheTrackerActor extends Actor with Logging { case AddedToCache(rddId, partition, host, size) => slaveUsage.put(host, getCacheUsage(host) + size) - logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format( - rddId, partition, host, Utils.memoryBytesToString(size), - Utils.memoryBytesToString(getCacheAvailable(host)))) locs(rddId)(partition) = host :: locs(rddId)(partition) sender ! true case DroppedFromCache(rddId, partition, host, size) => - logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format( - rddId, partition, host, Utils.memoryBytesToString(size), - Utils.memoryBytesToString(getCacheAvailable(host)))) slaveUsage.put(host, getCacheUsage(host) - size) // Do a sanity check to make sure usage is greater than 0. - val usage = getCacheUsage(host) - if (usage < 0) { - logError("Cache usage on %s is negative (%d)".format(host, usage)) - } locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) sender ! true @@ -101,7 +90,7 @@ class CacheTrackerActor extends Actor with Logging { } } -class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager) +private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager) extends Logging { // Tracker actor on the master, or remote reference to it on workers @@ -151,7 +140,6 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl logInfo("Registering RDD ID " + rddId + " with cache") registeredRddIds += rddId communicate(RegisterRDD(rddId, numPartitions)) - logInfo(RegisterRDD(rddId, numPartitions) + " successful") } } } @@ -169,9 +157,8 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl } // For BlockManager.scala only - def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) { + def notifyFromBlockManager(t: AddedToCache) { communicate(t) - logInfo("notifyTheCacheTrackerFromBlockManager successful") } // Get a snapshot of the currently known locations @@ -181,7 +168,7 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl // Gets or computes an RDD split def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = { - val key = "rdd:%d:%d".format(rdd.id, split.index) + val key = "rdd_%d_%d".format(rdd.id, split.index) logInfo("Cache key is " + key) blockManager.get(key) match { case Some(cachedValues) => @@ -221,23 +208,19 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl // TODO: fetch any remote copy of the split that may be available // TODO: also register a listener for when it unloads logInfo("Computing partition " + split) + val elements = new ArrayBuffer[Any] + elements ++= rdd.compute(split) try { - // BlockManager will iterate over results from compute to create RDD - blockManager.put(key, rdd.compute(split), storageLevel, false) + // Try to put this block in the blockManager + blockManager.put(key, elements, storageLevel, true) //future.apply() // Wait for the reply from the cache tracker - blockManager.get(key) match { - case Some(values) => - return values.asInstanceOf[Iterator[T]] - case None => - logWarning("loading partition failed after computing it " + key) - return null - } } finally { loading.synchronized { loading.remove(key) loading.notifyAll() } } + return elements.iterator.asInstanceOf[Iterator[T]] } } diff --git a/core/src/main/scala/spark/ClosureCleaner.scala b/core/src/main/scala/spark/ClosureCleaner.scala index 3b83d23a13..98525b99c8 100644 --- a/core/src/main/scala/spark/ClosureCleaner.scala +++ b/core/src/main/scala/spark/ClosureCleaner.scala @@ -9,7 +9,7 @@ import org.objectweb.asm.{ClassReader, MethodVisitor, Type} import org.objectweb.asm.commons.EmptyVisitor import org.objectweb.asm.Opcodes._ -object ClosureCleaner extends Logging { +private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it private def getClassReader(cls: Class[_]): ClassReader = { new ClassReader(cls.getResourceAsStream( @@ -154,7 +154,7 @@ object ClosureCleaner extends Logging { } } -class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor { +private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor { override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { return new EmptyVisitor { @@ -180,7 +180,7 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor } } -class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor { +private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor { var myName: String = null override def visit(version: Int, access: Int, name: String, sig: String, diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala index 003880c5e8..56e59adeb7 100644 --- a/core/src/main/scala/spark/DaemonThreadFactory.scala +++ b/core/src/main/scala/spark/DaemonThreadFactory.scala @@ -6,9 +6,13 @@ import java.util.concurrent.ThreadFactory * A ThreadFactory that creates daemon threads */ private object DaemonThreadFactory extends ThreadFactory { - override def newThread(r: Runnable): Thread = { - val t = new Thread(r) - t.setDaemon(true) - return t + override def newThread(r: Runnable): Thread = new DaemonThread(r) +} + +private class DaemonThread(r: Runnable = null) extends Thread { + override def run() { + if (r != null) { + r.run() + } } }
\ No newline at end of file diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index c0ff94acc6..dfc7e292b7 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -1,22 +1,53 @@ package spark -abstract class Dependency[T](val rdd: RDD[T], val isShuffle: Boolean) extends Serializable +/** + * Base class for dependencies. + */ +abstract class Dependency[T](val rdd: RDD[T]) extends Serializable -abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd, false) { +/** + * Base class for dependencies where each partition of the parent RDD is used by at most one + * partition of the child RDD. Narrow dependencies allow for pipelined execution. + */ +abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { + /** + * Get the parent partitions for a child partition. + * @param outputPartition a partition of the child RDD + * @return the partitions of the parent RDD that the child partition depends upon + */ def getParents(outputPartition: Int): Seq[Int] } +/** + * Represents a dependency on the output of a shuffle stage. + * @param shuffleId the shuffle id + * @param rdd the parent RDD + * @param aggregator optional aggregator; this allows for map-side combining + * @param partitioner partitioner used to partition the shuffle output + */ class ShuffleDependency[K, V, C]( - val shuffleId: Int, @transient rdd: RDD[(K, V)], - val aggregator: Aggregator[K, V, C], + val aggregator: Option[Aggregator[K, V, C]], val partitioner: Partitioner) - extends Dependency(rdd, true) + extends Dependency(rdd) { + val shuffleId: Int = rdd.context.newShuffleId() +} + +/** + * Represents a one-to-one dependency between partitions of the parent and child RDDs. + */ class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { override def getParents(partitionId: Int) = List(partitionId) } +/** + * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs. + * @param rdd the parent RDD + * @param inStart the start of the range in the parent RDD + * @param outStart the start of the range in the child RDD + * @param length the length of the range + */ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) extends NarrowDependency[T](rdd) { diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala index 1fbf66b7de..b2a0e2b631 100644 --- a/core/src/main/scala/spark/DoubleRDDFunctions.scala +++ b/core/src/main/scala/spark/DoubleRDDFunctions.scala @@ -4,33 +4,49 @@ import spark.partial.BoundedDouble import spark.partial.MeanEvaluator import spark.partial.PartialResult import spark.partial.SumEvaluator - import spark.util.StatCounter /** * Extra functions available on RDDs of Doubles through an implicit conversion. + * Import `spark.SparkContext._` at the top of your program to use these functions. */ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { + /** Add up the elements in this RDD. */ def sum(): Double = { self.reduce(_ + _) } + /** + * Return a [[spark.util.StatCounter]] object that captures the mean, variance and count + * of the RDD's elements in one operation. + */ def stats(): StatCounter = { self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) } + /** Compute the mean of this RDD's elements. */ def mean(): Double = stats().mean + /** Compute the variance of this RDD's elements. */ def variance(): Double = stats().variance + /** Compute the standard deviation of this RDD's elements. */ def stdev(): Double = stats().stdev + /** + * Compute the sample standard deviation of this RDD's elements (which corrects for bias in + * estimating the standard deviation by dividing by N-1 instead of N). + */ + def sampleStdev(): Double = stats().stdev + + /** (Experimental) Approximate operation to return the mean within a timeout. */ def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) val evaluator = new MeanEvaluator(self.splits.size, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } + /** (Experimental) Approximate operation to return the sum within a timeout. */ def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) val evaluator = new SumEvaluator(self.splits.size, confidence) diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala index 55512f4481..a953081d24 100644 --- a/core/src/main/scala/spark/FetchFailedException.scala +++ b/core/src/main/scala/spark/FetchFailedException.scala @@ -2,7 +2,7 @@ package spark import spark.storage.BlockManagerId -class FetchFailedException( +private[spark] class FetchFailedException( val bmAddress: BlockManagerId, val shuffleId: Int, val mapId: Int, diff --git a/core/src/main/scala/spark/HadoopWriter.scala b/core/src/main/scala/spark/HadoopWriter.scala index 12b6a0954c..ffe0f3c4a1 100644 --- a/core/src/main/scala/spark/HadoopWriter.scala +++ b/core/src/main/scala/spark/HadoopWriter.scala @@ -16,9 +16,12 @@ import spark.Logging import spark.SerializableWritable /** - * Saves an RDD using a Hadoop OutputFormat as specified by a JobConf. The JobConf should also - * contain an output key class, an output value class, a filename to write to, etc exactly like in - * a Hadoop job. + * Internal helper class that saves an RDD using a Hadoop OutputFormat. This is only public + * because we need to access this class from the `spark` package to use some package-private Hadoop + * functions, but this class should not be used directly by users. + * + * Saves the RDD using a JobConf, which should contain an output key class, an output value class, + * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ class HadoopWriter(@transient jobConf: JobConf) extends Logging with Serializable { @@ -42,7 +45,7 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with Serializabl setConfParams() val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) + getOutputCommitter().setupJob(jCtxt) } diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala new file mode 100644 index 0000000000..659d17718f --- /dev/null +++ b/core/src/main/scala/spark/HttpFileServer.scala @@ -0,0 +1,47 @@ +package spark + +import java.io.{File, PrintWriter} +import java.net.URL +import scala.collection.mutable.HashMap +import org.apache.hadoop.fs.FileUtil + +private[spark] class HttpFileServer extends Logging { + + var baseDir : File = null + var fileDir : File = null + var jarDir : File = null + var httpServer : HttpServer = null + var serverUri : String = null + + def initialize() { + baseDir = Utils.createTempDir() + fileDir = new File(baseDir, "files") + jarDir = new File(baseDir, "jars") + fileDir.mkdir() + jarDir.mkdir() + logInfo("HTTP File server directory is " + baseDir) + httpServer = new HttpServer(baseDir) + httpServer.start() + serverUri = httpServer.uri + } + + def stop() { + httpServer.stop() + } + + def addFile(file: File) : String = { + addFileToDir(file, fileDir) + return serverUri + "/files/" + file.getName + } + + def addJar(file: File) : String = { + addFileToDir(file, jarDir) + return serverUri + "/jars/" + file.getName + } + + def addFileToDir(file: File, dir: File) : String = { + Utils.copyFile(file, new File(dir, file.getName)) + return dir + "/" + file.getName + } + +}
\ No newline at end of file diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala index 855f2c752f..0196595ba1 100644 --- a/core/src/main/scala/spark/HttpServer.scala +++ b/core/src/main/scala/spark/HttpServer.scala @@ -12,14 +12,14 @@ import org.eclipse.jetty.util.thread.QueuedThreadPool /** * Exception type thrown by HttpServer when it is in the wrong state for an operation. */ -class ServerStateException(message: String) extends Exception(message) +private[spark] class ServerStateException(message: String) extends Exception(message) /** * An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext * as well as classes created by the interpreter when the user types in code. This is just a wrapper * around a Jetty server. */ -class HttpServer(resourceBase: File) extends Logging { +private[spark] class HttpServer(resourceBase: File) extends Logging { private var server: Server = null private var port: Int = -1 diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala index d11ba5167d..b04a27d073 100644 --- a/core/src/main/scala/spark/JavaSerializer.scala +++ b/core/src/main/scala/spark/JavaSerializer.scala @@ -3,16 +3,17 @@ package spark import java.io._ import java.nio.ByteBuffer +import serializer.{Serializer, SerializerInstance, DeserializationStream, SerializationStream} import spark.util.ByteBufferInputStream -class JavaSerializationStream(out: OutputStream) extends SerializationStream { +private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream { val objOut = new ObjectOutputStream(out) - def writeObject[T](t: T) { objOut.writeObject(t) } + def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this } def flush() { objOut.flush() } def close() { objOut.close() } } -class JavaDeserializationStream(in: InputStream, loader: ClassLoader) +private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) extends DeserializationStream { val objIn = new ObjectInputStream(in) { override def resolveClass(desc: ObjectStreamClass) = @@ -23,7 +24,7 @@ extends DeserializationStream { def close() { objIn.close() } } -class JavaSerializerInstance extends SerializerInstance { +private[spark] class JavaSerializerInstance extends SerializerInstance { def serialize[T](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) @@ -57,6 +58,9 @@ class JavaSerializerInstance extends SerializerInstance { } } +/** + * A Spark serializer that uses Java's built-in serialization. + */ class JavaSerializer extends Serializer { def newInstance(): SerializerInstance = new JavaSerializerInstance } diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 65d0532bd5..44b630e478 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -10,15 +10,18 @@ import scala.collection.mutable import com.esotericsoftware.kryo._ import com.esotericsoftware.kryo.{Serializer => KSerializer} import com.esotericsoftware.kryo.serialize.ClassSerializer +import com.esotericsoftware.kryo.serialize.SerializableSerializer import de.javakaffee.kryoserializers.KryoReflectionFactorySupport +import serializer.{SerializerInstance, DeserializationStream, SerializationStream} +import spark.broadcast._ import spark.storage._ /** * Zig-zag encoder used to write object sizes to serialization streams. * Based on Kryo's integer encoder. */ -object ZigZag { +private[spark] object ZigZag { def writeInt(n: Int, out: OutputStream) { var value = n if ((value & ~0x7F) == 0) { @@ -66,22 +69,25 @@ object ZigZag { } } +private[spark] class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream) extends SerializationStream { val channel = Channels.newChannel(out) - def writeObject[T](t: T) { + def writeObject[T](t: T): SerializationStream = { kryo.writeClassAndObject(threadBuffer, t) ZigZag.writeInt(threadBuffer.position(), out) threadBuffer.flip() channel.write(threadBuffer) threadBuffer.clear() + this } def flush() { out.flush() } def close() { out.close() } } +private[spark] class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream) extends DeserializationStream { def readObject[T](): T = { @@ -92,7 +98,7 @@ extends DeserializationStream { def close() { in.close() } } -class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { +private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { val kryo = ks.kryo val threadBuffer = ks.threadBuffer.get() val objectBuffer = ks.objectBuffer.get() @@ -153,13 +159,21 @@ class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { } } -// Used by clients to register their own classes +/** + * Interface implemented by clients to register their classes with Kryo when using Kryo + * serialization. + */ trait KryoRegistrator { def registerClasses(kryo: Kryo): Unit } -class KryoSerializer extends Serializer with Logging { - val kryo = createKryo() +/** + * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. + */ +class KryoSerializer extends spark.serializer.Serializer with Logging { + // Make this lazy so that it only gets called once we receive our first task on each executor, + // so we can pull out any custom Kryo registrator from the user's JARs. + lazy val kryo = createKryo() val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024 @@ -190,8 +204,8 @@ class KryoSerializer extends Serializer with Logging { (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1), None, ByteBuffer.allocate(1), - StorageLevel.MEMORY_ONLY_DESER, - PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER), + StorageLevel.MEMORY_ONLY, + PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), GotBlock("1", ByteBuffer.allocate(1)), GetBlock("1") ) @@ -203,6 +217,10 @@ class KryoSerializer extends Serializer with Logging { kryo.register(classOf[Class[_]], new ClassSerializer(kryo)) kryo.setRegistrationOptional(true) + // Allow sending SerializableWritable + kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer()) + kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer()) + // Register some commonly used Scala singleton objects. Because these // are singletons, we must return the exact same local object when we // deserialize rather than returning a clone as FieldSerializer would. @@ -250,7 +268,8 @@ class KryoSerializer extends Serializer with Logging { val regCls = System.getProperty("spark.kryo.registrator") if (regCls != null) { logInfo("Running user registrator: " + regCls) - val reg = Class.forName(regCls).newInstance().asInstanceOf[KryoRegistrator] + val classLoader = Thread.currentThread.getContextClassLoader + val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator] reg.registerClasses(kryo) } kryo diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 69935b86de..90bae26202 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -15,7 +15,7 @@ trait Logging { private var log_ : Logger = null // Method to get or create the logger for this object - def log: Logger = { + protected def log: Logger = { if (log_ == null) { var className = this.getClass.getName // Ignore trailing $'s in the class names for Scala objects @@ -28,48 +28,48 @@ trait Logging { } // Log methods that take only a String - def logInfo(msg: => String) { + protected def logInfo(msg: => String) { if (log.isInfoEnabled) log.info(msg) } - def logDebug(msg: => String) { + protected def logDebug(msg: => String) { if (log.isDebugEnabled) log.debug(msg) } - def logTrace(msg: => String) { + protected def logTrace(msg: => String) { if (log.isTraceEnabled) log.trace(msg) } - def logWarning(msg: => String) { + protected def logWarning(msg: => String) { if (log.isWarnEnabled) log.warn(msg) } - def logError(msg: => String) { + protected def logError(msg: => String) { if (log.isErrorEnabled) log.error(msg) } // Log methods that take Throwables (Exceptions/Errors) too - def logInfo(msg: => String, throwable: Throwable) { + protected def logInfo(msg: => String, throwable: Throwable) { if (log.isInfoEnabled) log.info(msg, throwable) } - def logDebug(msg: => String, throwable: Throwable) { + protected def logDebug(msg: => String, throwable: Throwable) { if (log.isDebugEnabled) log.debug(msg, throwable) } - def logTrace(msg: => String, throwable: Throwable) { + protected def logTrace(msg: => String, throwable: Throwable) { if (log.isTraceEnabled) log.trace(msg, throwable) } - def logWarning(msg: => String, throwable: Throwable) { + protected def logWarning(msg: => String, throwable: Throwable) { if (log.isWarnEnabled) log.warn(msg, throwable) } - def logError(msg: => String, throwable: Throwable) { + protected def logError(msg: => String, throwable: Throwable) { if (log.isErrorEnabled) log.error(msg, throwable) } // Method for ensuring that logging is initialized, to avoid having multiple // threads do it concurrently (as SLF4J initialization is not thread safe). - def initLogging() { log } + protected def initLogging() { log } } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 0c97cd44a1..45441aa5e5 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -1,5 +1,6 @@ package spark +import java.io._ import java.util.concurrent.ConcurrentHashMap import akka.actor._ @@ -10,20 +11,23 @@ import akka.util.Duration import akka.util.Timeout import akka.util.duration._ +import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet +import scheduler.MapStatus import spark.storage.BlockManagerId +import java.util.zip.{GZIPInputStream, GZIPOutputStream} -sealed trait MapOutputTrackerMessage -case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage -case object StopMapOutputTracker extends MapOutputTrackerMessage +private[spark] sealed trait MapOutputTrackerMessage +private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String) + extends MapOutputTrackerMessage +private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]]) -extends Actor with Logging { +private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging { def receive = { - case GetMapOutputLocations(shuffleId: Int) => - logInfo("Asked to get map output locations for shuffle " + shuffleId) - sender ! bmAddresses.get(shuffleId) + case GetMapOutputStatuses(shuffleId: Int, requester: String) => + logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester) + sender ! tracker.getSerializedLocations(shuffleId) case StopMapOutputTracker => logInfo("MapOutputTrackerActor stopped!") @@ -32,22 +36,26 @@ extends Actor with Logging { } } -class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging { +private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging { val ip: String = System.getProperty("spark.master.host", "localhost") val port: Int = System.getProperty("spark.master.port", "7077").toInt val actorName: String = "MapOutputTracker" val timeout = 10.seconds - private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] + var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. private var generation: Long = 0 - private var generationLock = new java.lang.Object + private val generationLock = new java.lang.Object + + // Cache a serialized version of the output statuses for each shuffle to send them out faster + var cacheGeneration = generation + val cachedSerializedStatuses = new HashMap[Int, Array[Byte]] var trackerActor: ActorRef = if (isMaster) { - val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(bmAddresses)), name = actorName) + val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName) logInfo("Registered MapOutputTrackerActor actor") actor } else { @@ -75,31 +83,34 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (bmAddresses.get(shuffleId) != null) { + if (mapStatuses.get(shuffleId) != null) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - bmAddresses.put(shuffleId, new Array[BlockManagerId](numMaps)) + mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)) } - def registerMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = bmAddresses.get(shuffleId) + def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { + var array = mapStatuses.get(shuffleId) array.synchronized { - array(mapId) = bmAddress + array(mapId) = status } } - def registerMapOutputs(shuffleId: Int, locs: Array[BlockManagerId], changeGeneration: Boolean = false) { - bmAddresses.put(shuffleId, Array[BlockManagerId]() ++ locs) + def registerMapOutputs( + shuffleId: Int, + statuses: Array[MapStatus], + changeGeneration: Boolean = false) { + mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) if (changeGeneration) { incrementGeneration() } } def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = bmAddresses.get(shuffleId) + var array = mapStatuses.get(shuffleId) if (array != null) { array.synchronized { - if (array(mapId) == bmAddress) { + if (array(mapId).address == bmAddress) { array(mapId) = null } } @@ -112,11 +123,11 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg // Remembers which map output locations are currently being fetched on a worker val fetching = new HashSet[Int] - // Called on possibly remote nodes to get the server URIs for a given shuffle - def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = { - val locs = bmAddresses.get(shuffleId) - if (locs == null) { - logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them") + // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle + def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { + val statuses = mapStatuses.get(shuffleId) + if (statuses == null) { + logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -124,33 +135,38 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg try { fetching.wait() } catch { - case _ => + case e: InterruptedException => } } - return bmAddresses.get(shuffleId) + return mapStatuses.get(shuffleId).map(status => + (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId)))) } else { fetching += shuffleId } } // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val fetched = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[BlockManagerId]] + val host = System.getProperty("spark.hostname", Utils.localHostName) + val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]] + val fetchedStatuses = deserializeStatuses(fetchedBytes) logInfo("Got the output locations") - bmAddresses.put(shuffleId, fetched) + mapStatuses.put(shuffleId, fetchedStatuses) fetching.synchronized { fetching -= shuffleId fetching.notifyAll() } - return fetched + return fetchedStatuses.map(s => + (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) } else { - return locs + return statuses.map(s => + (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) } } def stop() { communicate(StopMapOutputTracker) - bmAddresses.clear() + mapStatuses.clear() trackerActor = null } @@ -158,6 +174,7 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg def incrementGeneration() { generationLock.synchronized { generation += 1 + logDebug("Increasing generation to " + generation) } } @@ -175,9 +192,83 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg generationLock.synchronized { if (newGen > generation) { logInfo("Updating generation to " + newGen + " and clearing cache") - bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] + mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]] generation = newGen } } } + + def getSerializedLocations(shuffleId: Int): Array[Byte] = { + var statuses: Array[MapStatus] = null + var generationGotten: Long = -1 + generationLock.synchronized { + if (generation > cacheGeneration) { + cachedSerializedStatuses.clear() + cacheGeneration = generation + } + cachedSerializedStatuses.get(shuffleId) match { + case Some(bytes) => + return bytes + case None => + statuses = mapStatuses.get(shuffleId) + generationGotten = generation + } + } + // If we got here, we failed to find the serialized locations in the cache, so we pulled + // out a snapshot of the locations as "locs"; let's serialize and return that + val bytes = serializeStatuses(statuses) + logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) + // Add them into the table only if the generation hasn't changed while we were working + generationLock.synchronized { + if (generation == generationGotten) { + cachedSerializedStatuses(shuffleId) = bytes + } + } + return bytes + } + + // Serialize an array of map output locations into an efficient byte format so that we can send + // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will + // generally be pretty compressible because many map outputs will be on the same hostname. + def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { + val out = new ByteArrayOutputStream + val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) + objOut.writeObject(statuses) + objOut.close() + out.toByteArray + } + + // Opposite of serializeStatuses. + def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = { + val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) + objIn.readObject().asInstanceOf[Array[MapStatus]] + } +} + +private[spark] object MapOutputTracker { + private val LOG_BASE = 1.1 + + /** + * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. + * We do this by encoding the log base 1.1 of the size as an integer, which can support + * sizes up to 35 GB with at most 10% error. + */ + def compressSize(size: Long): Byte = { + if (size <= 1L) { + 0 + } else { + math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte + } + } + + /** + * Decompress an 8-bit encoded block size, using the reverse operation of compressSize. + */ + def decompressSize(compressedSize: Byte): Long = { + if (compressedSize == 0) { + 1 + } else { + math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong + } + } } diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 64018f8c6b..0240fd95c7 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -1,11 +1,10 @@ package spark import java.io.EOFException -import java.net.URL import java.io.ObjectInputStream +import java.net.URL +import java.util.{Date, HashMap => JHashMap} import java.util.concurrent.atomic.AtomicLong -import java.util.{HashMap => JHashMap} -import java.util.Date import java.text.SimpleDateFormat import scala.collection.Map @@ -35,26 +34,53 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob} import org.apache.hadoop.mapreduce.TaskAttemptID import org.apache.hadoop.mapreduce.TaskAttemptContext -import spark.SparkContext._ import spark.partial.BoundedDouble import spark.partial.PartialResult +import spark.rdd._ +import spark.SparkContext._ /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. + * Import `spark.SparkContext._` at the top of your program to use these functions. */ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( self: RDD[(K, V)]) extends Logging with Serializable { + /** + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C + * Note that V and C can be different -- for example, one might group an RDD of type + * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: + * + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. + * + * In addition, users can control the partitioning of the output RDD, and whether to perform + * map-side aggregation (if a mapper can produce multiple items with the same key). + */ def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, - partitioner: Partitioner): RDD[(K, C)] = { - val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) - new ShuffledRDD(self, aggregator, partitioner) + partitioner: Partitioner, + mapSideCombine: Boolean = true): RDD[(K, C)] = { + val aggregator = + if (mapSideCombine) { + new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) + } else { + // Don't apply map-side combiner. + // A sanity check to make sure mergeCombiners is not defined. + assert(mergeCombiners == null) + new Aggregator[K, V, C](createCombiner, mergeValue, null, false) + } + new ShuffledAggregatedRDD(self, aggregator, partitioner) } + /** + * Simplified version of combineByKey that hash-partitions the output RDD. + */ def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, @@ -62,10 +88,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits)) } + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. + */ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = { combineByKey[V]((v: V) => v, func, func, partitioner) } - + + /** + * Merge the values for each key using an associative reduce function, but return the results + * immediately to the master as a Map. This will also perform the merging locally on each mapper + * before sending results to a reducer, similarly to a "combiner" in MapReduce. + */ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { val map = new JHashMap[K, V] @@ -87,22 +123,34 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( self.mapPartitions(reducePartition).reduce(mergeMaps) } - // Alias for backwards compatibility + /** Alias for reduceByKeyLocally */ def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) - // TODO: This should probably be a distributed version + /** Count the number of elements for each key, and return the result to the master as a Map. */ def countByKey(): Map[K, Long] = self.map(_._1).countByValue() - // TODO: This should probably be a distributed version + /** + * (Experimental) Approximate version of countByKey that can return a partial result if it does + * not finish within a timeout. + */ def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[Map[K, BoundedDouble]] = { self.map(_._1).countByValueApprox(timeout, confidence) } + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits. + */ def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = { reduceByKey(new HashPartitioner(numSplits), func) } + /** + * Group the values for each key in the RDD into a single sequence. Allows controlling the + * partitioning of the resulting key-value pair RDD by passing a Partitioner. + */ def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = { def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v @@ -112,19 +160,39 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( bufs.asInstanceOf[RDD[(K, Seq[V])]] } + /** + * Group the values for each key in the RDD into a single sequence. Hash-partitions the + * resulting RDD with into `numSplits` partitions. + */ def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = { groupByKey(new HashPartitioner(numSplits)) } - def partitionBy(partitioner: Partitioner): RDD[(K, V)] = { - def createCombiner(v: V) = ArrayBuffer(v) - def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v - def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2 - val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, mergeCombiners _, partitioner) - bufs.flatMapValues(buf => buf) + /** + * Return a copy of the RDD partitioned using the specified partitioner. If `mapSideCombine` + * is true, Spark will group values of the same key together on the map side before the + * repartitioning, to only send each key over the network once. If a large number of + * duplicated keys are expected, and the size of the keys are large, `mapSideCombine` should + * be set to true. + */ + def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = { + if (mapSideCombine) { + def createCombiner(v: V) = ArrayBuffer(v) + def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v + def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2 + val bufs = combineByKey[ArrayBuffer[V]]( + createCombiner _, mergeValue _, mergeCombiners _, partitioner) + bufs.flatMapValues(buf => buf) + } else { + new RepartitionShuffledRDD(self, partitioner) + } } + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. + */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => @@ -132,6 +200,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } } + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to + * partition the output RDD. + */ def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = { this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => @@ -143,6 +217,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } } + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to + * partition the output RDD. + */ def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner) : RDD[(K, (Option[V], W))] = { this.cogroup(other, partitioner).flatMapValues { @@ -155,56 +235,117 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } } - def combineByKey[C](createCombiner: V => C, - mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) : RDD[(K, C)] = { + /** + * Simplified version of combineByKey that hash-partitions the resulting RDD using the default + * parallelism level. + */ + def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) + : RDD[(K, C)] = { combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level. + */ def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { reduceByKey(defaultPartitioner(self), func) } + /** + * Group the values for each key in the RDD into a single sequence. Hash-partitions the + * resulting RDD with the default parallelism level. + */ def groupByKey(): RDD[(K, Seq[V])] = { groupByKey(defaultPartitioner(self)) } + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Performs a hash join across the cluster. + */ def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { join(other, defaultPartitioner(self, other)) } + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Performs a hash join across the cluster. + */ def join[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, W))] = { join(other, new HashPartitioner(numSplits)) } + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output + * using the default level of parallelism. + */ def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = { leftOuterJoin(other, defaultPartitioner(self, other)) } + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output + * into `numSplits` partitions. + */ def leftOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, Option[W]))] = { leftOuterJoin(other, new HashPartitioner(numSplits)) } + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting + * RDD using the default parallelism level. + */ def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = { rightOuterJoin(other, defaultPartitioner(self, other)) } + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting + * RDD into the given number of partitions. + */ def rightOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Option[V], W))] = { rightOuterJoin(other, new HashPartitioner(numSplits)) } + /** + * Return the key-value pairs in this RDD to the master as a Map. + */ def collectAsMap(): Map[K, V] = HashMap(self.collect(): _*) - + + /** + * Pass each value in the key-value pair RDD through a map function without changing the keys; + * this also retains the original RDD's partitioning. + */ def mapValues[U](f: V => U): RDD[(K, U)] = { val cleanF = self.context.clean(f) new MappedValuesRDD(self, cleanF) } - + + /** + * Pass each value in the key-value pair RDD through a flatMap function without changing the + * keys; this also retains the original RDD's partitioning. + */ def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = { val cleanF = self.context.clean(f) new FlatMappedValuesRDD(self, cleanF) } - + + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { val cg = new CoGroupedRDD[K]( Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]), @@ -215,12 +356,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) } } - + + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { val cg = new CoGroupedRDD[K]( Seq(self.asInstanceOf[RDD[(_, _)]], - other1.asInstanceOf[RDD[(_, _)]], + other1.asInstanceOf[RDD[(_, _)]], other2.asInstanceOf[RDD[(_, _)]]), partitioner) val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) @@ -230,28 +375,46 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } } + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { cogroup(other, defaultPartitioner(self, other)) } + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ def cogroup[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Seq[V], Seq[W]))] = { cogroup(other, new HashPartitioner(numSplits)) } + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numSplits: Int) : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { cogroup(other1, other2, new HashPartitioner(numSplits)) } + /** Alias for cogroup. */ def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { cogroup(other, defaultPartitioner(self, other)) } + /** Alias for cogroup. */ def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { cogroup(other1, other2, defaultPartitioner(self, other1, other2)) @@ -268,6 +431,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( return new HashPartitioner(self.context.defaultParallelism) } + /** + * Return the list of values in the RDD for key `key`. This operation is done efficiently if the + * RDD has a known partitioner by only searching the partition that the key maps to. + */ def lookup(key: K): Seq[V] = { self.partitioner match { case Some(p) => @@ -286,14 +453,26 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } } + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. + */ def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } - + + /** + * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` + * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. + */ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) } + /** + * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` + * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. + */ def saveAsNewAPIHadoopFile( path: String, keyClass: Class[_], @@ -302,6 +481,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration) } + /** + * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` + * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. + */ def saveAsNewAPIHadoopFile( path: String, keyClass: Class[_], @@ -349,6 +532,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( jobCommitter.cleanupJob(jobTaskContext) } + /** + * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class + * supporting the key and value types K and V in this RDD. + */ def saveAsHadoopFile( path: String, keyClass: Class[_], @@ -363,7 +550,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf)) saveAsHadoopDataset(conf) } - + + /** + * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for + * that storage system. The JobConf should set an OutputFormat and any output paths required + * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop + * MapReduce job. + */ def saveAsHadoopDataset(conf: JobConf) { val outputFormatClass = conf.getOutputFormat val keyClass = conf.getOutputKeyClass @@ -377,7 +570,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( if (valueClass == null) { throw new SparkException("Output value class not set") } - + logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")") val writer = new HadoopWriter(conf) @@ -390,14 +583,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( writer.setup(context.stageId, context.splitId, attemptNumber) writer.open() - + var count = 0 while(iter.hasNext) { val record = iter.next count += 1 writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) } - + writer.close() writer.commit() } @@ -406,35 +599,33 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( writer.cleanup() } - def getKeyClass() = implicitly[ClassManifest[K]].erasure + private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure - def getValueClass() = implicitly[ClassManifest[V]].erasure + private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure } +/** + * Extra functions available on RDDs of (key, value) pairs where the key is sortable through + * an implicit conversion. Import `spark.SparkContext._` at the top of your program to use these + * functions. They will work with any key type that has a `scala.math.Ordered` implementation. + */ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( self: RDD[(K, V)]) - extends Logging + extends Logging with Serializable { - def sortByKey(ascending: Boolean = true): RDD[(K,V)] = { - val rangePartitionedRDD = self.partitionBy(new RangePartitioner(self.splits.size, self, ascending)) - new SortedRDD(rangePartitionedRDD, ascending) + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ + def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = { + new ShuffledSortedRDD(self, ascending, numSplits) } } -class SortedRDD[K <% Ordered[K], V](prev: RDD[(K, V)], ascending: Boolean) - extends RDD[(K, V)](prev.context) { - - override def splits = prev.splits - override val partitioner = prev.partitioner - override val dependencies = List(new OneToOneDependency(prev)) - - override def compute(split: Split) = { - prev.iterator(split).toArray - .sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1).iterator - } -} - +private[spark] class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) { override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) @@ -442,9 +633,10 @@ class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)] override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))} } +private[spark] class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]) extends RDD[(K, U)](prev.context) { - + override def splits = prev.splits override val dependencies = List(new OneToOneDependency(prev)) override val partitioner = prev.partitioner @@ -454,6 +646,6 @@ class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U] } } -object Manifests { +private[spark] object Manifests { val seqSeqManifest = classManifest[Seq[Seq[_]]] } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index d79007ab40..9b57ae3b4f 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -3,7 +3,7 @@ package spark import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer -class ParallelCollectionSplit[T: ClassManifest]( +private[spark] class ParallelCollectionSplit[T: ClassManifest]( val rddId: Long, val slice: Int, values: Seq[T]) @@ -21,7 +21,7 @@ class ParallelCollectionSplit[T: ClassManifest]( override val index: Int = slice } -class ParallelCollection[T: ClassManifest]( +private[spark] class ParallelCollection[T: ClassManifest]( sc: SparkContext, @transient data: Seq[T], numSlices: Int) diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index 643541429f..b71021a082 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -1,10 +1,17 @@ package spark +/** + * An object that defines how the elements in a key-value pair RDD are partitioned by key. + * Maps each key to a partition ID, from 0 to `numPartitions - 1`. + */ abstract class Partitioner extends Serializable { def numPartitions: Int def getPartition(key: Any): Int } +/** + * A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. + */ class HashPartitioner(partitions: Int) extends Partitioner { def numPartitions = partitions @@ -29,6 +36,10 @@ class HashPartitioner(partitions: Int) extends Partitioner { } } +/** + * A [[spark.Partitioner]] that partitions sortable records by range into roughly equal ranges. + * Determines the ranges by sampling the RDD passed in. + */ class RangePartitioner[K <% Ordered[K]: ClassManifest, V]( partitions: Int, @transient rdd: RDD[(K,V)], @@ -41,9 +52,9 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V]( Array() } else { val rddSize = rdd.count() - val maxSampleSize = partitions * 10.0 + val maxSampleSize = partitions * 20.0 val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) - val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _) + val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _) if (rddSample.length == 0) { Array() } else { diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 3fe8e8a4bf..ddb420efff 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -31,51 +31,86 @@ import spark.partial.BoundedDouble import spark.partial.CountEvaluator import spark.partial.GroupedCountEvaluator import spark.partial.PartialResult +import spark.rdd.BlockRDD +import spark.rdd.CartesianRDD +import spark.rdd.FilteredRDD +import spark.rdd.FlatMappedRDD +import spark.rdd.GlommedRDD +import spark.rdd.MappedRDD +import spark.rdd.MapPartitionsRDD +import spark.rdd.MapPartitionsWithSplitRDD +import spark.rdd.PipedRDD +import spark.rdd.SampledRDD +import spark.rdd.UnionRDD import spark.storage.StorageLevel import SparkContext._ /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, - * partitioned collection of elements that can be operated on in parallel. + * partitioned collection of elements that can be operated on in parallel. This class contains the + * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition, + * [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such + * as `groupByKey` and `join`; [[spark.DoubleRDDFunctions]] contains operations available only on + * RDDs of Doubles; and [[spark.SequenceFileRDDFunctions]] contains operations available on RDDs + * that can be saved as SequenceFiles. These operations are automatically available on any RDD of + * the right type (e.g. RDD[(Int, Int)] through implicit conversions when you + * `import spark.SparkContext._`. * - * Each RDD is characterized by five main properties: - * - A list of splits (partitions) - * - A function for computing each split - * - A list of dependencies on other RDDs - * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned) - * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for - * HDFS) + * Internally, each RDD is characterized by five main properties: * - * All the scheduling and execution in Spark is done based on these methods, allowing each RDD to - * implement its own way of computing itself. + * - A list of splits (partitions) + * - A function for computing each split + * - A list of dependencies on other RDDs + * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned) + * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for + * an HDFS file) * - * This class also contains transformation methods available on all RDDs (e.g. map and filter). In - * addition, PairRDDFunctions contains extra methods available on RDDs of key-value pairs, and - * SequenceFileRDDFunctions contains extra methods for saving RDDs to Hadoop SequenceFiles. + * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD + * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for + * reading data from a new storage system) by overriding these functions. Please refer to the + * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details + * on RDD internals. */ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable { - // Methods that must be implemented by subclasses + // Methods that must be implemented by subclasses: + + /** Set of partitions in this RDD. */ def splits: Array[Split] + + /** Function for computing a given partition. */ def compute(split: Split): Iterator[T] + + /** How this RDD depends on any parent RDDs. */ @transient val dependencies: List[Dependency[_]] + + // Methods available on all RDDs: + + /** Record user function generating this RDD. */ + private[spark] val origin = Utils.getSparkCallSite - // Optionally overridden by subclasses to specify how they are partitioned + /** Optionally overridden by subclasses to specify how they are partitioned. */ val partitioner: Option[Partitioner] = None - // Optionally overridden by subclasses to specify placement preferences + /** Optionally overridden by subclasses to specify placement preferences. */ def preferredLocations(split: Split): Seq[String] = Nil + /** The [[spark.SparkContext]] that this RDD was created on. */ def context = sc + + private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] - // Get a unique ID for this RDD + /** A unique ID for this RDD (within its SparkContext). */ val id = sc.newRddId() // Variables relating to persistence private var storageLevel: StorageLevel = StorageLevel.NONE - // Change this RDD's storage level + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. Can only be called once on each RDD. + */ def persist(newLevel: StorageLevel): RDD[T] = { // TODO: Handle changes of StorageLevel if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { @@ -86,22 +121,23 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial this } - // Turn on the default caching level for this RDD - def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER) + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY) - // Turn on the default caching level for this RDD + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): RDD[T] = persist() + /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = { + private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = { if (!level.useDisk && level.replication < 2) { throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") } // This is a hack. Ideally this should re-use the code used by the CacheTracker // to generate the key. - def getSplitKey(split: Split) = "rdd:%d:%d".format(this.id, split.index) + def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index) persist(level) sc.runJob(this, (iter: Iterator[T]) => {} ) @@ -113,7 +149,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial } } - // Read this RDD; will read from cache if applicable, or otherwise compute + /** + * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. + * This should ''not'' be called by users directly, but is available for implementors of custom + * subclasses of RDD. + */ final def iterator(split: Split): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) @@ -124,15 +164,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial // Transformations (return a new RDD) + /** + * Return a new RDD by applying a function to all elements of this RDD. + */ def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f)) - + + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] = new FlatMappedRDD(this, sc.clean(f)) - + + /** + * Return a new RDD containing only the elements that satisfy a predicate. + */ def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f)) - def distinct(): RDD[T] = map(x => (x, "")).reduceByKey((x, y) => x).map(_._1) + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(numSplits: Int = splits.size): RDD[T] = + map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1) + /** + * Return a sampled subset of this RDD. + */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = new SampledRDD(this, withReplacement, fraction, seed) @@ -143,8 +200,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial var initialCount = count() var maxSelected = 0 - if (initialCount > Integer.MAX_VALUE) { - maxSelected = Integer.MAX_VALUE + if (initialCount > Integer.MAX_VALUE - 1) { + maxSelected = Integer.MAX_VALUE - 1 } else { maxSelected = initialCount.toInt } @@ -159,56 +216,108 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial total = num } - var samples = this.sample(withReplacement, fraction, seed).collect() + val rand = new Random(seed) + var samples = this.sample(withReplacement, fraction, rand.nextInt).collect() while (samples.length < total) { - samples = this.sample(withReplacement, fraction, seed).collect() + samples = this.sample(withReplacement, fraction, rand.nextInt).collect() } - val arr = samples.take(total) - - return arr + Utils.randomizeInPlace(samples, rand).take(total) } + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other)) + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ def ++(other: RDD[T]): RDD[T] = this.union(other) + /** + * Return an RDD created by coalescing all elements within each partition into an array. + */ def glom(): RDD[Array[T]] = new GlommedRDD(this) + /** + * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of + * elements (a, b) where a is in `this` and b is in `other`. + */ def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) + /** + * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements + * mapping to that key. + */ def groupBy[K: ClassManifest](f: T => K, numSplits: Int): RDD[(K, Seq[T])] = { val cleanF = sc.clean(f) this.map(t => (cleanF(t), t)).groupByKey(numSplits) } + /** + * Return an RDD of grouped items. + */ def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = groupBy[K](f, sc.defaultParallelism) + /** + * Return an RDD created by piping elements to a forked external process. + */ def pipe(command: String): RDD[String] = new PipedRDD(this, command) + /** + * Return an RDD created by piping elements to a forked external process. + */ def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command) + /** + * Return an RDD created by piping elements to a forked external process. + */ def pipe(command: Seq[String], env: Map[String, String]): RDD[String] = new PipedRDD(this, command, env) + /** + * Return a new RDD by applying a function to each partition of this RDD. + */ def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] = new MapPartitionsRDD(this, sc.clean(f)) + /** + * Return a new RDD by applying a function to each partition of this RDD, while tracking the index + * of the original partition. + */ + def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] = + new MapPartitionsWithSplitRDD(this, sc.clean(f)) + // Actions (launch a job to return a value to the user program) - + + /** + * Applies a function f to all elements of this RDD. + */ def foreach(f: T => Unit) { val cleanF = sc.clean(f) sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) } + /** + * Return an array that contains all of the elements in this RDD. + */ def collect(): Array[T] = { val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) Array.concat(results: _*) } + /** + * Return an array that contains all of the elements in this RDD. + */ def toArray(): Array[T] = collect() + /** + * Reduces the elements of this RDD using the specified associative binary operator. + */ def reduce(f: (T, T) => T): T = { val cleanF = sc.clean(f) val reducePartition: Iterator[T] => Option[T] = iter => { @@ -257,7 +366,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial (iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)) return results.fold(zeroValue)(cleanCombOp) } - + + /** + * Return the number of elements in the RDD. + */ def count(): Long = { sc.runJob(this, (iter: Iterator[T]) => { var result = 0L @@ -270,7 +382,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial } /** - * Approximate version of count() that returns a potentially incomplete result after a timeout. + * (Experimental) Approximate version of count() that returns a potentially incomplete result + * within a timeout, even if not all tasks have finished. */ def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => @@ -286,12 +399,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial } /** - * Count elements equal to each value, returning a map of (value, count) pairs. The final combine - * step happens locally on the master, equivalent to running a single reduce task. - * - * TODO: This should perhaps be distributed by default. + * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final + * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): Map[T, Long] = { + // TODO: This should perhaps be distributed by default. def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { val map = new OLMap[T] while (iter.hasNext) { @@ -313,7 +425,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial } /** - * Approximate version of countByValue(). + * (Experimental) Approximate version of countByValue(). */ def countByValueApprox( timeout: Long, @@ -353,18 +465,27 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial return buf.toArray } + /** + * Return the first element in this RDD. + */ def first(): T = take(1) match { case Array(t) => t case _ => throw new UnsupportedOperationException("empty collection") } + /** + * Save this RDD as a text file, using string representations of elements. + */ def saveAsTextFile(path: String) { this.map(x => (NullWritable.get(), new Text(x.toString))) .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path) } + /** + * Save this RDD as a SequenceFile of serialized objects. + */ def saveAsObjectFile(path: String) { - this.glom + this.mapPartitions(iter => iter.grouped(10).map(_.toArray)) .map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x)))) .saveAsSequenceFile(path) } @@ -374,45 +495,3 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } } - -class MappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: T => U) - extends RDD[U](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).map(f) -} - -class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: T => TraversableOnce[U]) - extends RDD[U](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).flatMap(f) -} - -class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = prev.iterator(split).filter(f) -} - -class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) { - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator -} - -class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: Iterator[T] => Iterator[U]) - extends RDD[U](prev.context) { - - override def splits = prev.splits - override val dependencies = List(new OneToOneDependency(prev)) - override def compute(split: Split) = f(prev.iterator(split)) -} diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala index ea7171d3a1..a34aee69c1 100644 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala @@ -23,19 +23,21 @@ import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.Text -import SparkContext._ +import spark.SparkContext._ /** * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, * through an implicit conversion. Note that this can't be part of PairRDDFunctions because * we need more implicit parameters to convert our keys and values to Writable. + * + * Users should import `spark.SparkContext._` at the top of their program to use these functions. */ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest]( self: RDD[(K, V)]) extends Logging with Serializable { - def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = { + private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = { val c = { if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { classManifest[T].erasure @@ -47,6 +49,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla c.asInstanceOf[Class[_ <: Writable]] } + /** + * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key + * and value types. If the key or value are Writable, then we use their classes directly; + * otherwise we map primitive types such as Int and Double to IntWritable, DoubleWritable, etc, + * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported + * file system. + */ def saveAsSequenceFile(path: String) { def anyToWritable[U <% Writable](u: U): Writable = u diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala index 4f8d98f7d0..daa35fe7f2 100644 --- a/core/src/main/scala/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/spark/ShuffleFetcher.scala @@ -1,6 +1,6 @@ package spark -abstract class ShuffleFetcher { +private[spark] abstract class ShuffleFetcher { // Fetch the shuffle outputs for a given ShuffleDependency, calling func exactly // once on each key-value pair obtained. def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) diff --git a/core/src/main/scala/spark/ShuffleManager.scala b/core/src/main/scala/spark/ShuffleManager.scala deleted file mode 100644 index 24af7f3a08..0000000000 --- a/core/src/main/scala/spark/ShuffleManager.scala +++ /dev/null @@ -1,98 +0,0 @@ -package spark - -import java.io._ -import java.net.URL -import java.util.UUID -import java.util.concurrent.atomic.AtomicLong - -import scala.collection.mutable.{ArrayBuffer, HashMap} - -import spark._ - -class ShuffleManager extends Logging { - private var nextShuffleId = new AtomicLong(0) - - private var shuffleDir: File = null - private var server: HttpServer = null - private var serverUri: String = null - - initialize() - - private def initialize() { - // TODO: localDir should be created by some mechanism common to Spark - // so that it can be shared among shuffle, broadcast, etc - val localDirRoot = System.getProperty("spark.local.dir", "/tmp") - var tries = 0 - var foundLocalDir = false - var localDir: File = null - var localDirUuid: UUID = null - while (!foundLocalDir && tries < 10) { - tries += 1 - try { - localDirUuid = UUID.randomUUID - localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed 10 attempts to create local dir in " + localDirRoot) - System.exit(1) - } - shuffleDir = new File(localDir, "shuffle") - shuffleDir.mkdirs() - logInfo("Shuffle dir: " + shuffleDir) - - // Add a shutdown hook to delete the local dir - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dir") { - override def run() { - Utils.deleteRecursively(localDir) - } - }) - - val extServerPort = System.getProperty( - "spark.localFileShuffle.external.server.port", "-1").toInt - if (extServerPort != -1) { - // We're using an external HTTP server; set URI relative to its root - var extServerPath = System.getProperty( - "spark.localFileShuffle.external.server.path", "") - if (extServerPath != "" && !extServerPath.endsWith("/")) { - extServerPath += "/" - } - serverUri = "http://%s:%d/%s/spark-local-%s".format( - Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) - } else { - // Create our own server - server = new HttpServer(localDir) - server.start() - serverUri = server.uri - } - logInfo("Local URI: " + serverUri) - } - - def stop() { - if (server != null) { - server.stop() - } - } - - def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { - val dir = new File(shuffleDir, shuffleId + "/" + inputId) - dir.mkdirs() - val file = new File(dir, "" + outputId) - return file - } - - def getServerUri(): String = { - serverUri - } - - def newShuffleId(): Long = { - nextShuffleId.getAndIncrement() - } -} diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala deleted file mode 100644 index 594dbd235f..0000000000 --- a/core/src/main/scala/spark/ShuffledRDD.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark - -import java.util.{HashMap => JHashMap} - -class ShuffledRDDSplit(val idx: Int) extends Split { - override val index = idx - override def hashCode(): Int = idx -} - -class ShuffledRDD[K, V, C]( - @transient parent: RDD[(K, V)], - aggregator: Aggregator[K, V, C], - part : Partitioner) - extends RDD[(K, C)](parent.context) { - //override val partitioner = Some(part) - override val partitioner = Some(part) - - @transient - val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) - - override def splits = splits_ - - override def preferredLocations(split: Split) = Nil - - val dep = new ShuffleDependency(context.newShuffleId, parent, aggregator, part) - override val dependencies = List(dep) - - override def compute(split: Split): Iterator[(K, C)] = { - val combiners = new JHashMap[K, C] - def mergePair(k: K, c: C) { - val oldC = combiners.get(k) - if (oldC == null) { - combiners.put(k, c) - } else { - combiners.put(k, aggregator.mergeCombiners(oldC, c)) - } - } - val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair) - return new Iterator[(K, C)] { - var iter = combiners.entrySet().iterator() - - def hasNext: Boolean = iter.hasNext() - - def next(): (K, C) = { - val entry = iter.next() - (entry.getKey, entry.getValue) - } - } - } -} diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala index e5ad8b52dc..7c3e8640e9 100644 --- a/core/src/main/scala/spark/SizeEstimator.scala +++ b/core/src/main/scala/spark/SizeEstimator.scala @@ -22,7 +22,7 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet * Based on the following JavaWorld article: * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html */ -object SizeEstimator extends Logging { +private[spark] object SizeEstimator extends Logging { // Sizes of primitive types private val BYTE_SIZE = 1 @@ -77,22 +77,18 @@ object SizeEstimator extends Logging { return System.getProperty("spark.test.useCompressedOops").toBoolean } try { - val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"; - val server = ManagementFactory.getPlatformMBeanServer(); + val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic" + val server = ManagementFactory.getPlatformMBeanServer() val bean = ManagementFactory.newPlatformMXBeanProxy(server, - hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean]); + hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean]) return bean.getVMOption("UseCompressedOops").getValue.toBoolean } catch { - case e: IllegalArgumentException => { - logWarning("Exception while trying to check if compressed oops is enabled", e) - // Fall back to checking if maxMemory < 32GB - return Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) - } - - case e: SecurityException => { - logWarning("No permission to create MBeanServer", e) - // Fall back to checking if maxMemory < 32GB - return Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) + case e: Exception => { + // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB + val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) + val guessInWords = if (guess) "yes" else "not" + logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords) + return guess } } } @@ -146,6 +142,10 @@ object SizeEstimator extends Logging { val cls = obj.getClass if (cls.isArray) { visitArray(obj, cls, state) + } else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) { + // Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses + // the size estimator since it references the whole REPL. Do nothing in this case. In + // general all ClassLoaders and Classes will be shared between objects anyway. } else { val classInfo = getClassInfo(cls) state.size += classInfo.shellSize diff --git a/core/src/main/scala/spark/SoftReferenceCache.scala b/core/src/main/scala/spark/SoftReferenceCache.scala index ce9370c5d7..3dd0a4b1f9 100644 --- a/core/src/main/scala/spark/SoftReferenceCache.scala +++ b/core/src/main/scala/spark/SoftReferenceCache.scala @@ -5,7 +5,7 @@ import com.google.common.collect.MapMaker /** * An implementation of Cache that uses soft references. */ -class SoftReferenceCache extends Cache { +private[spark] class SoftReferenceCache extends Cache { val map = new MapMaker().softValues().makeMap[Any, Any]() override def get(datasetId: Any, partition: Int): Any = diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index b0f5e12a76..becf737597 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -2,13 +2,15 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger +import java.net.{URI, URLClassLoader} + +import scala.collection.Map +import scala.collection.generic.Growable +import scala.collection.mutable.{ArrayBuffer, HashMap} import akka.actor.Actor import akka.actor.Actor._ - -import scala.collection.mutable.ArrayBuffer - -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.SequenceFileInputFormat @@ -25,18 +27,18 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.FileInputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.TextInputFormat - import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} - import org.apache.mesos.{Scheduler, MesosNativeLibrary} import spark.broadcast._ - +import spark.deploy.LocalSparkCluster import spark.partial.ApproximateEvaluator import spark.partial.PartialResult - +import spark.rdd.HadoopRDD +import spark.rdd.NewHadoopRDD +import spark.rdd.UnionRDD import spark.scheduler.ShuffleMapTask import spark.scheduler.DAGScheduler import spark.scheduler.TaskScheduler @@ -45,14 +47,40 @@ import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, C import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import spark.storage.BlockManagerMaster +/** + * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark + * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. + * + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param jobName A name for your job, to display on the cluster web UI. + * @param sparkHome Location where Spark is installed on cluster nodes. + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + * @param environment Environment variables to set on worker nodes. + */ class SparkContext( master: String, - frameworkName: String, + jobName: String, val sparkHome: String, - val jars: Seq[String]) + jars: Seq[String], + environment: Map[String, String]) extends Logging { - - def this(master: String, frameworkName: String) = this(master, frameworkName, null, Nil) + + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param jobName A name for your job, to display on the cluster web UI + * @param sparkHome Location where Spark is installed on cluster nodes. + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + */ + def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) = + this(master, jobName, sparkHome, jars, Map()) + + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param jobName A name for your job, to display on the cluster web UI + */ + def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map()) // Ensure logging is initialized before we spawn any threads initLogging() @@ -68,37 +96,78 @@ class SparkContext( private val isLocal = (master == "local" || master.startsWith("local[")) // Create the Spark execution environment (cache, map output tracker, etc) - val env = SparkEnv.createFromSystemProperties( + private[spark] val env = SparkEnv.createFromSystemProperties( System.getProperty("spark.master.host"), System.getProperty("spark.master.port").toInt, true, isLocal) SparkEnv.set(env) - Broadcast.initialize(true) + + // Used to store a URL for each static file/jar together with the file's local timestamp + private[spark] val addedFiles = HashMap[String, Long]() + private[spark] val addedJars = HashMap[String, Long]() + + // Add each JAR given through the constructor + jars.foreach { addJar(_) } + + // Environment variables to pass to our executors + private[spark] val executorEnvs = HashMap[String, String]() + for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", + "SPARK_TESTING")) { + val value = System.getenv(key) + if (value != null) { + executorEnvs(key) = value + } + } + executorEnvs ++= environment // Create and start the scheduler private var taskScheduler: TaskScheduler = { // Regular expression used for local[N] master format val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + // Regular expression for simulating a Spark cluster of [N, cores, memory] locally + val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters val SPARK_REGEX = """(spark://.*)""".r master match { - case "local" => - new LocalScheduler(1, 0) + case "local" => + new LocalScheduler(1, 0, this) - case LOCAL_N_REGEX(threads) => - new LocalScheduler(threads.toInt, 0) + case LOCAL_N_REGEX(threads) => + new LocalScheduler(threads.toInt, 0, this) case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - new LocalScheduler(threads.toInt, maxFailures.toInt) + new LocalScheduler(threads.toInt, maxFailures.toInt, this) case SPARK_REGEX(sparkUrl) => val scheduler = new ClusterScheduler(this) - val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName) + val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName) + scheduler.initialize(backend) + scheduler + + case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => + // Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang. + val memoryPerSlaveInt = memoryPerSlave.toInt + val sparkMemEnv = System.getenv("SPARK_MEM") + val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512 + if (sparkMemEnvInt > memoryPerSlaveInt) { + throw new SparkException( + "Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format( + memoryPerSlaveInt, sparkMemEnvInt)) + } + + val scheduler = new ClusterScheduler(this) + val localCluster = new LocalSparkCluster( + numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) + val sparkUrl = localCluster.start() + val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName) scheduler.initialize(backend) + backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { + localCluster.stop() + } scheduler case _ => @@ -106,9 +175,9 @@ class SparkContext( val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, master, frameworkName) + new CoarseMesosSchedulerBackend(scheduler, this, master, jobName) } else { - new MesosSchedulerBackend(scheduler, this, master, frameworkName) + new MesosSchedulerBackend(scheduler, this, master, jobName) } scheduler.initialize(backend) scheduler @@ -120,14 +189,20 @@ class SparkContext( // Methods for creating RDDs - def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = { + /** Distribute a local Scala collection to form an RDD. */ + def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { new ParallelCollection[T](this, seq, numSlices) } - - def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = { + + /** Distribute a local Scala collection to form an RDD. */ + def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { parallelize(seq, numSlices) } + /** + * Read a text file from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI, and return it as an RDD of Strings. + */ def textFile(path: String, minSplits: Int = defaultMinSplits): RDD[String] = { hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minSplits) .map(pair => pair._2.toString) @@ -164,19 +239,31 @@ class SparkContext( } /** - * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys, - * values and the InputFormat so that users don't need to pass them directly. + * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys, + * values and the InputFormat so that users don't need to pass them directly. Instead, callers + * can just write, for example, + * {{{ + * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits) + * }}} */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int) (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]) : RDD[(K, V)] = { hadoopFile(path, - fm.erasure.asInstanceOf[Class[F]], + fm.erasure.asInstanceOf[Class[F]], km.erasure.asInstanceOf[Class[K]], vm.erasure.asInstanceOf[Class[V]], minSplits) } + /** + * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys, + * values and the InputFormat so that users don't need to pass them directly. Instead, callers + * can just write, for example, + * {{{ + * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) + * }}} + */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = hadoopFile[K, V, F](path, defaultMinSplits) @@ -192,7 +279,7 @@ class SparkContext( new Configuration) } - /** + /** * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. */ @@ -208,7 +295,7 @@ class SparkContext( new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf) } - /** + /** * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. */ @@ -220,7 +307,7 @@ class SparkContext( new NewHadoopRDD(this, fClass, kClass, vClass, conf) } - /** Get an RDD for a Hadoop SequenceFile with given key and value types */ + /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V], @@ -230,18 +317,23 @@ class SparkContext( hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits) } + /** Get an RDD for a Hadoop SequenceFile with given key and value types. */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = sequenceFile(path, keyClass, valueClass, defaultMinSplits) /** - * Version of sequenceFile() for types implicitly convertible to Writables through a - * WritableConverter. + * Version of sequenceFile() for types implicitly convertible to Writables through a + * WritableConverter. For example, to access a SequenceFile where the keys are Text and the + * values are IntWritable, you could simply write + * {{{ + * sparkContext.sequenceFile[String, Int](path, ...) + * }}} * * WritableConverters are provided in a somewhat strange way (by an implicit function) to support - * both subclasses of Writable and types for which we define a converter (e.g. Int to + * both subclasses of Writable and types for which we define a converter (e.g. Int to * IntWritable). The most natural thing would've been to have implicit objects for the * converters, but then we couldn't have an object for every subclass of Writable (you can't - * have a parameterized singleton object). We use functions instead to create a new converter + * have a parameterized singleton object). We use functions instead to create a new converter * for the appropriate type. In addition, we pass the converter a ClassManifest of its type to * allow it to figure out the Writable class to use in the subclass case. */ @@ -266,7 +358,7 @@ class SparkContext( * that there's very little effort required to save arbitrary objects. */ def objectFile[T: ClassManifest]( - path: String, + path: String, minSplits: Int = defaultMinSplits ): RDD[T] = { sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits) @@ -276,43 +368,118 @@ class SparkContext( /** Build the union of a list of RDDs. */ def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) - /** Build the union of a list of RDDs. */ + /** Build the union of a list of RDDs passed as variable-length arguments. */ def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] = new UnionRDD(this, Seq(first) ++ rest) // Methods for creating shared variables + /** + * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values + * to using the `+=` method. Only the master can access the accumulator's `value`. + */ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) = new Accumulator(initialValue, param) /** - * Create an accumulable shared variable, with a `+=` method + * Create an [[spark.Accumulable]] shared variable, with a `+=` method * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = new Accumulable(initialValue, param) + /** + * Create an accumulator from a "mutable collection" type. + * + * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by + * standard mutable collections. So you can use this with mutable Map, Set, etc. + */ + def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = { + val param = new GrowableAccumulableParam[R,T] + new Accumulable(initialValue, param) + } + + /** + * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for + * reading it in distributed functions. The variable will be sent to each cluster only once. + */ + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal) + + /** + * Add a file to be downloaded into the working directory of this Spark job on every node. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. + */ + def addFile(path: String) { + val uri = new URI(path) + val key = uri.getScheme match { + case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) + case _ => path + } + addedFiles(key) = System.currentTimeMillis + + // Fetch the file locally in case the task is executed locally + val filename = new File(path.split("/").last) + Utils.fetchFile(path, new File(".")) + + logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) + } - // Keep around a weak hash map of values to Cached versions? - def broadcast[T](value: T) = Broadcast.getBroadcastFactory.newBroadcast[T] (value, isLocal) + /** + * Clear the job's list of files added by `addFile` so that they do not get donwloaded to + * any new nodes. + */ + def clearFiles() { + addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } + addedFiles.clear() + } - // Stop the SparkContext + /** + * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. + * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. + */ + def addJar(path: String) { + val uri = new URI(path) + val key = uri.getScheme match { + case null | "file" => env.httpFileServer.addJar(new File(uri.getPath)) + case _ => path + } + addedJars(key) = System.currentTimeMillis + logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key)) + } + + /** + * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to + * any new nodes. + */ + def clearJars() { + addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() } + addedJars.clear() + } + + /** Shut down the SparkContext. */ def stop() { dagScheduler.stop() dagScheduler = null taskScheduler = null - // TODO: Broadcast.stop(), Cache.stop()? + // TODO: Cache.stop()? env.stop() + // Clean up locally linked files + clearFiles() + clearJars() SparkEnv.set(null) ShuffleMapTask.clearCache() logInfo("Successfully stopped SparkContext") } - // Get Spark's home location from either a value set through the constructor, - // or the spark.home Java property, or the SPARK_HOME environment variable - // (in that order of preference). If neither of these is set, return None. - def getSparkHome(): Option[String] = { + /** + * Get Spark's home location from either a value set through the constructor, + * or the spark.home Java property, or the SPARK_HOME environment variable + * (in that order of preference). If neither of these is set, return None. + */ + private[spark] def getSparkHome(): Option[String] = { if (sparkHome != null) { Some(sparkHome) } else if (System.getProperty("spark.home") != null) { @@ -327,7 +494,7 @@ class SparkContext( /** * Run a function on a given set of partitions in an RDD and return the results. This is the main * entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies - * whether the scheduler can run the computation on the master rather than shipping it out to the + * whether the scheduler can run the computation on the master rather than shipping it out to the * cluster, for short actions like first(). */ def runJob[T, U: ClassManifest]( @@ -336,22 +503,27 @@ class SparkContext( partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - logInfo("Starting job...") + val callSite = Utils.getSparkCallSite + logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, allowLocal) - logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s") + val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal) + logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") result } + /** + * Run a job on a given set of partitions of an RDD, but take a function of type + * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + */ def runJob[T, U: ClassManifest]( rdd: RDD[T], - func: Iterator[T] => U, + func: Iterator[T] => U, partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal) } - + /** * Run a job on all partitions in an RDD and return the results in an array. */ @@ -359,6 +531,9 @@ class SparkContext( runJob(rdd, func, 0 until rdd.splits.size, false) } + /** + * Run a job on all partitions in an RDD and return the results in an array. + */ def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { runJob(rdd, func, 0 until rdd.splits.size, false) } @@ -372,38 +547,37 @@ class SparkContext( evaluator: ApproximateEvaluator[U, R], timeout: Long ): PartialResult[R] = { - logInfo("Starting job...") + val callSite = Utils.getSparkCallSite + logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout) - logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s") + val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout) + logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") result } - // Clean a closure to make it ready to serialized and send to tasks - // (removes unreferenced variables in $outer's, updates REPL variables) + /** + * Clean a closure to make it ready to serialized and send to tasks + * (removes unreferenced variables in $outer's, updates REPL variables) + */ private[spark] def clean[F <: AnyRef](f: F): F = { ClosureCleaner.clean(f) return f } - // Default level of parallelism to use when not given by user (e.g. for reduce tasks) + /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */ def defaultParallelism: Int = taskScheduler.defaultParallelism - // Default min number of splits for Hadoop RDDs when not given by user + /** Default min number of splits for Hadoop RDDs when not given by user */ def defaultMinSplits: Int = math.min(defaultParallelism, 2) private var nextShuffleId = new AtomicInteger(0) - private[spark] def newShuffleId(): Int = { - nextShuffleId.getAndIncrement() - } - + private[spark] def newShuffleId(): Int = nextShuffleId.getAndIncrement() + private var nextRddId = new AtomicInteger(0) - // Register a new RDD, returning its RDD ID - private[spark] def newRddId(): Int = { - nextRddId.getAndIncrement() - } + /** Register a new RDD, returning its RDD ID */ + private[spark] def newRddId(): Int = nextRddId.getAndIncrement() } /** @@ -425,7 +599,7 @@ object SparkContext { implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = new PairRDDFunctions(rdd) - + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest]( rdd: RDD[(K, V)]) = new SequenceFileRDDFunctions(rdd) @@ -446,7 +620,7 @@ object SparkContext { implicit def longToLongWritable(l: Long) = new LongWritable(l) implicit def floatToFloatWritable(f: Float) = new FloatWritable(f) - + implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d) implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b) @@ -457,7 +631,7 @@ object SparkContext { private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = { def anyToWritable[U <% Writable](u: U): Writable = u - + new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]], arr.map(x => anyToWritable(x)).toArray) } @@ -485,8 +659,10 @@ object SparkContext { implicit def writableWritableConverter[T <: Writable]() = new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T]) - // Find the JAR from which a given class was loaded, to make it easy for users to pass - // their JARs to SparkContext + /** + * Find the JAR from which a given class was loaded, to make it easy for users to pass + * their JARs to SparkContext + */ def jarOfClass(cls: Class[_]): Seq[String] = { val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class") if (uri != null) { @@ -501,8 +677,8 @@ object SparkContext { Nil } } - - // Find the JAR that contains the class of a particular object + + /** Find the JAR that contains the class of a particular object */ def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) } @@ -514,7 +690,7 @@ object SparkContext { * that doesn't know the type of T when it is created. This sounds strange but is necessary to * support converting subclasses of Writable to themselves (writableWritableConverter). */ -class WritableConverter[T]( +private[spark] class WritableConverter[T]( val writableClass: ClassManifest[T] => Class[_ <: Writable], val convert: Writable => T) extends Serializable diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 694db6b2a3..4c6ec6cc6e 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -1,43 +1,57 @@ package spark import akka.actor.ActorSystem +import akka.actor.ActorSystemImpl +import akka.remote.RemoteActorRefProvider +import serializer.Serializer +import spark.broadcast.BroadcastManager import spark.storage.BlockManager import spark.storage.BlockManagerMaster import spark.network.ConnectionManager import spark.util.AkkaUtils +/** + * Holds all the runtime environment objects for a running Spark instance (either master or worker), + * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently + * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these + * objects needs to have the right SparkEnv set. You can get the current environment with + * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. + */ class SparkEnv ( val actorSystem: ActorSystem, - val cache: Cache, val serializer: Serializer, val closureSerializer: Serializer, val cacheTracker: CacheTracker, val mapOutputTracker: MapOutputTracker, val shuffleFetcher: ShuffleFetcher, - val shuffleManager: ShuffleManager, + val broadcastManager: BroadcastManager, val blockManager: BlockManager, - val connectionManager: ConnectionManager + val connectionManager: ConnectionManager, + val httpFileServer: HttpFileServer ) { /** No-parameter constructor for unit tests. */ def this() = { - this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null) + this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null) } def stop() { + httpFileServer.stop() mapOutputTracker.stop() cacheTracker.stop() shuffleFetcher.stop() - shuffleManager.stop() + broadcastManager.stop() blockManager.stop() blockManager.master.stop() actorSystem.shutdown() + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut + // down, but let's call it anyway in case it gets fixed in a later release actorSystem.awaitTermination() } } -object SparkEnv { +object SparkEnv extends Logging { private val env = new ThreadLocal[SparkEnv] def set(e: SparkEnv) { @@ -63,63 +77,55 @@ object SparkEnv { System.setProperty("spark.master.port", boundPort.toString) } - val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer") - val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] + val classLoader = Thread.currentThread.getContextClassLoader + + // Create an instance of the class named by the given Java system property, or by + // defaultClassName if the property is not set, and return it as a T + def instantiateClass[T](propertyName: String, defaultClassName: String): T = { + val name = System.getProperty(propertyName, defaultClassName) + Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] + } + + val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer") val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal) - val blockManager = new BlockManager(blockManagerMaster, serializer) - val connectionManager = blockManager.connectionManager - - val shuffleManager = new ShuffleManager() + val connectionManager = blockManager.connectionManager + + val broadcastManager = new BroadcastManager(isMaster) - val closureSerializerClass = - System.getProperty("spark.closure.serializer", "spark.JavaSerializer") - val closureSerializer = - Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer] - val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache") - val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] + val closureSerializer = instantiateClass[Serializer]( + "spark.closure.serializer", "spark.JavaSerializer") val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager) blockManager.cacheTracker = cacheTracker val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) - val shuffleFetcherClass = - System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") - val shuffleFetcher = - Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher] - - /* - if (System.getProperty("spark.stream.distributed", "false") == "true") { - val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]] - if (isLocal || !isMaster) { - (new Thread() { - override def run() { - println("Wait started") - Thread.sleep(60000) - println("Wait ended") - val receiverClass = Class.forName("spark.stream.TestStreamReceiver4") - val constructor = receiverClass.getConstructor(blockManagerClass) - val receiver = constructor.newInstance(blockManager) - receiver.asInstanceOf[Thread].start() - } - }).start() - } + val shuffleFetcher = instantiateClass[ShuffleFetcher]( + "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") + + val httpFileServer = new HttpFileServer() + httpFileServer.initialize() + System.setProperty("spark.fileserver.uri", httpFileServer.serverUri) + + // Warn about deprecated spark.cache.class property + if (System.getProperty("spark.cache.class") != null) { + logWarning("The spark.cache.class property is no longer being used! Specify storage " + + "levels using the RDD.persist() method instead.") } - */ new SparkEnv( actorSystem, - cache, serializer, closureSerializer, cacheTracker, mapOutputTracker, shuffleFetcher, - shuffleManager, + broadcastManager, blockManager, - connectionManager) + connectionManager, + httpFileServer) } } diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala index 6e4eb25ed4..420c54bc9a 100644 --- a/core/src/main/scala/spark/TaskEndReason.scala +++ b/core/src/main/scala/spark/TaskEndReason.scala @@ -7,10 +7,16 @@ import spark.storage.BlockManagerId * tasks several times for "ephemeral" failures, and only report back failures that require some * old stages to be resubmitted, such as shuffle map fetch failures. */ -sealed trait TaskEndReason +private[spark] sealed trait TaskEndReason -case object Success extends TaskEndReason +private[spark] case object Success extends TaskEndReason + +private[spark] case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it + +private[spark] case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason -case class ExceptionFailure(exception: Throwable) extends TaskEndReason -case class OtherFailure(message: String) extends TaskEndReason + +private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason + +private[spark] case class OtherFailure(message: String) extends TaskEndReason diff --git a/core/src/main/scala/spark/TaskState.scala b/core/src/main/scala/spark/TaskState.scala index 9566b52432..78eb33a628 100644 --- a/core/src/main/scala/spark/TaskState.scala +++ b/core/src/main/scala/spark/TaskState.scala @@ -2,7 +2,7 @@ package spark import org.apache.mesos.Protos.{TaskState => MesosTaskState} -object TaskState +private[spark] object TaskState extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") { val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 5eda1011f9..567c4b1475 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -1,18 +1,18 @@ package spark import java.io._ -import java.net.InetAddress +import java.net.{InetAddress, URL, URI} +import java.util.{Locale, Random, UUID} import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} - +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import scala.collection.mutable.ArrayBuffer -import scala.util.Random -import java.util.{Locale, UUID} import scala.io.Source /** * Various utility methods used by Spark. */ -object Utils { +private object Utils extends Logging { /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -71,7 +71,7 @@ object Utils { while (dir == null) { attempts += 1 if (attempts > maxAttempts) { - throw new IOException("Failed to create a temp directory after " + maxAttempts + + throw new IOException("Failed to create a temp directory after " + maxAttempts + " attempts!") } try { @@ -116,22 +116,84 @@ object Utils { copyStream(in, out, true) } + /** Download a file from a given URL to the local filesystem */ + def downloadFile(url: URL, localPath: String) { + val in = url.openStream() + val out = new FileOutputStream(localPath) + Utils.copyStream(in, out, true) + } + + /** + * Download a file requested by the executor. Supports fetching the file in a variety of ways, + * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. + */ + def fetchFile(url: String, targetDir: File) { + val filename = url.split("/").last + val targetFile = new File(targetDir, filename) + val uri = new URI(url) + uri.getScheme match { + case "http" | "https" | "ftp" => + logInfo("Fetching " + url + " to " + targetFile) + val in = new URL(url).openStream() + val out = new FileOutputStream(targetFile) + Utils.copyStream(in, out, true) + case "file" | null => + // Remove the file if it already exists + targetFile.delete() + // Symlink the file locally. + if (uri.isAbsolute) { + // url is absolute, i.e. it starts with "file:///". Extract the source + // file's absolute path from the url. + val sourceFile = new File(uri) + logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath) + FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath) + } else { + // url is not absolute, i.e. itself is the path to the source file. + logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath) + FileUtil.symLink(url, targetFile.getAbsolutePath) + } + case _ => + // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others + val uri = new URI(url) + val conf = new Configuration() + val fs = FileSystem.get(uri, conf) + val in = fs.open(new Path(uri)) + val out = new FileOutputStream(targetFile) + Utils.copyStream(in, out, true) + } + // Decompress the file if it's a .tar or .tar.gz + if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { + logInfo("Untarring " + filename) + Utils.execute(Seq("tar", "-xzf", filename), targetDir) + } else if (filename.endsWith(".tar")) { + logInfo("Untarring " + filename) + Utils.execute(Seq("tar", "-xf", filename), targetDir) + } + // Make the file executable - That's necessary for scripts + FileUtil.chmod(filename, "a+x") + } + /** * Shuffle the elements of a collection into a random order, returning the * result in a new collection. Unlike scala.util.Random.shuffle, this method * uses a local random number generator, avoiding inter-thread contention. */ - def randomize[T](seq: TraversableOnce[T]): Seq[T] = { - val buf = new ArrayBuffer[T]() - buf ++= seq - val rand = new Random() - for (i <- (buf.size - 1) to 1 by -1) { + def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = { + randomizeInPlace(seq.toArray) + } + + /** + * Shuffle the elements of an array into a random order, modifying the + * original array. Returns the original array. + */ + def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = { + for (i <- (arr.length - 1) to 1 by -1) { val j = rand.nextInt(i) - val tmp = buf(j) - buf(j) = buf(i) - buf(i) = tmp + val tmp = arr(j) + arr(j) = arr(i) + arr(i) = tmp } - buf + arr } /** @@ -155,7 +217,7 @@ object Utils { def localHostName(): String = { customHostname.getOrElse(InetAddress.getLocalHost.getHostName) } - + /** * Returns a standard ThreadFactory except all threads are daemons. */ @@ -179,10 +241,10 @@ object Utils { return threadPool } - + /** - * Return the string to tell how long has passed in seconds. The passing parameter should be in - * millisecond. + * Return the string to tell how long has passed in seconds. The passing parameter should be in + * millisecond. */ def getUsedTimeMs(startTimeMs: Long): String = { return " " + (System.currentTimeMillis - startTimeMs) + " ms " @@ -294,4 +356,43 @@ object Utils { def execute(command: Seq[String]) { execute(command, new File(".")) } + + + /** + * When called inside a class in the spark package, returns the name of the user code class + * (outside the spark package) that called into Spark, as well as which Spark method they called. + * This is used, for example, to tell users where in their code each RDD got created. + */ + def getSparkCallSite: String = { + val trace = Thread.currentThread.getStackTrace().filter( el => + (!el.getMethodName.contains("getStackTrace"))) + + // Keep crawling up the stack trace until we find the first function not inside of the spark + // package. We track the last (shallowest) contiguous Spark method. This might be an RDD + // transformation, a SparkContext function (such as parallelize), or anything else that leads + // to instantiation of an RDD. We also track the first (deepest) user method, file, and line. + var lastSparkMethod = "<unknown>" + var firstUserFile = "<unknown>" + var firstUserLine = 0 + var finished = false + + for (el <- trace) { + if (!finished) { + if (el.getClassName.startsWith("spark.") && !el.getClassName.startsWith("spark.examples.")) { + lastSparkMethod = if (el.getMethodName == "<init>") { + // Spark method is a constructor; get its class name + el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1) + } else { + el.getMethodName + } + } + else { + firstUserLine = el.getLineNumber + firstUserFile = el.getFileName + finished = true + } + } + } + "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine) + } } diff --git a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala index 7c0b17c45e..843e1bd18b 100644 --- a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala @@ -22,8 +22,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav import JavaDoubleRDD.fromRDD + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): JavaDoubleRDD = fromRDD(srdd.cache()) + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. Can only be called once on each RDD. + */ def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel)) // first() has to be overriden here in order for its return type to be Double instead of Object. @@ -31,36 +36,63 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav // Transformations (return a new RDD) + /** + * Return a new RDD containing the distinct elements in this RDD. + */ def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct()) + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(numSplits: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numSplits)) + + /** + * Return a new RDD containing only the elements that satisfy a predicate. + */ def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD = fromRDD(srdd.filter(x => f(x).booleanValue())) + /** + * Return a sampled subset of this RDD. + */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaDoubleRDD = fromRDD(srdd.sample(withReplacement, fraction, seed)) + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd)) // Double RDD functions + /** Return the sum of the elements in this RDD. */ def sum(): Double = srdd.sum() + /** Return a [[spark.StatCounter]] describing the elements in this RDD. */ def stats(): StatCounter = srdd.stats() + /** Return the mean of the elements in this RDD. */ def mean(): Double = srdd.mean() + /** Return the variance of the elements in this RDD. */ def variance(): Double = srdd.variance() + /** Return the standard deviation of the elements in this RDD. */ def stdev(): Double = srdd.stdev() + /** Return the approximate mean of the elements in this RDD. */ def meanApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = srdd.meanApprox(timeout, confidence) + /** Return the approximate mean of the elements in this RDD. */ def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout) + /** Return the approximate sum of the elements in this RDD. */ def sumApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = srdd.sumApprox(timeout, confidence) - + + /** Return the approximate sum of the elements in this RDD. */ def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) } diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index c28a13b061..5c2be534ff 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -1,13 +1,5 @@ package spark.api.java -import spark.SparkContext.rddToPairRDDFunctions -import spark.api.java.function.{Function2 => JFunction2} -import spark.api.java.function.{Function => JFunction} -import spark.partial.BoundedDouble -import spark.partial.PartialResult -import spark.storage.StorageLevel -import spark._ - import java.util.{List => JList} import java.util.Comparator @@ -19,6 +11,17 @@ import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.hadoop.conf.Configuration +import spark.api.java.function.{Function2 => JFunction2} +import spark.api.java.function.{Function => JFunction} +import spark.partial.BoundedDouble +import spark.partial.PartialResult +import spark.OrderedRDDFunctions +import spark.storage.StorageLevel +import spark.HashPartitioner +import spark.Partitioner +import spark.RDD +import spark.SparkContext.rddToPairRDDFunctions + class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K], implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { @@ -31,21 +34,44 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif // Common RDD functions + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache()) + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. Can only be called once on each RDD. + */ def persist(newLevel: StorageLevel): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.persist(newLevel)) // Transformations (return a new RDD) + /** + * Return a new RDD containing the distinct elements in this RDD. + */ def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct()) + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(numSplits: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numSplits)) + + /** + * Return a new RDD containing only the elements that satisfy a predicate. + */ def filter(f: Function[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue())) + /** + * Return a sampled subset of this RDD. + */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed)) + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.union(other.rdd)) @@ -56,7 +82,21 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif override def first(): (K, V) = rdd.first() // Pair RDD functions - + + /** + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a + * "combined type" C * Note that V and C can be different -- for example, one might group an + * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three + * functions: + * + * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) + * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) + * - `mergeCombiners`, to combine two C's into a single one. + * + * In addition, users can control the partitioning of the output RDD, and whether to perform + * map-side aggregation (if a mapper can produce multiple items with the same key). + */ def combineByKey[C](createCombiner: Function[V, C], mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C], @@ -71,50 +111,113 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif )) } + /** + * Simplified version of combineByKey that hash-partitions the output RDD. + */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C], numSplits: Int): JavaPairRDD[K, C] = combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits)) + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. + */ def reduceByKey(partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] = fromRDD(rdd.reduceByKey(partitioner, func)) + /** + * Merge the values for each key using an associative reduce function, but return the results + * immediately to the master as a Map. This will also perform the merging locally on each mapper + * before sending results to a reducer, similarly to a "combiner" in MapReduce. + */ def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] = mapAsJavaMap(rdd.reduceByKeyLocally(func)) + /** Count the number of elements for each key, and return the result to the master as a Map. */ def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) + /** + * (Experimental) Approximate version of countByKey that can return a partial result if it does + * not finish within a timeout. + */ def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout).map(mapAsJavaMap) + /** + * (Experimental) Approximate version of countByKey that can return a partial result if it does + * not finish within a timeout. + */ def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits. + */ def reduceByKey(func: JFunction2[V, V, V], numSplits: Int): JavaPairRDD[K, V] = fromRDD(rdd.reduceByKey(func, numSplits)) + /** + * Group the values for each key in the RDD into a single sequence. Allows controlling the + * partitioning of the resulting key-value pair RDD by passing a Partitioner. + */ def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JList[V]] = fromRDD(groupByResultToJava(rdd.groupByKey(partitioner))) + /** + * Group the values for each key in the RDD into a single sequence. Hash-partitions the + * resulting RDD with into `numSplits` partitions. + */ def groupByKey(numSplits: Int): JavaPairRDD[K, JList[V]] = fromRDD(groupByResultToJava(rdd.groupByKey(numSplits))) + /** + * Return a copy of the RDD partitioned using the specified partitioner. If `mapSideCombine` + * is true, Spark will group values of the same key together on the map side before the + * repartitioning, to only send each key over the network once. If a large number of + * duplicated keys are expected, and the size of the keys are large, `mapSideCombine` should + * be set to true. + */ def partitionBy(partitioner: Partitioner): JavaPairRDD[K, V] = fromRDD(rdd.partitionBy(partitioner)) + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. + */ def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] = fromRDD(rdd.join(other, partitioner)) + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to + * partition the output RDD. + */ def leftOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner) : JavaPairRDD[K, (V, Option[W])] = fromRDD(rdd.leftOuterJoin(other, partitioner)) + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to + * partition the output RDD. + */ def rightOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner) : JavaPairRDD[K, (Option[V], W)] = fromRDD(rdd.rightOuterJoin(other, partitioner)) + /** + * Simplified version of combineByKey that hash-partitions the resulting RDD using the default + * parallelism level. + */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = { @@ -123,40 +226,94 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners)) } + /** + * Merge the values for each key using an associative reduce function. This will also perform + * the merging locally on each mapper before sending results to a reducer, similarly to a + * "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level. + */ def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = { val partitioner = rdd.defaultPartitioner(rdd) fromRDD(reduceByKey(partitioner, func)) } + /** + * Group the values for each key in the RDD into a single sequence. Hash-partitions the + * resulting RDD with the default parallelism level. + */ def groupByKey(): JavaPairRDD[K, JList[V]] = fromRDD(groupByResultToJava(rdd.groupByKey())) + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Performs a hash join across the cluster. + */ def join[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, W)] = fromRDD(rdd.join(other)) + /** + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Performs a hash join across the cluster. + */ def join[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, W)] = fromRDD(rdd.join(other, numSplits)) + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output + * using the default level of parallelism. + */ def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Option[W])] = fromRDD(rdd.leftOuterJoin(other)) + /** + * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the + * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the + * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output + * into `numSplits` partitions. + */ def leftOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, Option[W])] = fromRDD(rdd.leftOuterJoin(other, numSplits)) + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting + * RDD using the default parallelism level. + */ def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Option[V], W)] = fromRDD(rdd.rightOuterJoin(other)) + /** + * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the + * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the + * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting + * RDD into the given number of partitions. + */ def rightOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (Option[V], W)] = fromRDD(rdd.rightOuterJoin(other, numSplits)) + /** + * Return the key-value pairs in this RDD to the master as a Map. + */ def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap()) + /** + * Pass each value in the key-value pair RDD through a map function without changing the keys; + * this also retains the original RDD's partitioning. + */ def mapValues[U](f: Function[V, U]): JavaPairRDD[K, U] = { implicit val cm: ClassManifest[U] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] fromRDD(rdd.mapValues(f)) } + /** + * Pass each value in the key-value pair RDD through a flatMap function without changing the + * keys; this also retains the original RDD's partitioning. + */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { import scala.collection.JavaConverters._ def fn = (x: V) => f.apply(x).asScala @@ -165,37 +322,68 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif fromRDD(rdd.flatMapValues(fn)) } + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ def cogroup[W](other: JavaPairRDD[K, W], partitioner: Partitioner) : JavaPairRDD[K, (JList[V], JList[W])] = fromRDD(cogroupResultToJava(rdd.cogroup(other, partitioner))) + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], partitioner: Partitioner) : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner))) + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] = fromRDD(cogroupResultToJava(rdd.cogroup(other))) + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2]) : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2))) + /** + * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the + * list of values for that key in `this` as well as `other`. + */ def cogroup[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (JList[V], JList[W])] = fromRDD(cogroupResultToJava(rdd.cogroup(other, numSplits))) + /** + * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a + * tuple with the list of values for that key in `this`, `other1` and `other2`. + */ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numSplits: Int) : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numSplits))) + /** Alias for cogroup. */ def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] = fromRDD(cogroupResultToJava(rdd.groupWith(other))) + /** Alias for cogroup. */ def groupWith[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2]) : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2))) + /** + * Return the list of values in the RDD for key `key`. This operation is done efficiently if the + * RDD has a known partitioner by only searching the partition that the key maps to. + */ def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key)) + /** Output the RDD to any Hadoop-supported file system. */ def saveAsHadoopFile[F <: OutputFormat[_, _]]( path: String, keyClass: Class[_], @@ -205,6 +393,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf) } + /** Output the RDD to any Hadoop-supported file system. */ def saveAsHadoopFile[F <: OutputFormat[_, _]]( path: String, keyClass: Class[_], @@ -213,6 +402,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass) } + /** Output the RDD to any Hadoop-supported file system. */ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( path: String, keyClass: Class[_], @@ -222,6 +412,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf) } + /** Output the RDD to any Hadoop-supported file system. */ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( path: String, keyClass: Class[_], @@ -230,21 +421,49 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass) } + /** + * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for + * that storage system. The JobConf should set an OutputFormat and any output paths required + * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop + * MapReduce job. + */ def saveAsHadoopDataset(conf: JobConf) { rdd.saveAsHadoopDataset(conf) } - - // Ordered RDD Functions + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements in + * ascending order. Calling `collect` or `save` on the resulting RDD will return or output an + * ordered list of records (in the `save` case, they will be written to multiple `part-X` files + * in the filesystem, in order of the keys). + */ def sortByKey(): JavaPairRDD[K, V] = sortByKey(true) + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = { val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]] sortByKey(comp, true) } + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ def sortByKey(comp: Comparator[K]): JavaPairRDD[K, V] = sortByKey(comp, true) + /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ def sortByKey(comp: Comparator[K], ascending: Boolean): JavaPairRDD[K, V] = { class KeyOrdering(val a: K) extends Ordered[K] { override def compare(b: K) = comp.compare(a, b) @@ -274,4 +493,4 @@ object JavaPairRDD { new JavaPairRDD[K, V](rdd) implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd -}
\ No newline at end of file +} diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala index 541aa1e60b..ac31350ec3 100644 --- a/core/src/main/scala/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaRDD.scala @@ -11,20 +11,43 @@ JavaRDDLike[T, JavaRDD[T]] { // Common RDD functions + /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): JavaRDD[T] = wrapRDD(rdd.cache()) + /** + * Set this RDD's storage level to persist its values across operations after the first time + * it is computed. Can only be called once on each RDD. + */ def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel)) // Transformations (return a new RDD) + /** + * Return a new RDD containing the distinct elements in this RDD. + */ def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct()) + /** + * Return a new RDD containing the distinct elements in this RDD. + */ + def distinct(numSplits: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numSplits)) + + /** + * Return a new RDD containing only the elements that satisfy a predicate. + */ def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] = wrapRDD(rdd.filter((x => f(x).booleanValue()))) + /** + * Return a sampled subset of this RDD. + */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) - + + /** + * Return the union of this RDD and another one. Any identical elements will appear multiple + * times (use `.distinct()` to eliminate them). + */ def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd)) } diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 785dd96394..13fcee1004 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -19,41 +19,71 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def rdd: RDD[T] + /** Set of partitions in this RDD. */ def splits: JList[Split] = new java.util.ArrayList(rdd.splits.toSeq) + /** The [[spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context - + + /** A unique ID for this RDD (within its SparkContext). */ def id: Int = rdd.id + /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel: StorageLevel = rdd.getStorageLevel + /** + * Internal method to this RDD; will read from cache if applicable, or otherwise compute it. + * This should ''not'' be called by users directly, but is available for implementors of custom + * subclasses of RDD. + */ def iterator(split: Split): java.util.Iterator[T] = asJavaIterator(rdd.iterator(split)) // Transformations (return a new RDD) + /** + * Return a new RDD by applying a function to all elements of this RDD. + */ def map[R](f: JFunction[T, R]): JavaRDD[R] = new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType()) + /** + * Return a new RDD by applying a function to all elements of this RDD. + */ def map[R](f: DoubleFunction[T]): JavaDoubleRDD = new JavaDoubleRDD(rdd.map(x => f(x).doubleValue())) + /** + * Return a new RDD by applying a function to all elements of this RDD. + */ def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType()) } + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType()) } + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue())) } + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala @@ -61,29 +91,50 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType()) } + /** + * Return a new RDD by applying a function to each partition of this RDD. + */ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = { def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType()) } + + /** + * Return a new RDD by applying a function to each partition of this RDD. + */ def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = { def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue())) } + /** + * Return a new RDD by applying a function to each partition of this RDD. + */ def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]): JavaPairRDD[K, V] = { def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType()) } + /** + * Return an RDD created by coalescing all elements within each partition into an array. + */ def glom(): JavaRDD[JList[T]] = new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq))) + /** + * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of + * elements (a, b) where a is in `this` and b is in `other`. + */ def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest, other.classManifest) + /** + * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements + * mapping to that key. + */ def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = { implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] @@ -92,6 +143,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm) } + /** + * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements + * mapping to that key. + */ def groupBy[K](f: JFunction[T, K], numSplits: Int): JavaPairRDD[K, JList[T]] = { implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] @@ -100,56 +155,114 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numSplits)(f.returnType)))(kcm, vcm) } + /** + * Return an RDD created by piping elements to a forked external process. + */ def pipe(command: String): JavaRDD[String] = rdd.pipe(command) + /** + * Return an RDD created by piping elements to a forked external process. + */ def pipe(command: JList[String]): JavaRDD[String] = rdd.pipe(asScalaBuffer(command)) + /** + * Return an RDD created by piping elements to a forked external process. + */ def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] = rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env)) // Actions (launch a job to return a value to the user program) - + + /** + * Applies a function f to all elements of this RDD. + */ def foreach(f: VoidFunction[T]) { val cleanF = rdd.context.clean(f) rdd.foreach(cleanF) } + /** + * Return an array that contains all of the elements in this RDD. + */ def collect(): JList[T] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[T] = rdd.collect().toSeq new java.util.ArrayList(arr) } - + + /** + * Reduces the elements of this RDD using the specified associative binary operator. + */ def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) + /** + * Aggregate the elements of each partition, and then the results for all the partitions, using a + * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to + * modify t1 and return it as its result value to avoid object allocation; however, it should not + * modify t2. + */ def fold(zeroValue: T)(f: JFunction2[T, T, T]): T = rdd.fold(zeroValue)(f) + /** + * Aggregate the elements of each partition, and then the results for all the partitions, using + * given combine functions and a neutral "zero value". This function can return a different result + * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U + * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are + * allowed to modify and return their first argument instead of creating a new U to avoid memory + * allocation. + */ def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U], combOp: JFunction2[U, U, U]): U = rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType) + /** + * Return the number of elements in the RDD. + */ def count(): Long = rdd.count() + /** + * (Experimental) Approximate version of count() that returns a potentially incomplete result + * within a timeout, even if not all tasks have finished. + */ def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = rdd.countApprox(timeout, confidence) + /** + * (Experimental) Approximate version of count() that returns a potentially incomplete result + * within a timeout, even if not all tasks have finished. + */ def countApprox(timeout: Long): PartialResult[BoundedDouble] = rdd.countApprox(timeout) + /** + * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final + * combine step happens locally on the master, equivalent to running a single reduce task. + */ def countByValue(): java.util.Map[T, java.lang.Long] = mapAsJavaMap(rdd.countByValue().map((x => (x._1, new lang.Long(x._2))))) + /** + * (Experimental) Approximate version of countByValue(). + */ def countByValueApprox( timeout: Long, confidence: Double ): PartialResult[java.util.Map[T, BoundedDouble]] = rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap) + /** + * (Experimental) Approximate version of countByValue(). + */ def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] = rdd.countByValueApprox(timeout).map(mapAsJavaMap) + /** + * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so + * it will be slow if a lot of partitions are required. In that case, use collect() to get the + * whole RDD instead. + */ def take(num: Int): JList[T] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[T] = rdd.take(num).toSeq @@ -162,9 +275,18 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } + /** + * Return the first element in this RDD. + */ def first(): T = rdd.first() + /** + * Save this RDD as a text file, using string representations of elements. + */ def saveAsTextFile(path: String) = rdd.saveAsTextFile(path) + /** + * Save this RDD as a SequenceFile of serialized objects. + */ def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path) } diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index 08c92b145e..edbb187b1b 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -1,37 +1,78 @@ package spark.api.java -import spark.{Accumulator, AccumulatorParam, RDD, SparkContext} -import spark.SparkContext.IntAccumulatorParam -import spark.SparkContext.DoubleAccumulatorParam -import spark.broadcast.Broadcast +import java.util.{Map => JMap} +import scala.collection.JavaConversions import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf - import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import spark.{Accumulator, AccumulatorParam, RDD, SparkContext} +import spark.SparkContext.IntAccumulatorParam +import spark.SparkContext.DoubleAccumulatorParam +import spark.broadcast.Broadcast -import scala.collection.JavaConversions - +/** + * A Java-friendly version of [[spark.SparkContext]] that returns [[spark.api.java.JavaRDD]]s and + * works with Java collections instead of Scala ones. + */ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround { - def this(master: String, frameworkName: String) = this(new SparkContext(master, frameworkName)) + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param jobName A name for your job, to display on the cluster web UI + */ + def this(master: String, jobName: String) = this(new SparkContext(master, jobName)) + + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param jobName A name for your job, to display on the cluster web UI + * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + */ + def this(master: String, jobName: String, sparkHome: String, jarFile: String) = + this(new SparkContext(master, jobName, sparkHome, Seq(jarFile))) + + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param jobName A name for your job, to display on the cluster web UI + * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + */ + def this(master: String, jobName: String, sparkHome: String, jars: Array[String]) = + this(new SparkContext(master, jobName, sparkHome, jars.toSeq)) + + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param jobName A name for your job, to display on the cluster web UI + * @param sparkHome The SPARK_HOME directory on the slave nodes + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + * @param environment Environment variables to set on worker nodes + */ + def this(master: String, jobName: String, sparkHome: String, jars: Array[String], + environment: JMap[String, String]) = + this(new SparkContext(master, jobName, sparkHome, jars.toSeq, environment)) - val env = sc.env + private[spark] val env = sc.env + /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { implicit val cm: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) } + /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T]): JavaRDD[T] = parallelize(list, sc.defaultParallelism) - + /** Distribute a local Scala collection to form an RDD. */ def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int) : JavaPairRDD[K, V] = { implicit val kcm: ClassManifest[K] = @@ -41,21 +82,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)) } + /** Distribute a local Scala collection to form an RDD. */ def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]]): JavaPairRDD[K, V] = parallelizePairs(list, sc.defaultParallelism) + /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD = JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()), numSlices)) + /** Distribute a local Scala collection to form an RDD. */ def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD = parallelizeDoubles(list, sc.defaultParallelism) + /** + * Read a text file from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI, and return it as an RDD of Strings. + */ def textFile(path: String): JavaRDD[String] = sc.textFile(path) + /** + * Read a text file from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI, and return it as an RDD of Strings. + */ def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits) - /**Get an RDD for a Hadoop SequenceFile with given key and value types */ + /**Get an RDD for a Hadoop SequenceFile with given key and value types. */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V], @@ -66,6 +118,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits)) } + /**Get an RDD for a Hadoop SequenceFile. */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): JavaPairRDD[K, V] = { implicit val kcm = ClassManifest.fromClass(keyClass) @@ -86,6 +139,13 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork sc.objectFile(path, minSplits)(cm) } + /** + * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and + * BytesWritable values that contain a serialized partition. This is still an experimental storage + * format and may not be supported exactly as is in future Spark releases. It will also be pretty + * slow if you use the default serializer (Java serialization), though the nice thing about it is + * that there's very little effort required to save arbitrary objects. + */ def objectFile[T](path: String): JavaRDD[T] = { implicit val cm: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] @@ -109,6 +169,11 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits)) } + /** + * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any + * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, + * etc). + */ def hadoopRDD[K, V, F <: InputFormat[K, V]]( conf: JobConf, inputFormatClass: Class[F], @@ -120,7 +185,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass)) } - /**Get an RDD for a Hadoop file with an arbitrary InputFormat */ + /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ def hadoopFile[K, V, F <: InputFormat[K, V]]( path: String, inputFormatClass: Class[F], @@ -133,6 +198,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)) } + /** Get an RDD for a Hadoop file with an arbitrary InputFormat */ def hadoopFile[K, V, F <: InputFormat[K, V]]( path: String, inputFormatClass: Class[F], @@ -174,12 +240,14 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass)) } + /** Build the union of two or more RDDs. */ override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) implicit val cm: ClassManifest[T] = first.classManifest sc.union(rdds)(cm) } + /** Build the union of two or more RDDs. */ override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) : JavaPairRDD[K, V] = { val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) @@ -189,26 +257,49 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm) } + /** Build the union of two or more RDDs. */ override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = { val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd) new JavaDoubleRDD(sc.union(rdds)) } + /** + * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values + * to using the `+=` method. Only the master can access the accumulator's `value`. + */ def intAccumulator(initialValue: Int): Accumulator[Int] = sc.accumulator(initialValue)(IntAccumulatorParam) + /** + * Create an [[spark.Accumulator]] double variable, which tasks can "add" values + * to using the `+=` method. Only the master can access the accumulator's `value`. + */ def doubleAccumulator(initialValue: Double): Accumulator[Double] = sc.accumulator(initialValue)(DoubleAccumulatorParam) + /** + * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values + * to using the `+=` method. Only the master can access the accumulator's `value`. + */ def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) + /** + * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for + * reading it in distributed functions. The variable will be sent to each cluster only once. + */ def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value) + /** Shut down the SparkContext. */ def stop() { sc.stop() } + /** + * Get Spark's home location from either a value set through the constructor, + * or the spark.home Java property, or the SPARK_HOME environment variable + * (in that order of preference). If neither of these is set, return None. + */ def getSparkHome(): Option[String] = sc.getSparkHome() } diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java new file mode 100644 index 0000000000..722af3c06c --- /dev/null +++ b/core/src/main/scala/spark/api/java/StorageLevels.java @@ -0,0 +1,20 @@ +package spark.api.java; + +import spark.storage.StorageLevel; + +/** + * Expose some commonly useful storage level constants. + */ +public class StorageLevels { + public static final StorageLevel NONE = new StorageLevel(false, false, false, 1); + public static final StorageLevel DISK_ONLY = new StorageLevel(true, false, false, 1); + public static final StorageLevel DISK_ONLY_2 = new StorageLevel(true, false, false, 2); + public static final StorageLevel MEMORY_ONLY = new StorageLevel(false, true, true, 1); + public static final StorageLevel MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2); + public static final StorageLevel MEMORY_ONLY_SER = new StorageLevel(false, true, false, 1); + public static final StorageLevel MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2); + public static final StorageLevel MEMORY_AND_DISK = new StorageLevel(true, true, true, 1); + public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2); + public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1); + public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2); +} diff --git a/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java index 7b6478c2cd..3a8192be3a 100644 --- a/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java +++ b/core/src/main/scala/spark/api/java/function/DoubleFlatMapFunction.java @@ -5,6 +5,9 @@ import scala.runtime.AbstractFunction1; import java.io.Serializable; +/** + * A function that returns zero or more records of type Double from each input record. + */ // DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is // overloaded for both FlatMapFunction and DoubleFlatMapFunction. public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>> diff --git a/core/src/main/scala/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/spark/api/java/function/DoubleFunction.java index a03a72c835..c6ef76d088 100644 --- a/core/src/main/scala/spark/api/java/function/DoubleFunction.java +++ b/core/src/main/scala/spark/api/java/function/DoubleFunction.java @@ -5,6 +5,9 @@ import scala.runtime.AbstractFunction1; import java.io.Serializable; +/** + * A function that returns Doubles, and can be used to construct DoubleRDDs. + */ // DoubleFunction does not extend Function because some UDF functions, like map, // are overloaded for both Function and DoubleFunction. public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double> diff --git a/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala index bcba38c569..e027cdacd3 100644 --- a/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala +++ b/core/src/main/scala/spark/api/java/function/FlatMapFunction.scala @@ -1,5 +1,8 @@ package spark.api.java.function +/** + * A function that returns zero or more output records from each input record. + */ abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] { @throws(classOf[Exception]) def call(x: T) : java.lang.Iterable[R] diff --git a/core/src/main/scala/spark/api/java/function/Function.java b/core/src/main/scala/spark/api/java/function/Function.java index f6f2e5fd76..dae8295f21 100644 --- a/core/src/main/scala/spark/api/java/function/Function.java +++ b/core/src/main/scala/spark/api/java/function/Function.java @@ -8,8 +8,9 @@ import java.io.Serializable; /** - * Base class for functions whose return types do not have special RDDs; DoubleFunction is - * handled separately, to allow DoubleRDDs to be constructed when mapping RDDs to doubles. + * Base class for functions whose return types do not create special RDDs. PairFunction and + * DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed + * when mapping RDDs of other types. */ public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable { public abstract R call(T t) throws Exception; diff --git a/core/src/main/scala/spark/api/java/function/Function2.java b/core/src/main/scala/spark/api/java/function/Function2.java index be48b173b8..69bf12c8c9 100644 --- a/core/src/main/scala/spark/api/java/function/Function2.java +++ b/core/src/main/scala/spark/api/java/function/Function2.java @@ -6,6 +6,9 @@ import scala.runtime.AbstractFunction2; import java.io.Serializable; +/** + * A two-argument function that takes arguments of type T1 and T2 and returns an R. + */ public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R> implements Serializable { diff --git a/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java index c074b9c717..b3cc4df6aa 100644 --- a/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java +++ b/core/src/main/scala/spark/api/java/function/PairFlatMapFunction.java @@ -7,6 +7,10 @@ import scala.runtime.AbstractFunction1; import java.io.Serializable; +/** + * A function that returns zero or more key-value pair records from each input record. The + * key-value pairs are represented as scala.Tuple2 objects. + */ // PairFlatMapFunction does not extend FlatMapFunction because flatMap is // overloaded for both FlatMapFunction and PairFlatMapFunction. public abstract class PairFlatMapFunction<T, K, V> diff --git a/core/src/main/scala/spark/api/java/function/PairFunction.java b/core/src/main/scala/spark/api/java/function/PairFunction.java index 7f5bb7de13..9fc6df4b88 100644 --- a/core/src/main/scala/spark/api/java/function/PairFunction.java +++ b/core/src/main/scala/spark/api/java/function/PairFunction.java @@ -7,6 +7,9 @@ import scala.runtime.AbstractFunction1; import java.io.Serializable; +/** + * A function that returns key-value pairs (Tuple2<K, V>), and can be used to construct PairRDDs. + */ // PairFunction does not extend Function because some UDF functions, like map, // are overloaded for both Function and PairFunction. public abstract class PairFunction<T, K, V> diff --git a/core/src/main/scala/spark/api/java/function/VoidFunction.scala b/core/src/main/scala/spark/api/java/function/VoidFunction.scala index 0eefe337e8..b0096cf2bf 100644 --- a/core/src/main/scala/spark/api/java/function/VoidFunction.scala +++ b/core/src/main/scala/spark/api/java/function/VoidFunction.scala @@ -1,5 +1,8 @@ package spark.api.java.function +/** + * A function with no return value. + */ // This allows Java users to write void methods without having to return Unit. abstract class VoidFunction[T] extends Serializable { @throws(classOf[Exception]) diff --git a/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala b/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala index d08e1e9fbf..923f5cdf4f 100644 --- a/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala +++ b/core/src/main/scala/spark/api/java/function/WrappedFunction1.scala @@ -7,7 +7,7 @@ import scala.runtime.AbstractFunction1 * apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply * isn't marked to allow that). */ -abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] { +private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] { @throws(classOf[Exception]) def call(t: T): R diff --git a/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala b/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala index c9d67d9771..2c6e9b1571 100644 --- a/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala +++ b/core/src/main/scala/spark/api/java/function/WrappedFunction2.scala @@ -7,7 +7,7 @@ import scala.runtime.AbstractFunction2 * apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply * isn't marked to allow that). */ -abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] { +private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] { @throws(classOf[Exception]) def call(t1: T1, t2: T2): R diff --git a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala index e009d4e7db..ef27bbb502 100644 --- a/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/BitTorrentBroadcast.scala @@ -2,21 +2,26 @@ package spark.broadcast import java.io._ import java.net._ -import java.util.{BitSet, Comparator, Random, Timer, TimerTask, UUID} +import java.util.{BitSet, Comparator, Timer, TimerTask, UUID} import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ListBuffer, Map, Set} import scala.math import spark._ +import spark.storage.StorageLevel -class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging with Serializable { +private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) + extends Broadcast[T](id) + with Logging + with Serializable { def value = value_ - BitTorrentBroadcast.synchronized { - BitTorrentBroadcast.values.put(uuid, 0, value_) + def blockId: String = "broadcast_" + id + + MultiTracker.synchronized { + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @@ -25,8 +30,6 @@ extends Broadcast[T] with Logging with Serializable { @transient var totalBytes = -1 @transient var totalBlocks = -1 @transient var hasBlocks = new AtomicInteger(0) - // CHANGED: BlockSize in the Broadcast object is expected to change over time - @transient var blockSize = Broadcast.BlockSize // Used ONLY by Master to track how many unique blocks have been sent out @transient var sentBlocks = new AtomicInteger(0) @@ -45,37 +48,24 @@ extends Broadcast[T] with Logging with Serializable { // Used only in Workers @transient var ttGuide: TalkToGuide = null - @transient var rxSpeeds = new SpeedTracker - @transient var txSpeeds = new SpeedTracker - - @transient var hostAddress = Utils.localIpAddress + @transient var hostAddress = Utils.localIpAddress() @transient var listenPort = -1 @transient var guidePort = -1 - @transient var hasCopyInHDFS = false @transient var stopBroadcast = false // Must call this after all the variables have been created/initialized if (!isLocal) { - sendBroadcast + sendBroadcast() } def sendBroadcast() { logInfo("Local host address: " + hostAddress) - // Store a persistent copy in HDFS - // TODO: Turned OFF for now. Related to persistence - // val out = new ObjectOutputStream(BroadcastCH.openFileForWriting(uuid)) - // out.writeObject(value_) - // out.close() - // FIXME: Fix this at some point - hasCopyInHDFS = true - // Create a variableInfo object and store it in valueInfos - var variableInfo = Broadcast.blockifyObject(value_) + var variableInfo = MultiTracker.blockifyObject(value_) // Prepare the value being broadcasted - // TODO: Refactoring and clean-up required here arrayOfBlocks = variableInfo.arrayOfBlocks totalBytes = variableInfo.totalBytes totalBlocks = variableInfo.totalBlocks @@ -95,9 +85,7 @@ extends Broadcast[T] with Logging with Serializable { // Must always come AFTER guideMR is created while (guidePort == -1) { - guidePortLock.synchronized { - guidePortLock.wait() - } + guidePortLock.synchronized { guidePortLock.wait() } } serveMR = new ServeMultipleRequests @@ -107,14 +95,12 @@ extends Broadcast[T] with Logging with Serializable { // Must always come AFTER serveMR is created while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } + listenPortLock.synchronized { listenPortLock.wait() } } // Must always come AFTER listenPort is created val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize) + SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) hasBlocksBitVector.synchronized { masterSource.hasBlocksBitVector = hasBlocksBitVector } @@ -123,46 +109,44 @@ extends Broadcast[T] with Logging with Serializable { listOfSources += masterSource // Register with the Tracker - registerBroadcast(uuid, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes, blockSize)) + MultiTracker.registerBroadcast(id, + SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) } private def readObject(in: ObjectInputStream) { in.defaultReadObject() - BitTorrentBroadcast.synchronized { - val cachedVal = BitTorrentBroadcast.values.get(uuid, 0) - - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - // Only the first worker in a node can ever be inside this 'else' - initializeWorkerVariables - - logInfo("Local host address: " + hostAddress) - - // Start local ServeMultipleRequests thread first - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(uuid) - // If does not succeed, then get from HDFS copy - if (receptionSucceeded) { - value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - BitTorrentBroadcast.values.put(uuid, 0, value_) - } else { - // TODO: This part won't work, cause HDFS writing is turned OFF - val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - BitTorrentBroadcast.values.put(uuid, 0, value_) - fileIn.close() - } + MultiTracker.synchronized { + SparkEnv.get.blockManager.getSingle(blockId) match { + case Some(x) => + value_ = x.asInstanceOf[T] + + case None => + logInfo("Started reading broadcast variable " + id) + // Initializing everything because Master will only send null/0 values + // Only the 1st worker in a node can be here. Others will get from cache + initializeWorkerVariables() + + logInfo("Local host address: " + hostAddress) + + // Start local ServeMultipleRequests thread first + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start() + logInfo("ServeMultipleRequests started...") + + val start = System.nanoTime + + val receptionSucceeded = receiveBroadcast(id) + if (receptionSucceeded) { + value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + } else { + logError("Reading broadcast variable " + id + " failed") + } - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + id + " took " + time + " s") } } } @@ -175,7 +159,6 @@ extends Broadcast[T] with Logging with Serializable { totalBytes = -1 totalBlocks = -1 hasBlocks = new AtomicInteger(0) - blockSize = -1 listenPortLock = new Object totalBlocksLock = new Object @@ -183,9 +166,6 @@ extends Broadcast[T] with Logging with Serializable { serveMR = null ttGuide = null - rxSpeeds = new SpeedTracker - txSpeeds = new SpeedTracker - hostAddress = Utils.localIpAddress listenPort = -1 @@ -194,75 +174,19 @@ extends Broadcast[T] with Logging with Serializable { stopBroadcast = false } - private def registerBroadcast(uuid: UUID, gInfo: SourceInfo) { - val socket = new Socket(Broadcast.MasterHostAddress, - Broadcast.MasterTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(Broadcast.REGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send UUID of this broadcast - oosST.writeObject(uuid) - oosST.flush() - - // Send this tracker's information - oosST.writeObject(gInfo) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - private def unregisterBroadcast(uuid: UUID) { - val socket = new Socket(Broadcast.MasterHostAddress, - Broadcast.MasterTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(Broadcast.UNREGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send UUID of this broadcast - oosST.writeObject(uuid) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - private def getLocalSourceInfo: SourceInfo = { // Wait till hostName and listenPort are OK while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } + listenPortLock.synchronized { listenPortLock.wait() } } // Wait till totalBlocks and totalBytes are OK while (totalBlocks == -1) { - totalBlocksLock.synchronized { - totalBlocksLock.wait() - } + totalBlocksLock.synchronized { totalBlocksLock.wait() } } var localSourceInfo = SourceInfo( - hostAddress, listenPort, totalBlocks, totalBytes, blockSize) + hostAddress, listenPort, totalBlocks, totalBytes) localSourceInfo.hasBlocks = hasBlocks.get @@ -274,7 +198,7 @@ extends Broadcast[T] with Logging with Serializable { } // Add new SourceInfo to the listOfSources. Update if it exists already. - // TODO: Optimizing just by OR-ing the BitVectors was BAD for performance + // Optimizing just by OR-ing the BitVectors was BAD for performance private def addToListOfSources(newSourceInfo: SourceInfo) { listOfSources.synchronized { if (listOfSources.contains(newSourceInfo)) { @@ -297,9 +221,9 @@ extends Broadcast[T] with Logging with Serializable { // Keep exchaning information until all blocks have been received while (hasBlocks.get < totalBlocks) { talkOnce - Thread.sleep(BitTorrentBroadcast.ranGen.nextInt( - Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + - Broadcast.MinKnockInterval) + Thread.sleep(MultiTracker.ranGen.nextInt( + MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + + MultiTracker.MinKnockInterval) } // Talk one more time to let the Guide know of reception completion @@ -324,7 +248,7 @@ extends Broadcast[T] with Logging with Serializable { // Receive source information from Guide var suitableSources = oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logInfo("Received suitableSources from Master " + suitableSources) + logDebug("Received suitableSources from Master " + suitableSources) addToListOfSources(suitableSources) @@ -334,76 +258,17 @@ extends Broadcast[T] with Logging with Serializable { } } - def getGuideInfo(variableUUID: UUID): SourceInfo = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToHDFS) + def receiveBroadcast(variableID: Long): Boolean = { + val gInfo = MultiTracker.getGuideInfo(variableID) - var retriesLeft = Broadcast.MaxRetryCount - do { - try { - // Connect to the tracker to find out GuideInfo - clientSocketToTracker = - new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send messageType/intention - oosTracker.writeObject(Broadcast.FIND_BROADCAST_TRACKER) - oosTracker.flush() - - // Send UUID and receive GuideInfo - oosTracker.writeObject(uuid) - oosTracker.flush() - gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] - } catch { - case e: Exception => { - logInfo("getGuideInfo had a " + e) - } - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - - Thread.sleep(BitTorrentBroadcast.ranGen.nextInt( - Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + - Broadcast.MinKnockInterval) - - retriesLeft -= 1 - } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) - - logInfo("Got this guidePort from Tracker: " + gInfo.listenPort) - return gInfo - } - - def receiveBroadcast(variableUUID: UUID): Boolean = { - val gInfo = getGuideInfo(variableUUID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToHDFS || - gInfo.listenPort == SourceInfo.TxNotStartedRetry) { - // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go - // to HDFS anyway when receiveBroadcast returns false + if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { return false } // Wait until hostAddress and listenPort are created by the // ServeMultipleRequests thread while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } + listenPortLock.synchronized { listenPortLock.wait() } } // Setup initial states of variables @@ -411,11 +276,8 @@ extends Broadcast[T] with Logging with Serializable { arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) hasBlocksBitVector = new BitSet(totalBlocks) numCopiesSent = new Array[Int](totalBlocks) - totalBlocksLock.synchronized { - totalBlocksLock.notifyAll() - } + totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } totalBytes = gInfo.totalBytes - blockSize = gInfo.blockSize // Start ttGuide to periodically talk to the Guide var ttGuide = new TalkToGuide(gInfo) @@ -432,7 +294,7 @@ extends Broadcast[T] with Logging with Serializable { // FIXME: Must fix this. This might never break if broadcast fails. // We should be able to break and send false. Also need to kill threads while (hasBlocks.get < totalBlocks) { - Thread.sleep(Broadcast.MaxKnockInterval) + Thread.sleep(MultiTracker.MaxKnockInterval) } return true @@ -446,36 +308,36 @@ extends Broadcast[T] with Logging with Serializable { private var blocksInRequestBitVector = new BitSet(totalBlocks) override def run() { - var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxRxSlots) + var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) while (hasBlocks.get < totalBlocks) { - var numThreadsToCreate = - math.min(listOfSources.size, Broadcast.MaxRxSlots) - + var numThreadsToCreate = 0 + listOfSources.synchronized { + numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) - threadPool.getActiveCount + } while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) { var peerToTalkTo = pickPeerToTalkToRandom if (peerToTalkTo != null) - logInfo("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector) + logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector) else - logInfo("No peer chosen...") + logDebug("No peer chosen...") if (peerToTalkTo != null) { threadPool.execute(new TalkToPeer(peerToTalkTo)) // Add to peersNowTalking. Remove in the thread. We have to do this // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once - peersNowTalking.synchronized { - peersNowTalking += peerToTalkTo - } + peersNowTalking.synchronized { peersNowTalking += peerToTalkTo } } numThreadsToCreate = numThreadsToCreate - 1 } // Sleep for a while before starting some more threads - Thread.sleep(Broadcast.MinKnockInterval) + Thread.sleep(MultiTracker.MinKnockInterval) } // Shutdown the thread pool threadPool.shutdown() @@ -487,7 +349,7 @@ extends Broadcast[T] with Logging with Serializable { var curPeer: SourceInfo = null var curMax = 0 - logInfo("Picking peers to talk to...") + logDebug("Picking peers to talk to...") // Find peers that are not connected right now var peersNotInUse = ListBuffer[SourceInfo]() @@ -512,11 +374,10 @@ extends Broadcast[T] with Logging with Serializable { } } - // TODO: Always pick randomly or randomly pick randomly? - // Now always picking randomly + // Always picking randomly if (curPeer == null && peersNotInUse.size > 0) { // Pick uniformly the i'th required peer - var i = BitTorrentBroadcast.ranGen.nextInt(peersNotInUse.size) + var i = MultiTracker.ranGen.nextInt(peersNotInUse.size) var peerIter = peersNotInUse.iterator curPeer = peerIter.next @@ -552,8 +413,8 @@ extends Broadcast[T] with Logging with Serializable { } } - // TODO: A block is rare if there are at most 2 copies of that block - // TODO: This CONSTANT could be a function of the neighborhood size + // A block is considered rare if there are at most 2 copies of that block + // This CONSTANT could be a function of the neighborhood size var rareBlocksIndices = ListBuffer[Int]() for (i <- 0 until totalBlocks) { if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) { @@ -587,7 +448,7 @@ extends Broadcast[T] with Logging with Serializable { // Sort the peers based on how many rare blocks they have peersWithRareBlocks.sortBy(_._2) - var randomNumber = BitTorrentBroadcast.ranGen.nextDouble + var randomNumber = MultiTracker.ranGen.nextDouble var tempSum = 0.0 var i = 0 @@ -625,7 +486,7 @@ extends Broadcast[T] with Logging with Serializable { } var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, Broadcast.MaxKnockInterval) + timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval) logInfo("TalkToPeer started... => " + peerToTalkTo) @@ -677,7 +538,7 @@ extends Broadcast[T] with Logging with Serializable { val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] val receptionTime = (System.currentTimeMillis - recvStartTime) - logInfo("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") + logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") if (!hasBlocksBitVector.get(bcBlock.blockID)) { arrayOfBlocks(bcBlock.blockID) = bcBlock @@ -688,8 +549,6 @@ extends Broadcast[T] with Logging with Serializable { hasBlocks.getAndIncrement } - rxSpeeds.addDataPoint(peerToTalkTo, receptionTime) - // Some block(may NOT be blockToAskFor) has arrived. // In any case, blockToAskFor is not in request any more blocksInRequestBitVector.synchronized { @@ -710,7 +569,7 @@ extends Broadcast[T] with Logging with Serializable { // connection due to timeout case eofe: java.io.EOFException => { } case e: Exception => { - logInfo("TalktoPeer had a " + e) + logError("TalktoPeer had a " + e) // FIXME: Remove 'newPeerToTalkTo' from listOfSources // We probably should have the following in some form, but not // really here. This exception can happen if the sender just breaks connection @@ -741,8 +600,8 @@ extends Broadcast[T] with Logging with Serializable { } // Include blocks already in transmission ONLY IF - // BitTorrentBroadcast.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) { + // MultiTracker.EndGameFraction has NOT been achieved + if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { blocksInRequestBitVector.synchronized { needBlocksBitVector.or(blocksInRequestBitVector) } @@ -758,7 +617,7 @@ extends Broadcast[T] with Logging with Serializable { return -1 } else { // Pick uniformly the i'th required block - var i = BitTorrentBroadcast.ranGen.nextInt(needBlocksBitVector.cardinality) + var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality) var pickedBlockIndex = needBlocksBitVector.nextSetBit(0) while (i > 0) { @@ -781,8 +640,8 @@ extends Broadcast[T] with Logging with Serializable { } // Include blocks already in transmission ONLY IF - // BitTorrentBroadcast.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < Broadcast.EndGameFraction) { + // MultiTracker.EndGameFraction has NOT been achieved + if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { blocksInRequestBitVector.synchronized { needBlocksBitVector.or(blocksInRequestBitVector) } @@ -830,7 +689,7 @@ extends Broadcast[T] with Logging with Serializable { return -1 } else { // Pick uniformly the i'th index - var i = BitTorrentBroadcast.ranGen.nextInt(minBlocksIndices.size) + var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size) return minBlocksIndices(i) } } @@ -848,9 +707,7 @@ extends Broadcast[T] with Logging with Serializable { } // Delete from peersNowTalking - peersNowTalking.synchronized { - peersNowTalking = peersNowTalking - peerToTalkTo - } + peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo } } } } @@ -868,32 +725,32 @@ extends Broadcast[T] with Logging with Serializable { guidePort = serverSocket.getLocalPort logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - guidePortLock.synchronized { - guidePortLock.notifyAll() - } + guidePortLock.synchronized { guidePortLock.notifyAll() } try { - // Don't stop until there is a copy in HDFS - while (!stopBroadcast || !hasCopyInHDFS) { + while (!stopBroadcast) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) clientSocket = serverSocket.accept() } catch { case e: Exception => { - logInfo("GuideMultipleRequests Timeout.") - // Stop broadcast if at least one worker has connected and // everyone connected so far are done. Comparing with // listOfSources.size - 1, because it includes the Guide itself - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true + listOfSources.synchronized { + setOfCompletedSources.synchronized { + if (listOfSources.size > 1 && + setOfCompletedSources.size == listOfSources.size - 1) { + stopBroadcast = true + logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") + } + } } } } if (clientSocket != null) { - logInfo("Guide: Accepted new client connection:" + clientSocket) + logDebug("Guide: Accepted new client connection:" + clientSocket) try { threadPool.execute(new GuideSingleRequest(clientSocket)) } catch { @@ -911,7 +768,7 @@ extends Broadcast[T] with Logging with Serializable { logInfo("Sending stopBroadcast notifications...") sendStopBroadcastNotifications - unregisterBroadcast(uuid) + MultiTracker.unregisterBroadcast(id) } finally { if (serverSocket != null) { logInfo("GuideMultipleRequests now stopping...") @@ -930,13 +787,10 @@ extends Broadcast[T] with Logging with Serializable { try { // Connect to the source - guideSocketToSource = - new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = - new ObjectOutputStream(guideSocketToSource.getOutputStream) + guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) gosSource.flush() - gisSource = - new ObjectInputStream(guideSocketToSource.getInputStream) + gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) // Throw away whatever comes in gisSource.readObject.asInstanceOf[SourceInfo] @@ -946,7 +800,7 @@ extends Broadcast[T] with Logging with Serializable { gosSource.flush() } catch { case e: Exception => { - logInfo("sendStopBroadcastNotifications had a " + e) + logError("sendStopBroadcastNotifications had a " + e) } } finally { if (gisSource != null) { @@ -980,7 +834,7 @@ extends Broadcast[T] with Logging with Serializable { // Select a suitable source and send it back to the worker selectedSources = selectSuitableSources(sourceInfo) - logInfo("Sending selectedSources:" + selectedSources) + logDebug("Sending selectedSources:" + selectedSources) oos.writeObject(selectedSources) oos.flush() @@ -990,12 +844,11 @@ extends Broadcast[T] with Logging with Serializable { case e: Exception => { // Assuming exception caused by receiver failure: remove if (listOfSources != null) { - listOfSources.synchronized { - listOfSources = listOfSources - sourceInfo - } + listOfSources.synchronized { listOfSources -= sourceInfo } } } } finally { + logInfo("GuideSingleRequest is closing streams and sockets") ois.close() oos.close() clientSocket.close() @@ -1009,24 +862,22 @@ extends Broadcast[T] with Logging with Serializable { // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' // then add skipSourceInfo to setOfCompletedSources. Return blank. if (skipSourceInfo.hasBlocks == totalBlocks) { - setOfCompletedSources.synchronized { - setOfCompletedSources += skipSourceInfo - } + setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo } return selectedSources } listOfSources.synchronized { - if (listOfSources.size <= Broadcast.MaxPeersInGuideResponse) { + if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) { selectedSources = listOfSources.clone } else { - var picksLeft = Broadcast.MaxPeersInGuideResponse + var picksLeft = MultiTracker.MaxPeersInGuideResponse var alreadyPicked = new BitSet(listOfSources.size) while (picksLeft > 0) { var i = -1 do { - i = BitTorrentBroadcast.ranGen.nextInt(listOfSources.size) + i = MultiTracker.ranGen.nextInt(listOfSources.size) } while (alreadyPicked.get(i)) var peerIter = listOfSources.iterator @@ -1057,8 +908,8 @@ extends Broadcast[T] with Logging with Serializable { class ServeMultipleRequests extends Thread with Logging { - // Server at most Broadcast.MaxTxSlots peers - var threadPool = Utils.newDaemonFixedThreadPool(Broadcast.MaxTxSlots) + // Server at most MultiTracker.MaxChatSlots peers + var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) override def run() { var serverSocket = new ServerSocket(0) @@ -1066,30 +917,24 @@ extends Broadcast[T] with Logging with Serializable { logInfo("ServeMultipleRequests started with " + serverSocket) - listenPortLock.synchronized { - listenPortLock.notifyAll() - } + listenPortLock.synchronized { listenPortLock.notifyAll() } try { while (!stopBroadcast) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) clientSocket = serverSocket.accept() } catch { - case e: Exception => { - logInfo("ServeMultipleRequests Timeout.") - } + case e: Exception => { } } if (clientSocket != null) { - logInfo("Serve: Accepted new client connection:" + clientSocket) + logDebug("Serve: Accepted new client connection:" + clientSocket) try { threadPool.execute(new ServeSingleRequest(clientSocket)) } catch { // In failure, close socket here; else, the thread will close it - case ioe: IOException => { - clientSocket.close() - } + case ioe: IOException => clientSocket.close() } } } @@ -1125,14 +970,13 @@ extends Broadcast[T] with Logging with Serializable { if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { stopBroadcast = true } else { - // Carry on addToListOfSources(rxSourceInfo) } val startTime = System.currentTimeMillis var curTime = startTime var keepSending = true - var numBlocksToSend = Broadcast.MaxChatBlocks + var numBlocksToSend = MultiTracker.MaxChatBlocks while (!stopBroadcast && keepSending && numBlocksToSend > 0) { // Receive which block to send @@ -1140,7 +984,7 @@ extends Broadcast[T] with Logging with Serializable { // If it is master AND at least one copy of each block has not been // sent out already, MODIFY blockToSend - if (BitTorrentBroadcast.isMaster && sentBlocks.get < totalBlocks) { + if (MultiTracker.isMaster && sentBlocks.get < totalBlocks) { blockToSend = sentBlocks.getAndIncrement } @@ -1152,27 +996,21 @@ extends Broadcast[T] with Logging with Serializable { // Receive latest SourceInfo from the receiver rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - // logInfo("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) + logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) addToListOfSources(rxSourceInfo) curTime = System.currentTimeMillis // Revoke sending only if there is anyone waiting in the queue - if (curTime - startTime >= Broadcast.MaxChatTime && + if (curTime - startTime >= MultiTracker.MaxChatTime && threadPool.getQueue.size > 0) { keepSending = false } } } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - // Exception can happen if the receiver stops receiving - case e: Exception => { - logInfo("ServeSingleRequest had a " + e) - } + case e: Exception => logError("ServeSingleRequest had a " + e) } finally { logInfo("ServeSingleRequest is closing streams and sockets") ois.close() - // TODO: The following line causes a "java.net.SocketException: Socket closed" oos.close() clientSocket.close() } @@ -1183,173 +1021,20 @@ extends Broadcast[T] with Logging with Serializable { oos.writeObject(arrayOfBlocks(blockToSend)) oos.flush() } catch { - case e: Exception => { - logInfo("sendBlock had a " + e) - } + case e: Exception => logError("sendBlock had a " + e) } - logInfo("Sent block: " + blockToSend + " to " + clientSocket) + logDebug("Sent block: " + blockToSend + " to " + clientSocket) } } } } -class BitTorrentBroadcastFactory +private[spark] class BitTorrentBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { - BitTorrentBroadcast.initialize(isMaster) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean) = { - new BitTorrentBroadcast[T](value_, isLocal) - } -} - -private object BitTorrentBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - - var valueToGuideMap = Map[UUID, SourceInfo]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var isMaster_ = false + def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } - private var trackMV: TrackMultipleValues = null + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new BitTorrentBroadcast[T](value_, isLocal, id) - def initialize(isMaster__ : Boolean) { - synchronized { - if (!initialized) { - - isMaster_ = isMaster__ - - if (isMaster) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - // TODO: Logging the following line makes the Spark framework ID not - // getting logged, cause it calls logInfo before log4j is initialized - logInfo("TrackMultipleValues started...") - } - - // Initialize DfsBroadcast to be used for broadcast variable persistence - // TODO: Think about persistence - DfsBroadcast.initialize - - initialized = true - } - } - } - - def isMaster = isMaster_ - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) - logInfo("TrackMultipleValues" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - logInfo("TrackMultipleValues Timeout. Stopping listening...") - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // First, read message type - val messageType = ois.readObject.asInstanceOf[Int] - - if (messageType == Broadcast.REGISTER_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] - // Receive hostAddress and listenPort - val gInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Add to the map - valueToGuideMap.synchronized { - valueToGuideMap += (uuid -> gInfo) - } - - logInfo ("New broadcast registered with TrackMultipleValues " + uuid + " " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == Broadcast.UNREGISTER_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] - - // Remove from the map - valueToGuideMap.synchronized { - valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToHDFS) - logInfo("Value unregistered from the Tracker " + valueToGuideMap) - } - - logInfo ("Broadcast unregistered from TrackMultipleValues " + uuid + " " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == Broadcast.FIND_BROADCAST_TRACKER) { - // Receive UUID - val uuid = ois.readObject.asInstanceOf[UUID] - - var gInfo = - if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid) - else SourceInfo("", SourceInfo.TxNotStartedRetry) - - logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort) - - // Send reply back - oos.writeObject(gInfo) - oos.flush() - } else if (messageType == Broadcast.GET_UPDATED_SHARE) { - // TODO: Not implemented - } else { - throw new SparkException("Undefined messageType at TrackMultipleValues") - } - } catch { - case e: Exception => { - logInfo("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => { - clientSocket.close() - } - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } + def stop() { MultiTracker.stop() } } diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala index eaa9153279..6055bfd045 100644 --- a/core/src/main/scala/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/spark/broadcast/Broadcast.scala @@ -1,36 +1,29 @@ package spark.broadcast import java.io._ -import java.net._ -import java.util.{BitSet, UUID} -import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor} +import java.util.concurrent.atomic.AtomicLong import spark._ -trait Broadcast[T] extends Serializable { - val uuid = UUID.randomUUID - +abstract class Broadcast[T](id: Long) extends Serializable { def value: T // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. Possibly a Scala bug! + // readObject having to be 'private' in sub-classes. - override def toString = "spark.Broadcast(" + uuid + ")" + override def toString = "spark.Broadcast(" + id + ")" } -object Broadcast extends Logging with Serializable { - // Messages - val REGISTER_BROADCAST_TRACKER = 0 - val UNREGISTER_BROADCAST_TRACKER = 1 - val FIND_BROADCAST_TRACKER = 2 - val GET_UPDATED_SHARE = 3 +private[spark] +class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable { private var initialized = false - private var isMaster_ = false private var broadcastFactory: BroadcastFactory = null + initialize() + // Called by SparkContext or Executor before using Broadcast - def initialize (isMaster__ : Boolean) { + private def initialize() { synchronized { if (!initialized) { val broadcastFactoryClass = System.getProperty( @@ -39,14 +32,6 @@ object Broadcast extends Logging with Serializable { broadcastFactory = Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - // Setup isMaster before using it - isMaster_ = isMaster__ - - // Set masterHostAddress to the master's IP address for the slaves to read - if (isMaster) { - System.setProperty("spark.broadcast.masterHostAddress", Utils.localIpAddress) - } - // Initialize appropriate BroadcastFactory and BroadcastObject broadcastFactory.initialize(isMaster) @@ -55,170 +40,14 @@ object Broadcast extends Logging with Serializable { } } - def getBroadcastFactory: BroadcastFactory = { - if (broadcastFactory == null) { - throw new SparkException ("Broadcast.getBroadcastFactory called before initialize") - } - broadcastFactory + def stop() { + broadcastFactory.stop() } - // Load common broadcast-related config parameters - private var MasterHostAddress_ = System.getProperty( - "spark.broadcast.masterHostAddress", "") - private var MasterTrackerPort_ = System.getProperty( - "spark.broadcast.masterTrackerPort", "11111").toInt - private var BlockSize_ = System.getProperty( - "spark.broadcast.blockSize", "4096").toInt * 1024 - private var MaxRetryCount_ = System.getProperty( - "spark.broadcast.maxRetryCount", "2").toInt - - private var TrackerSocketTimeout_ = System.getProperty( - "spark.broadcast.trackerSocketTimeout", "50000").toInt - private var ServerSocketTimeout_ = System.getProperty( - "spark.broadcast.serverSocketTimeout", "10000").toInt - - private var MinKnockInterval_ = System.getProperty( - "spark.broadcast.minKnockInterval", "500").toInt - private var MaxKnockInterval_ = System.getProperty( - "spark.broadcast.maxKnockInterval", "999").toInt - - // Load ChainedBroadcast config params - - // Load TreeBroadcast config params - private var MaxDegree_ = System.getProperty("spark.broadcast.maxDegree", "2").toInt - - // Load BitTorrentBroadcast config params - private var MaxPeersInGuideResponse_ = System.getProperty( - "spark.broadcast.maxPeersInGuideResponse", "4").toInt - - private var MaxRxSlots_ = System.getProperty("spark.broadcast.maxRxSlots", "4").toInt - private var MaxTxSlots_ = System.getProperty("spark.broadcast.maxTxSlots", "4").toInt - - private var MaxChatTime_ = System.getProperty("spark.broadcast.maxChatTime", "500").toInt - private var MaxChatBlocks_ = System.getProperty("spark.broadcast.maxChatBlocks", "1024").toInt + private val nextBroadcastId = new AtomicLong(0) - private var EndGameFraction_ = System.getProperty( - "spark.broadcast.endGameFraction", "0.95").toDouble + def newBroadcast[T](value_ : T, isLocal: Boolean) = + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) def isMaster = isMaster_ - - // Common config params - def MasterHostAddress = MasterHostAddress_ - def MasterTrackerPort = MasterTrackerPort_ - def BlockSize = BlockSize_ - def MaxRetryCount = MaxRetryCount_ - - def TrackerSocketTimeout = TrackerSocketTimeout_ - def ServerSocketTimeout = ServerSocketTimeout_ - - def MinKnockInterval = MinKnockInterval_ - def MaxKnockInterval = MaxKnockInterval_ - - // ChainedBroadcast configs - - // TreeBroadcast configs - def MaxDegree = MaxDegree_ - - // BitTorrentBroadcast configs - def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ - - def MaxRxSlots = MaxRxSlots_ - def MaxTxSlots = MaxTxSlots_ - - def MaxChatTime = MaxChatTime_ - def MaxChatBlocks = MaxChatBlocks_ - - def EndGameFraction = EndGameFraction_ - - // Helper functions to convert an object to Array[BroadcastBlock] - def blockifyObject[IN](obj: IN): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream(baos) - oos.writeObject(obj) - oos.close() - baos.close() - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream(byteArray) - - var blockNum = (byteArray.length / Broadcast.BlockSize) - if (byteArray.length % Broadcast.BlockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock](blockNum) - var blockID = 0 - - for (i <- 0 until (byteArray.length, Broadcast.BlockSize)) { - val thisBlockSize = math.min(Broadcast.BlockSize, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) - - retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) - blockID += 1 - } - bais.close() - - var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } - - // Helper function to convert Array[BroadcastBlock] to object - def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], - totalBytes: Int, - totalBlocks: Int): OUT = { - - var retByteArray = new Array[Byte](totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, - i * Broadcast.BlockSize, arrayOfBlocks(i).byteArray.length) - } - byteArrayToObject(retByteArray) - } - - private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) - } - val retVal = in.readObject.asInstanceOf[OUT] - in.close() - return retVal - } -} - -case class BroadcastBlock (blockID: Int, byteArray: Array[Byte]) extends Serializable - -case class VariableInfo (@transient arrayOfBlocks : Array[BroadcastBlock], - totalBlocks: Int, - totalBytes: Int) - extends Serializable { - - @transient - var hasBlocks = 0 -} - -class SpeedTracker extends Serializable { - // Mapping 'source' to '(totalTime, numBlocks)' - private var sourceToSpeedMap = Map[SourceInfo, (Long, Int)] () - - def addDataPoint (srcInfo: SourceInfo, timeInMillis: Long) { - sourceToSpeedMap.synchronized { - if (!sourceToSpeedMap.contains(srcInfo)) { - sourceToSpeedMap += (srcInfo -> (timeInMillis, 1)) - } else { - val tTnB = sourceToSpeedMap (srcInfo) - sourceToSpeedMap += (srcInfo -> (tTnB._1 + timeInMillis, tTnB._2 + 1)) - } - } - } - - def getTimePerBlock (srcInfo: SourceInfo): Double = { - sourceToSpeedMap.synchronized { - val tTnB = sourceToSpeedMap (srcInfo) - return tTnB._1 / tTnB._2 - } - } - - override def toString = sourceToSpeedMap.toString } diff --git a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala index b18908f789..ab6d302827 100644 --- a/core/src/main/scala/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/spark/broadcast/BroadcastFactory.scala @@ -6,7 +6,8 @@ package spark.broadcast * BroadcastFactory implementation to instantiate a particular broadcast for the * entire Spark job. */ -trait BroadcastFactory { +private[spark] trait BroadcastFactory { def initialize(isMaster: Boolean): Unit - def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T] + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] + def stop(): Unit } diff --git a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala b/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala deleted file mode 100644 index 43290c241f..0000000000 --- a/core/src/main/scala/spark/broadcast/ChainedBroadcast.scala +++ /dev/null @@ -1,794 +0,0 @@ -package spark.broadcast - -import java.io._ -import java.net._ -import java.util.{Comparator, PriorityQueue, Random, UUID} - -import scala.collection.mutable.{Map, Set} -import scala.math - -import spark._ - -class ChainedBroadcast[T](@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging with Serializable { - - def value = value_ - - ChainedBroadcast.synchronized { - ChainedBroadcast.values.put(uuid, 0, value_) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = 0 - // CHANGED: BlockSize in the Broadcast object is expected to change over time - @transient var blockSize = Broadcast.BlockSize - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - @transient var hasBlocksLock = new Object - - @transient var pqOfSources = new PriorityQueue[SourceInfo] - - @transient var serveMR: ServeMultipleRequests = null - @transient var guideMR: GuideMultipleRequests = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var hasCopyInHDFS = false - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Store a persistent copy in HDFS - // TODO: Turned OFF for now - // val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid)) - // out.writeObject(value_) - // out.close() - // TODO: Fix this at some point - hasCopyInHDFS = true - - // Create a variableInfo object and store it in valueInfos - var variableInfo = Broadcast.blockifyObject(value_) - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Prepare the value being broadcasted - // TODO: Refactoring and clean-up required here - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } - } - - pqOfSources = new PriorityQueue[SourceInfo] - val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize) - pqOfSources.add(masterSource) - - // Register with the Tracker - while (guidePort == -1) { - guidePortLock.synchronized { - guidePortLock.wait() - } - } - ChainedBroadcast.registerValue(uuid, guidePort) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - ChainedBroadcast.synchronized { - val cachedVal = ChainedBroadcast.values.get(uuid, 0) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - // Initializing everything because Master will only send null/0 values - initializeSlaveVariables - - logInfo("Local host address: " + hostAddress) - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(uuid) - // If does not succeed, then get from HDFS copy - if (receptionSucceeded) { - value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - ChainedBroadcast.values.put(uuid, 0, value_) - } else { - val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - ChainedBroadcast.values.put(uuid, 0, value_) - fileIn.close() - } - - val time =(System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") - } - } - } - - private def initializeSlaveVariables() { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - blockSize = -1 - - listenPortLock = new Object - totalBlocksLock = new Object - hasBlocksLock = new Object - - serveMR = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - stopBroadcast = false - } - - def getMasterListenPort(variableUUID: UUID): Int = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var masterListenPort: Int = SourceInfo.TxOverGoToHDFS - - var retriesLeft = Broadcast.MaxRetryCount - do { - try { - // Connect to the tracker to find out the guide - clientSocketToTracker = - new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send UUID and receive masterListenPort - oosTracker.writeObject(uuid) - oosTracker.flush() - masterListenPort = oisTracker.readObject.asInstanceOf[Int] - } catch { - case e: Exception => { - logInfo("getMasterListenPort had a " + e) - } - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - retriesLeft -= 1 - - Thread.sleep(ChainedBroadcast.ranGen.nextInt( - Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + - Broadcast.MinKnockInterval) - - } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry) - - logInfo("Got this guidePort from Tracker: " + masterListenPort) - return masterListenPort - } - - def receiveBroadcast(variableUUID: UUID): Boolean = { - val masterListenPort = getMasterListenPort(variableUUID) - - if (masterListenPort == SourceInfo.TxOverGoToHDFS || - masterListenPort == SourceInfo.TxNotStartedRetry) { - // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go - // to HDFS anyway when receiveBroadcast returns false - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } - } - - var clientSocketToMaster: Socket = null - var oosMaster: ObjectOutputStream = null - var oisMaster: ObjectInputStream = null - - // Connect and receive broadcast from the specified source, retrying the - // specified number of times in case of failures - var retriesLeft = Broadcast.MaxRetryCount - do { - // Connect to Master and send this worker's Information - clientSocketToMaster = - new Socket(Broadcast.MasterHostAddress, masterListenPort) - // TODO: Guiding object connection is reusable - oosMaster = - new ObjectOutputStream(clientSocketToMaster.getOutputStream) - oosMaster.flush() - oisMaster = - new ObjectInputStream(clientSocketToMaster.getInputStream) - - logInfo("Connected to Master's guiding object") - - // Send local source information - oosMaster.writeObject(SourceInfo(hostAddress, listenPort)) - oosMaster.flush() - - // Receive source information from Master - var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] - totalBlocks = sourceInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - totalBlocksLock.synchronized { - totalBlocksLock.notifyAll() - } - totalBytes = sourceInfo.totalBytes - - logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) - - val start = System.nanoTime - val receptionSucceeded = receiveSingleTransmission(sourceInfo) - val time =(System.nanoTime - start) / 1e9 - - // Updating some statistics in sourceInfo. Master will be using them later - if (!receptionSucceeded) { - sourceInfo.receptionFailed = true - } - - // Send back statistics to the Master - oosMaster.writeObject(sourceInfo) - - if (oisMaster != null) { - oisMaster.close() - } - if (oosMaster != null) { - oosMaster.close() - } - if (clientSocketToMaster != null) { - clientSocketToMaster.close() - } - - retriesLeft -= 1 - } while (retriesLeft > 0 && hasBlocks < totalBlocks) - - return(hasBlocks == totalBlocks) - } - - // Tries to receive broadcast from the source and returns Boolean status. - // This might be called multiple times to retry a defined number of times. - private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { - var clientSocketToSource: Socket = null - var oosSource: ObjectOutputStream = null - var oisSource: ObjectInputStream = null - - var receptionSucceeded = false - try { - // Connect to the source to get the object itself - clientSocketToSource = - new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = - new ObjectOutputStream(clientSocketToSource.getOutputStream) - oosSource.flush() - oisSource = - new ObjectInputStream(clientSocketToSource.getInputStream) - - logInfo("Inside receiveSingleTransmission") - logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) - - // Send the range - oosSource.writeObject((hasBlocks, totalBlocks)) - oosSource.flush() - - for (i <- hasBlocks until totalBlocks) { - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime =(System.currentTimeMillis - recvStartTime) - - logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") - - arrayOfBlocks(hasBlocks) = bcBlock - hasBlocks += 1 - // Set to true if at least one block is received - receptionSucceeded = true - hasBlocksLock.synchronized { - hasBlocksLock.notifyAll() - } - } - } catch { - case e: Exception => { - logInfo("receiveSingleTransmission had a " + e) - } - } finally { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (clientSocketToSource != null) { - clientSocketToSource.close() - } - } - - return receptionSucceeded - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { - guidePortLock.notifyAll() - } - - try { - // Don't stop until there is a copy in HDFS - while (!stopBroadcast || !hasCopyInHDFS) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("GuideMultipleRequests Timeout.") - - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // pqOfSources.size - 1, because it includes the Guide itself - if (pqOfSources.size > 1 && - setOfCompletedSources.size == pqOfSources.size - 1) { - stopBroadcast = true - } - } - } - if (clientSocket != null) { - logInfo("Guide: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close the socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - ChainedBroadcast.unregisterValue(uuid) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - - // Shutdown the thread pool - threadPool.shutdown() - } - - private def sendStopBroadcastNotifications() { - pqOfSources.synchronized { - var pqIter = pqOfSources.iterator - while (pqIter.hasNext) { - var sourceInfo = pqIter.next - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = - new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = - new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = - new ObjectInputStream(guideSocketToSource.getInputStream) - - // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2 - gosSource.writeObject((SourceInfo.StopBroadcast, - SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logInfo("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var selectedSourceInfo: SourceInfo = null - private var thisWorkerInfo:SourceInfo = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its hostAddress and listenPort it will - // be listening to. Other fields are invalid(SourceInfo.UnusedParam) - var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - pqOfSources.synchronized { - // Select a suitable source and send it back to the worker - selectedSourceInfo = selectSuitableSource(sourceInfo) - logInfo("Sending selectedSourceInfo: " + selectedSourceInfo) - oos.writeObject(selectedSourceInfo) - oos.flush() - - // Add this new(if it can finish) source to the PQ of sources - thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes, blockSize) - logInfo("Adding possible new source to pqOfSources: " + thisWorkerInfo) - pqOfSources.add(thisWorkerInfo) - } - - // Wait till the whole transfer is done. Then receive and update source - // statistics in pqOfSources - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - pqOfSources.synchronized { - // This should work since SourceInfo is a case class - assert(pqOfSources.contains(selectedSourceInfo)) - - // Remove first - pqOfSources.remove(selectedSourceInfo) - // TODO: Removing a source based on just one failure notification! - - // Update sourceInfo and put it back in, IF reception succeeded - if (!sourceInfo.receptionFailed) { - // Add thisWorkerInfo to sources that have completed reception - setOfCompletedSources.synchronized { - setOfCompletedSources += thisWorkerInfo - } - - selectedSourceInfo.currentLeechers -= 1 - - // Put it back - pqOfSources.add(selectedSourceInfo) - } - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - case e: Exception => { - // Assuming that exception caused due to receiver worker failure. - // Remove failed worker from pqOfSources and update leecherCount of - // corresponding source worker - pqOfSources.synchronized { - if (selectedSourceInfo != null) { - // Remove first - pqOfSources.remove(selectedSourceInfo) - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - pqOfSources.add(selectedSourceInfo) - } - - // Remove thisWorkerInfo - if (pqOfSources != null) { - pqOfSources.remove(thisWorkerInfo) - } - } - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - - // FIXME: Caller must have a synchronized block on pqOfSources - // FIXME: If a worker fails to get the broadcasted variable from a source and - // comes back to Master, this function might choose the worker itself as a - // source tp create a dependency cycle(this worker was put into pqOfSources - // as a streming source when it first arrived). The length of this cycle can - // be arbitrarily long. - private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - // Select one based on the ordering strategy(e.g., least leechers etc.) - // take is a blocking call removing the element from PQ - var selectedSource = pqOfSources.poll - assert(selectedSource != null) - // Update leecher count - selectedSource.currentLeechers += 1 - // Add it back and then return - pqOfSources.add(selectedSource) - return selectedSource - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { - listenPortLock.notifyAll() - } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("ServeMultipleRequests Timeout.") - } - } - if (clientSocket != null) { - logInfo("Serve: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sendFrom = 0 - private var sendUntil = totalBlocks - - override def run() { - try { - logInfo("new ServeSingleRequest is running") - - // Receive range to send - var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] - sendFrom = rangeToSend._1 - sendUntil = rangeToSend._2 - - if (sendFrom == SourceInfo.StopBroadcast && - sendUntil == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - // Carry on - sendObject - } - } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close everything up - case e: Exception => { - logInfo("ServeSingleRequest had a " + e) - } - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendObject() { - // Wait till receiving the SourceInfo from Master - while (totalBlocks == -1) { - totalBlocksLock.synchronized { - totalBlocksLock.wait() - } - } - - for (i <- sendFrom until sendUntil) { - while (i == hasBlocks) { - hasBlocksLock.synchronized { - hasBlocksLock.wait() - } - } - try { - oos.writeObject(arrayOfBlocks(i)) - oos.flush() - } catch { - case e: Exception => { - logInfo("sendObject had a " + e) - } - } - logInfo("Sent block: " + i + " to " + clientSocket) - } - } - } - } -} - -class ChainedBroadcastFactory -extends BroadcastFactory { - def initialize(isMaster: Boolean) { - ChainedBroadcast.initialize(isMaster) - } - def newBroadcast[T](value_ : T, isLocal: Boolean) = { - new ChainedBroadcast[T](value_, isLocal) - } -} - -private object ChainedBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - - var valueToGuidePortMap = Map[UUID, Int]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var isMaster_ = false - - private var trackMV: TrackMultipleValues = null - - def initialize(isMaster__ : Boolean) { - synchronized { - if (!initialized) { - isMaster_ = isMaster__ - - if (isMaster) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - // TODO: Logging the following line makes the Spark framework ID not - // getting logged, cause it calls logInfo before log4j is initialized - logInfo("TrackMultipleValues started...") - } - - // Initialize DfsBroadcast to be used for broadcast variable persistence - DfsBroadcast.initialize - - initialized = true - } - } - } - - def isMaster = isMaster_ - - def registerValue(uuid: UUID, guidePort: Int) { - valueToGuidePortMap.synchronized { - valueToGuidePortMap +=(uuid -> guidePort) - logInfo("New value registered with the Tracker " + valueToGuidePortMap) - } - } - - def unregisterValue(uuid: UUID) { - valueToGuidePortMap.synchronized { - valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS - logInfo("Value unregistered from the Tracker " + valueToGuidePortMap) - } - } - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) - logInfo("TrackMultipleValues" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("TrackMultipleValues Timeout. Stopping listening...") - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - try { - val uuid = ois.readObject.asInstanceOf[UUID] - var guidePort = - if (valueToGuidePortMap.contains(uuid)) { - valueToGuidePortMap(uuid) - } else SourceInfo.TxNotStartedRetry - logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) - oos.writeObject(guidePort) - } catch { - case e: Exception => { - logInfo("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - serverSocket.close() - } - - // Shutdown the thread pool - threadPool.shutdown() - } - } -} diff --git a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala b/core/src/main/scala/spark/broadcast/DfsBroadcast.scala deleted file mode 100644 index d18dfb8963..0000000000 --- a/core/src/main/scala/spark/broadcast/DfsBroadcast.scala +++ /dev/null @@ -1,135 +0,0 @@ -package spark.broadcast - -import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} - -import java.io._ -import java.net._ -import java.util.UUID - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} - -import spark._ - -class DfsBroadcast[T](@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging with Serializable { - - def value = value_ - - DfsBroadcast.synchronized { - DfsBroadcast.values.put(uuid, 0, value_) - } - - if (!isLocal) { - sendBroadcast - } - - def sendBroadcast () { - val out = new ObjectOutputStream (DfsBroadcast.openFileForWriting(uuid)) - out.writeObject (value_) - out.close() - } - - // Called by JVM when deserializing an object - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - DfsBroadcast.synchronized { - val cachedVal = DfsBroadcast.values.get(uuid, 0) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - logInfo( "Started reading Broadcasted variable " + uuid) - val start = System.nanoTime - - val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - DfsBroadcast.values.put(uuid, 0, value_) - fileIn.close() - - val time = (System.nanoTime - start) / 1e9 - logInfo( "Reading Broadcasted variable " + uuid + " took " + time + " s") - } - } - } -} - -class DfsBroadcastFactory -extends BroadcastFactory { - def initialize (isMaster: Boolean) { - DfsBroadcast.initialize - } - def newBroadcast[T] (value_ : T, isLocal: Boolean) = - new DfsBroadcast[T] (value_, isLocal) -} - -private object DfsBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - - private var initialized = false - - private var fileSystem: FileSystem = null - private var workDir: String = null - private var compress: Boolean = false - private var bufferSize: Int = 65536 - - def initialize() { - synchronized { - if (!initialized) { - bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val dfs = System.getProperty("spark.dfs", "file:///") - if (!dfs.startsWith("file://")) { - val conf = new Configuration() - conf.setInt("io.file.buffer.size", bufferSize) - val rep = System.getProperty("spark.dfs.replication", "3").toInt - conf.setInt("dfs.replication", rep) - fileSystem = FileSystem.get(new URI(dfs), conf) - } - workDir = System.getProperty("spark.dfs.workDir", "/tmp") - compress = System.getProperty("spark.compress", "false").toBoolean - - initialized = true - } - } - } - - private def getPath(uuid: UUID) = new Path(workDir + "/broadcast-" + uuid) - - def openFileForReading(uuid: UUID): InputStream = { - val fileStream = if (fileSystem != null) { - fileSystem.open(getPath(uuid)) - } else { - // Local filesystem - new FileInputStream(getPath(uuid).toString) - } - - if (compress) { - // LZF stream does its own buffering - new LZFInputStream(fileStream) - } else if (fileSystem == null) { - new BufferedInputStream(fileStream, bufferSize) - } else { - // Hadoop streams do their own buffering - fileStream - } - } - - def openFileForWriting(uuid: UUID): OutputStream = { - val fileStream = if (fileSystem != null) { - fileSystem.create(getPath(uuid)) - } else { - // Local filesystem - new FileOutputStream(getPath(uuid).toString) - } - - if (compress) { - // LZF stream does its own buffering - new LZFOutputStream(fileStream) - } else if (fileSystem == null) { - new BufferedOutputStream(fileStream, bufferSize) - } else { - // Hadoop streams do their own buffering - fileStream - } - } -} diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 6e3dde76bd..7eb4ddb74f 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -10,49 +10,52 @@ import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import spark._ +import spark.storage.StorageLevel -class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging with Serializable { +private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +extends Broadcast[T](id) with Logging with Serializable { def value = value_ - HttpBroadcast.synchronized { - HttpBroadcast.values.put(uuid, 0, value_) + def blockId: String = "broadcast_" + id + + HttpBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) } if (!isLocal) { - HttpBroadcast.write(uuid, value_) + HttpBroadcast.write(id, value_) } // Called by JVM when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { - val cachedVal = HttpBroadcast.values.get(uuid, 0) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - logInfo("Started reading broadcast variable " + uuid) - val start = System.nanoTime - value_ = HttpBroadcast.read[T](uuid) - HttpBroadcast.values.put(uuid, 0, value_) - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + uuid + " took " + time + " s") + SparkEnv.get.blockManager.getSingle(blockId) match { + case Some(x) => value_ = x.asInstanceOf[T] + case None => { + logInfo("Started reading broadcast variable " + id) + val start = System.nanoTime + value_ = HttpBroadcast.read[T](id) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + id + " took " + time + " s") + } } } } } -class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) { - HttpBroadcast.initialize(isMaster) - } - def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal) +private[spark] class HttpBroadcastFactory extends BroadcastFactory { + def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new HttpBroadcast[T](value_, isLocal, id) + + def stop() { HttpBroadcast.stop() } } private object HttpBroadcast extends Logging { - val values = SparkEnv.get.cache.newKeySpace() - private var initialized = false private var broadcastDir: File = null @@ -65,7 +68,7 @@ private object HttpBroadcast extends Logging { synchronized { if (!initialized) { bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - compress = System.getProperty("spark.compress", "false").toBoolean + compress = System.getProperty("spark.broadcast.compress", "true").toBoolean if (isMaster) { createServer() } @@ -74,6 +77,16 @@ private object HttpBroadcast extends Logging { } } } + + def stop() { + synchronized { + if (server != null) { + server.stop() + server = null + } + initialized = false + } + } private def createServer() { broadcastDir = Utils.createTempDir() @@ -84,8 +97,8 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } - def write(uuid: UUID, value: Any) { - val file = new File(broadcastDir, "broadcast-" + uuid) + def write(id: Long, value: Any) { + val file = new File(broadcastDir, "broadcast-" + id) val out: OutputStream = if (compress) { new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering } else { @@ -97,8 +110,8 @@ private object HttpBroadcast extends Logging { serOut.close() } - def read[T](uuid: UUID): T = { - val url = serverUri + "/broadcast-" + uuid + def read[T](id: Long): T = { + val url = serverUri + "/broadcast-" + id var in = if (compress) { new LZFInputStream(new URL(url).openStream()) // Does its own buffering } else { diff --git a/core/src/main/scala/spark/broadcast/MultiTracker.scala b/core/src/main/scala/spark/broadcast/MultiTracker.scala new file mode 100644 index 0000000000..5e76dedb94 --- /dev/null +++ b/core/src/main/scala/spark/broadcast/MultiTracker.scala @@ -0,0 +1,393 @@ +package spark.broadcast + +import java.io._ +import java.net._ +import java.util.Random + +import scala.collection.mutable.Map + +import spark._ + +private object MultiTracker +extends Logging { + + // Tracker Messages + val REGISTER_BROADCAST_TRACKER = 0 + val UNREGISTER_BROADCAST_TRACKER = 1 + val FIND_BROADCAST_TRACKER = 2 + + // Map to keep track of guides of ongoing broadcasts + var valueToGuideMap = Map[Long, SourceInfo]() + + // Random number generator + var ranGen = new Random + + private var initialized = false + private var isMaster_ = false + + private var stopBroadcast = false + + private var trackMV: TrackMultipleValues = null + + def initialize(isMaster__ : Boolean) { + synchronized { + if (!initialized) { + + isMaster_ = isMaster__ + + if (isMaster) { + trackMV = new TrackMultipleValues + trackMV.setDaemon(true) + trackMV.start() + + // Set masterHostAddress to the master's IP address for the slaves to read + System.setProperty("spark.MultiTracker.MasterHostAddress", Utils.localIpAddress) + } + + initialized = true + } + } + } + + def stop() { + stopBroadcast = true + } + + // Load common parameters + private var MasterHostAddress_ = System.getProperty( + "spark.MultiTracker.MasterHostAddress", "") + private var MasterTrackerPort_ = System.getProperty( + "spark.broadcast.masterTrackerPort", "11111").toInt + private var BlockSize_ = System.getProperty( + "spark.broadcast.blockSize", "4096").toInt * 1024 + private var MaxRetryCount_ = System.getProperty( + "spark.broadcast.maxRetryCount", "2").toInt + + private var TrackerSocketTimeout_ = System.getProperty( + "spark.broadcast.trackerSocketTimeout", "50000").toInt + private var ServerSocketTimeout_ = System.getProperty( + "spark.broadcast.serverSocketTimeout", "10000").toInt + + private var MinKnockInterval_ = System.getProperty( + "spark.broadcast.minKnockInterval", "500").toInt + private var MaxKnockInterval_ = System.getProperty( + "spark.broadcast.maxKnockInterval", "999").toInt + + // Load TreeBroadcast config params + private var MaxDegree_ = System.getProperty( + "spark.broadcast.maxDegree", "2").toInt + + // Load BitTorrentBroadcast config params + private var MaxPeersInGuideResponse_ = System.getProperty( + "spark.broadcast.maxPeersInGuideResponse", "4").toInt + + private var MaxChatSlots_ = System.getProperty( + "spark.broadcast.maxChatSlots", "4").toInt + private var MaxChatTime_ = System.getProperty( + "spark.broadcast.maxChatTime", "500").toInt + private var MaxChatBlocks_ = System.getProperty( + "spark.broadcast.maxChatBlocks", "1024").toInt + + private var EndGameFraction_ = System.getProperty( + "spark.broadcast.endGameFraction", "0.95").toDouble + + def isMaster = isMaster_ + + // Common config params + def MasterHostAddress = MasterHostAddress_ + def MasterTrackerPort = MasterTrackerPort_ + def BlockSize = BlockSize_ + def MaxRetryCount = MaxRetryCount_ + + def TrackerSocketTimeout = TrackerSocketTimeout_ + def ServerSocketTimeout = ServerSocketTimeout_ + + def MinKnockInterval = MinKnockInterval_ + def MaxKnockInterval = MaxKnockInterval_ + + // TreeBroadcast configs + def MaxDegree = MaxDegree_ + + // BitTorrentBroadcast configs + def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ + + def MaxChatSlots = MaxChatSlots_ + def MaxChatTime = MaxChatTime_ + def MaxChatBlocks = MaxChatBlocks_ + + def EndGameFraction = EndGameFraction_ + + class TrackMultipleValues + extends Thread with Logging { + override def run() { + var threadPool = Utils.newDaemonCachedThreadPool() + var serverSocket: ServerSocket = null + + serverSocket = new ServerSocket(MasterTrackerPort) + logInfo("TrackMultipleValues started at " + serverSocket) + + try { + while (!stopBroadcast) { + var clientSocket: Socket = null + try { + serverSocket.setSoTimeout(TrackerSocketTimeout) + clientSocket = serverSocket.accept() + } catch { + case e: Exception => { + if (stopBroadcast) { + logInfo("Stopping TrackMultipleValues...") + } + } + } + + if (clientSocket != null) { + try { + threadPool.execute(new Thread { + override def run() { + val oos = new ObjectOutputStream(clientSocket.getOutputStream) + oos.flush() + val ois = new ObjectInputStream(clientSocket.getInputStream) + + try { + // First, read message type + val messageType = ois.readObject.asInstanceOf[Int] + + if (messageType == REGISTER_BROADCAST_TRACKER) { + // Receive Long + val id = ois.readObject.asInstanceOf[Long] + // Receive hostAddress and listenPort + val gInfo = ois.readObject.asInstanceOf[SourceInfo] + + // Add to the map + valueToGuideMap.synchronized { + valueToGuideMap += (id -> gInfo) + } + + logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap) + + // Send dummy ACK + oos.writeObject(-1) + oos.flush() + } else if (messageType == UNREGISTER_BROADCAST_TRACKER) { + // Receive Long + val id = ois.readObject.asInstanceOf[Long] + + // Remove from the map + valueToGuideMap.synchronized { + valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault) + } + + logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap) + + // Send dummy ACK + oos.writeObject(-1) + oos.flush() + } else if (messageType == FIND_BROADCAST_TRACKER) { + // Receive Long + val id = ois.readObject.asInstanceOf[Long] + + var gInfo = + if (valueToGuideMap.contains(id)) valueToGuideMap(id) + else SourceInfo("", SourceInfo.TxNotStartedRetry) + + logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort) + + // Send reply back + oos.writeObject(gInfo) + oos.flush() + } else { + throw new SparkException("Undefined messageType at TrackMultipleValues") + } + } catch { + case e: Exception => { + logError("TrackMultipleValues had a " + e) + } + } finally { + ois.close() + oos.close() + clientSocket.close() + } + } + }) + } catch { + // In failure, close socket here; else, client thread will close + case ioe: IOException => clientSocket.close() + } + } + } + } finally { + serverSocket.close() + } + // Shutdown the thread pool + threadPool.shutdown() + } + } + + def getGuideInfo(variableLong: Long): SourceInfo = { + var clientSocketToTracker: Socket = null + var oosTracker: ObjectOutputStream = null + var oisTracker: ObjectInputStream = null + + var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry) + + var retriesLeft = MultiTracker.MaxRetryCount + do { + try { + // Connect to the tracker to find out GuideInfo + clientSocketToTracker = + new Socket(MultiTracker.MasterHostAddress, MultiTracker.MasterTrackerPort) + oosTracker = + new ObjectOutputStream(clientSocketToTracker.getOutputStream) + oosTracker.flush() + oisTracker = + new ObjectInputStream(clientSocketToTracker.getInputStream) + + // Send messageType/intention + oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER) + oosTracker.flush() + + // Send Long and receive GuideInfo + oosTracker.writeObject(variableLong) + oosTracker.flush() + gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] + } catch { + case e: Exception => logError("getGuideInfo had a " + e) + } finally { + if (oisTracker != null) { + oisTracker.close() + } + if (oosTracker != null) { + oosTracker.close() + } + if (clientSocketToTracker != null) { + clientSocketToTracker.close() + } + } + + Thread.sleep(MultiTracker.ranGen.nextInt( + MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + + MultiTracker.MinKnockInterval) + + retriesLeft -= 1 + } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) + + logDebug("Got this guidePort from Tracker: " + gInfo.listenPort) + return gInfo + } + + def registerBroadcast(id: Long, gInfo: SourceInfo) { + val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) + val oosST = new ObjectOutputStream(socket.getOutputStream) + oosST.flush() + val oisST = new ObjectInputStream(socket.getInputStream) + + // Send messageType/intention + oosST.writeObject(REGISTER_BROADCAST_TRACKER) + oosST.flush() + + // Send Long of this broadcast + oosST.writeObject(id) + oosST.flush() + + // Send this tracker's information + oosST.writeObject(gInfo) + oosST.flush() + + // Receive ACK and throw it away + oisST.readObject.asInstanceOf[Int] + + // Shut stuff down + oisST.close() + oosST.close() + socket.close() + } + + def unregisterBroadcast(id: Long) { + val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort) + val oosST = new ObjectOutputStream(socket.getOutputStream) + oosST.flush() + val oisST = new ObjectInputStream(socket.getInputStream) + + // Send messageType/intention + oosST.writeObject(UNREGISTER_BROADCAST_TRACKER) + oosST.flush() + + // Send Long of this broadcast + oosST.writeObject(id) + oosST.flush() + + // Receive ACK and throw it away + oisST.readObject.asInstanceOf[Int] + + // Shut stuff down + oisST.close() + oosST.close() + socket.close() + } + + // Helper method to convert an object to Array[BroadcastBlock] + def blockifyObject[IN](obj: IN): VariableInfo = { + val baos = new ByteArrayOutputStream + val oos = new ObjectOutputStream(baos) + oos.writeObject(obj) + oos.close() + baos.close() + val byteArray = baos.toByteArray + val bais = new ByteArrayInputStream(byteArray) + + var blockNum = (byteArray.length / BlockSize) + if (byteArray.length % BlockSize != 0) + blockNum += 1 + + var retVal = new Array[BroadcastBlock](blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, BlockSize)) { + val thisBlockSize = math.min(BlockSize, byteArray.length - i) + var tempByteArray = new Array[Byte](thisBlockSize) + val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + + retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) + blockID += 1 + } + bais.close() + + var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) + variableInfo.hasBlocks = blockNum + + return variableInfo + } + + // Helper method to convert Array[BroadcastBlock] to object + def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], + totalBytes: Int, + totalBlocks: Int): OUT = { + + var retByteArray = new Array[Byte](totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, + i * BlockSize, arrayOfBlocks(i).byteArray.length) + } + byteArrayToObject(retByteArray) + } + + private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { + val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){ + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) + } + val retVal = in.readObject.asInstanceOf[OUT] + in.close() + return retVal + } +} + +private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte]) +extends Serializable + +private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock], + totalBlocks: Int, + totalBytes: Int) +extends Serializable { + @transient var hasBlocks = 0 +} diff --git a/core/src/main/scala/spark/broadcast/SourceInfo.scala b/core/src/main/scala/spark/broadcast/SourceInfo.scala index 09907f4ee7..c79bb93c38 100644 --- a/core/src/main/scala/spark/broadcast/SourceInfo.scala +++ b/core/src/main/scala/spark/broadcast/SourceInfo.scala @@ -6,15 +6,11 @@ import spark._ /** * Used to keep and pass around information of peers involved in a broadcast - * - * CHANGED: Keep track of the blockSize for THIS broadcast variable. - * Broadcast.BlockSize is expected to be updated across different broadcasts */ -case class SourceInfo (hostAddress: String, +private[spark] case class SourceInfo (hostAddress: String, listenPort: Int, totalBlocks: Int = SourceInfo.UnusedParam, - totalBytes: Int = SourceInfo.UnusedParam, - blockSize: Int = Broadcast.BlockSize) + totalBytes: Int = SourceInfo.UnusedParam) extends Comparable[SourceInfo] with Logging { var currentLeechers = 0 @@ -30,11 +26,12 @@ extends Comparable[SourceInfo] with Logging { /** * Helper Object of SourceInfo for its constants */ -object SourceInfo { - // Constants for special values of listenPort +private[spark] object SourceInfo { + // Broadcast has not started yet! Should never happen. val TxNotStartedRetry = -1 - val TxOverGoToHDFS = 0 + // Broadcast has already finished. Try default mechanism. + val TxOverGoToDefault = -3 // Other constants val StopBroadcast = -2 val UnusedParam = 0 -}
\ No newline at end of file +} diff --git a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala index f5527b6ec9..fa676e9064 100644 --- a/core/src/main/scala/spark/broadcast/TreeBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/TreeBroadcast.scala @@ -8,22 +8,23 @@ import scala.collection.mutable.{ListBuffer, Map, Set} import scala.math import spark._ +import spark.storage.StorageLevel -class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean) -extends Broadcast[T] with Logging with Serializable { +private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +extends Broadcast[T](id) with Logging with Serializable { def value = value_ - TreeBroadcast.synchronized { - TreeBroadcast.values.put(uuid, 0, value_) + def blockId = "broadcast_" + id + + MultiTracker.synchronized { + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) } @transient var arrayOfBlocks: Array[BroadcastBlock] = null @transient var totalBytes = -1 @transient var totalBlocks = -1 @transient var hasBlocks = 0 - // CHANGED: BlockSize in the Broadcast object is expected to change over time - @transient var blockSize = Broadcast.BlockSize @transient var listenPortLock = new Object @transient var guidePortLock = new Object @@ -35,34 +36,24 @@ extends Broadcast[T] with Logging with Serializable { @transient var serveMR: ServeMultipleRequests = null @transient var guideMR: GuideMultipleRequests = null - @transient var hostAddress = Utils.localIpAddress + @transient var hostAddress = Utils.localIpAddress() @transient var listenPort = -1 @transient var guidePort = -1 - @transient var hasCopyInHDFS = false @transient var stopBroadcast = false // Must call this after all the variables have been created/initialized if (!isLocal) { - sendBroadcast + sendBroadcast() } def sendBroadcast() { logInfo("Local host address: " + hostAddress) - // Store a persistent copy in HDFS - // TODO: Turned OFF for now - // val out = new ObjectOutputStream(DfsBroadcast.openFileForWriting(uuid)) - // out.writeObject(value_) - // out.close() - // TODO: Fix this at some point - hasCopyInHDFS = true - // Create a variableInfo object and store it in valueInfos - var variableInfo = Broadcast.blockifyObject(value_) + var variableInfo = MultiTracker.blockifyObject(value_) // Prepare the value being broadcasted - // TODO: Refactoring and clean-up required here arrayOfBlocks = variableInfo.arrayOfBlocks totalBytes = variableInfo.totalBytes totalBlocks = variableInfo.totalBlocks @@ -75,9 +66,7 @@ extends Broadcast[T] with Logging with Serializable { // Must always come AFTER guideMR is created while (guidePort == -1) { - guidePortLock.synchronized { - guidePortLock.wait() - } + guidePortLock.synchronized { guidePortLock.wait() } } serveMR = new ServeMultipleRequests @@ -87,63 +76,61 @@ extends Broadcast[T] with Logging with Serializable { // Must always come AFTER serveMR is created while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } + listenPortLock.synchronized { listenPortLock.wait() } } // Must always come AFTER listenPort is created val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes, blockSize) + SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) listOfSources += masterSource // Register with the Tracker - TreeBroadcast.registerValue(uuid, guidePort) + MultiTracker.registerBroadcast(id, + SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) } private def readObject(in: ObjectInputStream) { in.defaultReadObject() - TreeBroadcast.synchronized { - val cachedVal = TreeBroadcast.values.get(uuid, 0) - if (cachedVal != null) { - value_ = cachedVal.asInstanceOf[T] - } else { - // Initializing everything because Master will only send null/0 values - initializeSlaveVariables - - logInfo("Local host address: " + hostAddress) - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(uuid) - // If does not succeed, then get from HDFS copy - if (receptionSucceeded) { - value_ = Broadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - TreeBroadcast.values.put(uuid, 0, value_) - } else { - val fileIn = new ObjectInputStream(DfsBroadcast.openFileForReading(uuid)) - value_ = fileIn.readObject.asInstanceOf[T] - TreeBroadcast.values.put(uuid, 0, value_) - fileIn.close() - } + MultiTracker.synchronized { + SparkEnv.get.blockManager.getSingle(blockId) match { + case Some(x) => + value_ = x.asInstanceOf[T] + + case None => + logInfo("Started reading broadcast variable " + id) + // Initializing everything because Master will only send null/0 values + // Only the 1st worker in a node can be here. Others will get from cache + initializeWorkerVariables() + + logInfo("Local host address: " + hostAddress) + + serveMR = new ServeMultipleRequests + serveMR.setDaemon(true) + serveMR.start() + logInfo("ServeMultipleRequests started...") + + val start = System.nanoTime + + val receptionSucceeded = receiveBroadcast(id) + if (receptionSucceeded) { + value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + } else { + logError("Reading broadcast variable " + id + " failed") + } - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s") + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + id + " took " + time + " s") } } } - private def initializeSlaveVariables() { + private def initializeWorkerVariables() { arrayOfBlocks = null totalBytes = -1 totalBlocks = -1 hasBlocks = 0 - blockSize = -1 listenPortLock = new Object totalBlocksLock = new Object @@ -151,78 +138,23 @@ extends Broadcast[T] with Logging with Serializable { serveMR = null - hostAddress = Utils.localIpAddress + hostAddress = Utils.localIpAddress() listenPort = -1 stopBroadcast = false } - def getMasterListenPort(variableUUID: UUID): Int = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var masterListenPort: Int = SourceInfo.TxOverGoToHDFS - - var retriesLeft = Broadcast.MaxRetryCount - do { - try { - // Connect to the tracker to find out the guide - clientSocketToTracker = - new Socket(Broadcast.MasterHostAddress, Broadcast.MasterTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send UUID and receive masterListenPort - oosTracker.writeObject(uuid) - oosTracker.flush() - masterListenPort = oisTracker.readObject.asInstanceOf[Int] - } catch { - case e: Exception => { - logInfo("getMasterListenPort had a " + e) - } - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - retriesLeft -= 1 - - Thread.sleep(TreeBroadcast.ranGen.nextInt( - Broadcast.MaxKnockInterval - Broadcast.MinKnockInterval) + - Broadcast.MinKnockInterval) - - } while (retriesLeft > 0 && masterListenPort == SourceInfo.TxNotStartedRetry) - - logInfo("Got this guidePort from Tracker: " + masterListenPort) - return masterListenPort - } - - def receiveBroadcast(variableUUID: UUID): Boolean = { - val masterListenPort = getMasterListenPort(variableUUID) - - if (masterListenPort == SourceInfo.TxOverGoToHDFS || - masterListenPort == SourceInfo.TxNotStartedRetry) { - // TODO: SourceInfo.TxNotStartedRetry is not really in use because we go - // to HDFS anyway when receiveBroadcast returns false + def receiveBroadcast(variableID: Long): Boolean = { + val gInfo = MultiTracker.getGuideInfo(variableID) + + if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { return false } // Wait until hostAddress and listenPort are created by the // ServeMultipleRequests thread while (listenPort == -1) { - listenPortLock.synchronized { - listenPortLock.wait() - } + listenPortLock.synchronized { listenPortLock.wait() } } var clientSocketToMaster: Socket = null @@ -231,19 +163,15 @@ extends Broadcast[T] with Logging with Serializable { // Connect and receive broadcast from the specified source, retrying the // specified number of times in case of failures - var retriesLeft = Broadcast.MaxRetryCount + var retriesLeft = MultiTracker.MaxRetryCount do { // Connect to Master and send this worker's Information - clientSocketToMaster = - new Socket(Broadcast.MasterHostAddress, masterListenPort) - // TODO: Guiding object connection is reusable - oosMaster = - new ObjectOutputStream(clientSocketToMaster.getOutputStream) + clientSocketToMaster = new Socket(MultiTracker.MasterHostAddress, gInfo.listenPort) + oosMaster = new ObjectOutputStream(clientSocketToMaster.getOutputStream) oosMaster.flush() - oisMaster = - new ObjectInputStream(clientSocketToMaster.getInputStream) + oisMaster = new ObjectInputStream(clientSocketToMaster.getInputStream) - logInfo("Connected to Master's guiding object") + logDebug("Connected to Master's guiding object") // Send local source information oosMaster.writeObject(SourceInfo(hostAddress, listenPort)) @@ -253,13 +181,10 @@ extends Broadcast[T] with Logging with Serializable { var sourceInfo = oisMaster.readObject.asInstanceOf[SourceInfo] totalBlocks = sourceInfo.totalBlocks arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - totalBlocksLock.synchronized { - totalBlocksLock.notifyAll() - } + totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } totalBytes = sourceInfo.totalBytes - blockSize = sourceInfo.blockSize - logInfo("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) + logDebug("Received SourceInfo from Master:" + sourceInfo + " My Port: " + listenPort) val start = System.nanoTime val receptionSucceeded = receiveSingleTransmission(sourceInfo) @@ -289,8 +214,10 @@ extends Broadcast[T] with Logging with Serializable { return (hasBlocks == totalBlocks) } - // Tries to receive broadcast from the source and returns Boolean status. - // This might be called multiple times to retry a defined number of times. + /** + * Tries to receive broadcast from the source and returns Boolean status. + * This might be called multiple times to retry a defined number of times. + */ private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { var clientSocketToSource: Socket = null var oosSource: ObjectOutputStream = null @@ -299,16 +226,13 @@ extends Broadcast[T] with Logging with Serializable { var receptionSucceeded = false try { // Connect to the source to get the object itself - clientSocketToSource = - new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = - new ObjectOutputStream(clientSocketToSource.getOutputStream) + clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream) oosSource.flush() - oisSource = - new ObjectInputStream(clientSocketToSource.getInputStream) + oisSource = new ObjectInputStream(clientSocketToSource.getInputStream) - logInfo("Inside receiveSingleTransmission") - logInfo("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) + logDebug("Inside receiveSingleTransmission") + logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) // Send the range oosSource.writeObject((hasBlocks, totalBlocks)) @@ -319,20 +243,17 @@ extends Broadcast[T] with Logging with Serializable { val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] val receptionTime = (System.currentTimeMillis - recvStartTime) - logInfo("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") + logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") arrayOfBlocks(hasBlocks) = bcBlock hasBlocks += 1 + // Set to true if at least one block is received receptionSucceeded = true - hasBlocksLock.synchronized { - hasBlocksLock.notifyAll() - } + hasBlocksLock.synchronized { hasBlocksLock.notifyAll() } } } catch { - case e: Exception => { - logInfo("receiveSingleTransmission had a " + e) - } + case e: Exception => logError("receiveSingleTransmission had a " + e) } finally { if (oisSource != null) { oisSource.close() @@ -361,32 +282,32 @@ extends Broadcast[T] with Logging with Serializable { guidePort = serverSocket.getLocalPort logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - guidePortLock.synchronized { - guidePortLock.notifyAll() - } + guidePortLock.synchronized { guidePortLock.notifyAll() } try { - // Don't stop until there is a copy in HDFS - while (!stopBroadcast || !hasCopyInHDFS) { + while (!stopBroadcast) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) clientSocket = serverSocket.accept } catch { case e: Exception => { - logInfo("GuideMultipleRequests Timeout.") - // Stop broadcast if at least one worker has connected and // everyone connected so far are done. Comparing with // listOfSources.size - 1, because it includes the Guide itself - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true + listOfSources.synchronized { + setOfCompletedSources.synchronized { + if (listOfSources.size > 1 && + setOfCompletedSources.size == listOfSources.size - 1) { + stopBroadcast = true + logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") + } + } } } } if (clientSocket != null) { - logInfo("Guide: Accepted new client connection: " + clientSocket) + logDebug("Guide: Accepted new client connection: " + clientSocket) try { threadPool.execute(new GuideSingleRequest(clientSocket)) } catch { @@ -399,14 +320,13 @@ extends Broadcast[T] with Logging with Serializable { logInfo("Sending stopBroadcast notifications...") sendStopBroadcastNotifications - TreeBroadcast.unregisterValue(uuid) + MultiTracker.unregisterBroadcast(id) } finally { if (serverSocket != null) { logInfo("GuideMultipleRequests now stopping...") serverSocket.close() } } - // Shutdown the thread pool threadPool.shutdown() } @@ -423,21 +343,17 @@ extends Broadcast[T] with Logging with Serializable { try { // Connect to the source - guideSocketToSource = - new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = - new ObjectOutputStream(guideSocketToSource.getOutputStream) + guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) + gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) gosSource.flush() - gisSource = - new ObjectInputStream(guideSocketToSource.getInputStream) + gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - // Send stopBroadcast signal. Range = SourceInfo.StopBroadcast*2 - gosSource.writeObject((SourceInfo.StopBroadcast, - SourceInfo.StopBroadcast)) + // Send stopBroadcast signal + gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast)) gosSource.flush() } catch { case e: Exception => { - logInfo("sendStopBroadcastNotifications had a " + e) + logError("sendStopBroadcastNotifications had a " + e) } } finally { if (gisSource != null) { @@ -473,14 +389,14 @@ extends Broadcast[T] with Logging with Serializable { listOfSources.synchronized { // Select a suitable source and send it back to the worker selectedSourceInfo = selectSuitableSource(sourceInfo) - logInfo("Sending selectedSourceInfo: " + selectedSourceInfo) + logDebug("Sending selectedSourceInfo: " + selectedSourceInfo) oos.writeObject(selectedSourceInfo) oos.flush() // Add this new (if it can finish) source to the list of sources thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes, blockSize) - logInfo("Adding possible new source to listOfSources: " + thisWorkerInfo) + sourceInfo.listenPort, totalBlocks, totalBytes) + logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo) listOfSources += thisWorkerInfo } @@ -492,9 +408,9 @@ extends Broadcast[T] with Logging with Serializable { // This should work since SourceInfo is a case class assert(listOfSources.contains(selectedSourceInfo)) - // Remove first + // Remove first + // (Currently removing a source based on just one failure notification!) listOfSources = listOfSources - selectedSourceInfo - // TODO: Removing a source based on just one failure notification! // Update sourceInfo and put it back in, IF reception succeeded if (!sourceInfo.receptionFailed) { @@ -503,17 +419,13 @@ extends Broadcast[T] with Logging with Serializable { setOfCompletedSources += thisWorkerInfo } + // Update leecher count and put it back in selectedSourceInfo.currentLeechers -= 1 - - // Put it back listOfSources += selectedSourceInfo } } } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close() everything up case e: Exception => { - // Assuming that exception caused due to receiver worker failure. // Remove failed worker from listOfSources and update leecherCount of // corresponding source worker listOfSources.synchronized { @@ -532,27 +444,23 @@ extends Broadcast[T] with Logging with Serializable { } } } finally { + logInfo("GuideSingleRequest is closing streams and sockets") ois.close() oos.close() clientSocket.close() } } - // FIXME: Caller must have a synchronized block on listOfSources - // FIXME: If a worker fails to get the broadcasted variable from a source - // and comes back to the Master, this function might choose the worker - // itself as a source to create a dependency cycle (this worker was put - // into listOfSources as a streming source when it first arrived). The - // length of this cycle can be arbitrarily long. + // Assuming the caller to have a synchronized block on listOfSources + // Select one with the most leechers. This will level-wise fill the tree private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - // Select one with the most leechers. This will level-wise fill the tree - var maxLeechers = -1 var selectedSource: SourceInfo = null listOfSources.foreach { source => - if (source != skipSourceInfo && - source.currentLeechers < Broadcast.MaxDegree && + if ((source.hostAddress != skipSourceInfo.hostAddress || + source.listenPort != skipSourceInfo.listenPort) && + source.currentLeechers < MultiTracker.MaxDegree && source.currentLeechers > maxLeechers) { selectedSource = source maxLeechers = source.currentLeechers @@ -561,7 +469,6 @@ extends Broadcast[T] with Logging with Serializable { // Update leecher count selectedSource.currentLeechers += 1 - return selectedSource } } @@ -569,35 +476,33 @@ extends Broadcast[T] with Logging with Serializable { class ServeMultipleRequests extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) + + var threadPool = Utils.newDaemonCachedThreadPool() + + override def run() { + var serverSocket = new ServerSocket(0) listenPort = serverSocket.getLocalPort + logInfo("ServeMultipleRequests started with " + serverSocket) - listenPortLock.synchronized { - listenPortLock.notifyAll() - } + listenPortLock.synchronized { listenPortLock.notifyAll() } try { while (!stopBroadcast) { var clientSocket: Socket = null try { - serverSocket.setSoTimeout(Broadcast.ServerSocketTimeout) + serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) clientSocket = serverSocket.accept } catch { - case e: Exception => { - logInfo("ServeMultipleRequests Timeout.") - } + case e: Exception => { } } + if (clientSocket != null) { - logInfo("Serve: Accepted new client connection: " + clientSocket) + logDebug("Serve: Accepted new client connection: " + clientSocket) try { threadPool.execute(new ServeSingleRequest(clientSocket)) } catch { - // In failure, close() socket here; else, the thread will close() it + // In failure, close socket here; else, the thread will close it case ioe: IOException => clientSocket.close() } } @@ -608,7 +513,6 @@ extends Broadcast[T] with Logging with Serializable { serverSocket.close() } } - // Shutdown the thread pool threadPool.shutdown() } @@ -631,19 +535,14 @@ extends Broadcast[T] with Logging with Serializable { sendFrom = rangeToSend._1 sendUntil = rangeToSend._2 - if (sendFrom == SourceInfo.StopBroadcast && - sendUntil == SourceInfo.StopBroadcast) { + // If not a valid range, stop broadcast + if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) { stopBroadcast = true } else { - // Carry on sendObject } } catch { - // If something went wrong, e.g., the worker at the other end died etc. - // then close() everything up - case e: Exception => { - logInfo("ServeSingleRequest had a " + e) - } + case e: Exception => logError("ServeSingleRequest had a " + e) } finally { logInfo("ServeSingleRequest is closing streams and sockets") ois.close() @@ -655,152 +554,32 @@ extends Broadcast[T] with Logging with Serializable { private def sendObject() { // Wait till receiving the SourceInfo from Master while (totalBlocks == -1) { - totalBlocksLock.synchronized { - totalBlocksLock.wait() - } + totalBlocksLock.synchronized { totalBlocksLock.wait() } } for (i <- sendFrom until sendUntil) { while (i == hasBlocks) { - hasBlocksLock.synchronized { - hasBlocksLock.wait() - } + hasBlocksLock.synchronized { hasBlocksLock.wait() } } try { oos.writeObject(arrayOfBlocks(i)) oos.flush() } catch { - case e: Exception => { - logInfo("sendObject had a " + e) - } + case e: Exception => logError("sendObject had a " + e) } - logInfo("Sent block: " + i + " to " + clientSocket) + logDebug("Sent block: " + i + " to " + clientSocket) } } } } } -class TreeBroadcastFactory +private[spark] class TreeBroadcastFactory extends BroadcastFactory { - def initialize(isMaster: Boolean) = TreeBroadcast.initialize(isMaster) - def newBroadcast[T](value_ : T, isLocal: Boolean) = - new TreeBroadcast[T](value_, isLocal) -} + def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) } -private object TreeBroadcast -extends Logging { - val values = SparkEnv.get.cache.newKeySpace() + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TreeBroadcast[T](value_, isLocal, id) - var valueToGuidePortMap = Map[UUID, Int]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var isMaster_ = false - - private var trackMV: TrackMultipleValues = null - - private var MaxDegree_ : Int = 2 - - def initialize(isMaster__ : Boolean) { - synchronized { - if (!initialized) { - isMaster_ = isMaster__ - - if (isMaster) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - // TODO: Logging the following line makes the Spark framework ID not - // getting logged, cause it calls logInfo before log4j is initialized - logInfo("TrackMultipleValues started...") - } - - // Initialize DfsBroadcast to be used for broadcast variable persistence - DfsBroadcast.initialize - - initialized = true - } - } - } - - def isMaster = isMaster_ - - def registerValue(uuid: UUID, guidePort: Int) { - valueToGuidePortMap.synchronized { - valueToGuidePortMap += (uuid -> guidePort) - logInfo("New value registered with the Tracker " + valueToGuidePortMap) - } - } - - def unregisterValue(uuid: UUID) { - valueToGuidePortMap.synchronized { - valueToGuidePortMap(uuid) = SourceInfo.TxOverGoToHDFS - logInfo("Value unregistered from the Tracker " + valueToGuidePortMap) - } - } - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(Broadcast.MasterTrackerPort) - logInfo("TrackMultipleValues" + serverSocket) - - try { - while (true) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(Broadcast.TrackerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - logInfo("TrackMultipleValues Timeout. Stopping listening...") - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - try { - val uuid = ois.readObject.asInstanceOf[UUID] - var guidePort = - if (valueToGuidePortMap.contains(uuid)) { - valueToGuidePortMap(uuid) - } else SourceInfo.TxNotStartedRetry - logInfo("TrackMultipleValues: Got new request: " + clientSocket + " for " + uuid + " : " + guidePort) - oos.writeObject(guidePort) - } catch { - case e: Exception => { - logInfo("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close() socket here; else, client thread will close() - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - serverSocket.close() - } - - // Shutdown the thread pool - threadPool.shutdown() - } - } + def stop() { MultiTracker.stop() } } diff --git a/core/src/main/scala/spark/deploy/Command.scala b/core/src/main/scala/spark/deploy/Command.scala index 344888919a..577101e3c3 100644 --- a/core/src/main/scala/spark/deploy/Command.scala +++ b/core/src/main/scala/spark/deploy/Command.scala @@ -2,7 +2,7 @@ package spark.deploy import scala.collection.Map -case class Command( +private[spark] case class Command( mainClass: String, arguments: Seq[String], environment: Map[String, String]) { diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 141bbe4d57..d2b63d6e0d 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -7,13 +7,15 @@ import scala.collection.immutable.List import scala.collection.mutable.HashMap -sealed trait DeployMessage extends Serializable +private[spark] sealed trait DeployMessage extends Serializable // Worker to Master +private[spark] case class RegisterWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int) extends DeployMessage +private[spark] case class ExecutorStateChanged( jobId: String, execId: Int, @@ -23,11 +25,11 @@ case class ExecutorStateChanged( // Master to Worker -case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage -case class RegisterWorkerFailed(message: String) extends DeployMessage -case class KillExecutor(jobId: String, execId: Int) extends DeployMessage +private[spark] case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage +private[spark] case class RegisterWorkerFailed(message: String) extends DeployMessage +private[spark] case class KillExecutor(jobId: String, execId: Int) extends DeployMessage -case class LaunchExecutor( +private[spark] case class LaunchExecutor( jobId: String, execId: Int, jobDesc: JobDescription, @@ -38,33 +40,42 @@ case class LaunchExecutor( // Client to Master -case class RegisterJob(jobDescription: JobDescription) extends DeployMessage +private[spark] case class RegisterJob(jobDescription: JobDescription) extends DeployMessage // Master to Client +private[spark] case class RegisteredJob(jobId: String) extends DeployMessage + +private[spark] case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) + +private[spark] case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String]) + +private[spark] case class JobKilled(message: String) // Internal message in Client -case object StopClient +private[spark] case object StopClient // MasterWebUI To Master -case object RequestMasterState +private[spark] case object RequestMasterState // Master to MasterWebUI +private[spark] case class MasterState(uri : String, workers: List[WorkerInfo], activeJobs: List[JobInfo], completedJobs: List[JobInfo]) // WorkerWebUI to Worker -case object RequestWorkerState +private[spark] case object RequestWorkerState // Worker to WorkerWebUI +private[spark] case class WorkerState(uri: String, workerId: String, executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
\ No newline at end of file diff --git a/core/src/main/scala/spark/deploy/ExecutorState.scala b/core/src/main/scala/spark/deploy/ExecutorState.scala index d6ff1c54ca..5dc0c54552 100644 --- a/core/src/main/scala/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/spark/deploy/ExecutorState.scala @@ -1,6 +1,6 @@ package spark.deploy -object ExecutorState +private[spark] object ExecutorState extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") { val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala index 8ae77b1038..20879c5f11 100644 --- a/core/src/main/scala/spark/deploy/JobDescription.scala +++ b/core/src/main/scala/spark/deploy/JobDescription.scala @@ -1,6 +1,6 @@ package spark.deploy -class JobDescription( +private[spark] class JobDescription( val name: String, val cores: Int, val memoryPerSlave: Int, diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala new file mode 100644 index 0000000000..8b2a71add5 --- /dev/null +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -0,0 +1,58 @@ +package spark.deploy + +import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} + +import spark.deploy.worker.Worker +import spark.deploy.master.Master +import spark.util.AkkaUtils +import spark.{Logging, Utils} + +import scala.collection.mutable.ArrayBuffer + +private[spark] +class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging { + + val localIpAddress = Utils.localIpAddress + + var masterActor : ActorRef = _ + var masterActorSystem : ActorSystem = _ + var masterPort : Int = _ + var masterUrl : String = _ + + val slaveActorSystems = ArrayBuffer[ActorSystem]() + val slaveActors = ArrayBuffer[ActorRef]() + + def start() : String = { + logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.") + + /* Start the Master */ + val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0) + masterActorSystem = actorSystem + masterUrl = "spark://" + localIpAddress + ":" + masterPort + val actor = masterActorSystem.actorOf( + Props(new Master(localIpAddress, masterPort, 0)), name = "Master") + masterActor = actor + + /* Start the Slaves */ + for (slaveNum <- 1 to numSlaves) { + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0) + slaveActorSystems += actorSystem + val actor = actorSystem.actorOf( + Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), + name = "Worker") + slaveActors += actor + } + + return masterUrl + } + + def stop() { + logInfo("Shutting down local Spark cluster.") + // Stop the slaves before the master so they don't get upset that it disconnected + slaveActorSystems.foreach(_.shutdown()) + slaveActorSystems.foreach(_.awaitTermination()) + masterActorSystem.shutdown() + masterActorSystem.awaitTermination() + } +} diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala index c7fa8a3874..e51b0c5c15 100644 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ b/core/src/main/scala/spark/deploy/client/Client.scala @@ -4,6 +4,7 @@ import spark.deploy._ import akka.actor._ import akka.pattern.ask import akka.util.duration._ +import akka.pattern.AskTimeoutException import spark.{SparkException, Logging} import akka.remote.RemoteClientLifeCycleEvent import akka.remote.RemoteClientShutdown @@ -16,7 +17,7 @@ import akka.dispatch.Await * The main class used to talk to a Spark deploy cluster. Takes a master URL, a job description, * and a listener for job events, and calls back the listener when various events occur. */ -class Client( +private[spark] class Client( actorSystem: ActorSystem, masterUrl: String, jobDescription: JobDescription, @@ -42,7 +43,6 @@ class Client( val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort) try { master = context.actorFor(akkaUrl) - //master ! RegisterWorker(ip, port, cores, memory) master ! RegisterJob(jobDescription) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing @@ -101,9 +101,13 @@ class Client( def stop() { if (actor != null) { - val timeout = 1.seconds - val future = actor.ask(StopClient)(timeout) - Await.result(future, timeout) + try { + val timeout = 1.seconds + val future = actor.ask(StopClient)(timeout) + Await.result(future, timeout) + } catch { + case e: AskTimeoutException => // Ignore it, maybe master went away + } actor = null } } diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala index 7d23baff32..a8fa982085 100644 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala @@ -7,7 +7,7 @@ package spark.deploy.client * * Users of this API should *not* block inside the callback methods. */ -trait ClientListener { +private[spark] trait ClientListener { def connected(jobId: String): Unit def disconnected(): Unit diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index df9a36c7fe..bf0e7428ba 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -4,7 +4,7 @@ import spark.util.AkkaUtils import spark.{Logging, Utils} import spark.deploy.{Command, JobDescription} -object TestClient { +private[spark] object TestClient { class TestListener extends ClientListener with Logging { def connected(id: String) { diff --git a/core/src/main/scala/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/spark/deploy/client/TestExecutor.scala index 2e40e10d18..0e46db2272 100644 --- a/core/src/main/scala/spark/deploy/client/TestExecutor.scala +++ b/core/src/main/scala/spark/deploy/client/TestExecutor.scala @@ -1,6 +1,6 @@ package spark.deploy.client -object TestExecutor { +private[spark] object TestExecutor { def main(args: Array[String]) { println("Hello world!") while (true) { diff --git a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala index 335e00958c..1db2c32633 100644 --- a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala +++ b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala @@ -2,7 +2,7 @@ package spark.deploy.master import spark.deploy.ExecutorState -class ExecutorInfo( +private[spark] class ExecutorInfo( val id: Int, val job: JobInfo, val worker: WorkerInfo, diff --git a/core/src/main/scala/spark/deploy/master/JobInfo.scala b/core/src/main/scala/spark/deploy/master/JobInfo.scala index 31d48b82b9..8795c09cc1 100644 --- a/core/src/main/scala/spark/deploy/master/JobInfo.scala +++ b/core/src/main/scala/spark/deploy/master/JobInfo.scala @@ -5,6 +5,7 @@ import java.util.Date import akka.actor.ActorRef import scala.collection.mutable +private[spark] class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, val actor: ActorRef) { var state = JobState.WAITING var executors = new mutable.HashMap[Int, ExecutorInfo] @@ -31,4 +32,13 @@ class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, va } def coresLeft: Int = desc.cores - coresGranted + + private var _retryCount = 0 + + def retryCount = _retryCount + + def incrementRetryCount = { + _retryCount += 1 + _retryCount + } } diff --git a/core/src/main/scala/spark/deploy/master/JobState.scala b/core/src/main/scala/spark/deploy/master/JobState.scala index 50b0c6f95b..2b70cf0191 100644 --- a/core/src/main/scala/spark/deploy/master/JobState.scala +++ b/core/src/main/scala/spark/deploy/master/JobState.scala @@ -1,7 +1,9 @@ package spark.deploy.master -object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") { +private[spark] object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") { type JobState = Value val WAITING, RUNNING, FINISHED, FAILED = Value + + val MAX_NUM_RETRY = 10 } diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index c98dddea7b..6010f7cff2 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -1,21 +1,20 @@ package spark.deploy.master -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} - import akka.actor._ -import spark.{Logging, Utils} -import spark.util.AkkaUtils +import akka.actor.Terminated +import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown} + import java.text.SimpleDateFormat import java.util.Date -import akka.remote.RemoteClientLifeCycleEvent + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} + import spark.deploy._ -import akka.remote.RemoteClientShutdown -import akka.remote.RemoteClientDisconnected -import spark.deploy.RegisterWorker -import spark.deploy.RegisterWorkerFailed -import akka.actor.Terminated +import spark.{Logging, SparkException, Utils} +import spark.util.AkkaUtils -class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { + +private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For job IDs var nextJobNumber = 0 @@ -81,12 +80,22 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { exec.state = state exec.job.actor ! ExecutorUpdated(execId, state, message) if (ExecutorState.isFinished(state)) { + val jobInfo = idToJob(jobId) // Remove this executor from the worker and job logInfo("Removing executor " + exec.fullId + " because it is " + state) - idToJob(jobId).removeExecutor(exec) + jobInfo.removeExecutor(exec) exec.worker.removeExecutor(exec) - // TODO: the worker would probably want to restart the executor a few times - schedule() + + // Only retry certain number of times so we don't go into an infinite loop. + if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) { + schedule() + } else { + val e = new SparkException("Job %s wth ID %s failed %d times.".format( + jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) + logError(e.getMessage, e) + throw e + //System.exit(1) + } } } case None => @@ -112,7 +121,7 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { addressToWorker.get(address).foreach(removeWorker) addressToJob.get(address).foreach(removeJob) } - + case RequestMasterState => { sender ! MasterState(ip + ":" + port, workers.toList, jobs.toList, completedJobs.toList) } @@ -203,7 +212,7 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { } } -object Master { +private[spark] object Master { def main(argStrings: Array[String]) { val args = new MasterArguments(argStrings) val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port) diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala index 0f7a92bdd0..1b1c3dd0ad 100644 --- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/spark/deploy/master/MasterArguments.scala @@ -6,7 +6,7 @@ import spark.Utils /** * Command-line parser for the master. */ -class MasterArguments(args: Array[String]) { +private[spark] class MasterArguments(args: Array[String]) { var ip = Utils.localIpAddress() var port = 7077 var webUiPort = 8080 @@ -51,7 +51,7 @@ class MasterArguments(args: Array[String]) { */ def printUsageAndExit(exitCode: Int) { System.err.println( - "Usage: spark-master [options]\n" + + "Usage: Master [options]\n" + "\n" + "Options:\n" + " -i IP, --ip IP IP address or DNS name to listen on\n" + diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index f03c0a0229..700a41c770 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -10,6 +10,7 @@ import cc.spray.directives._ import cc.spray.typeconversion.TwirlSupport._ import spark.deploy._ +private[spark] class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/master/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" @@ -22,7 +23,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct completeWith { val future = master ? RequestMasterState future.map { - masterState => masterui.html.index.render(masterState.asInstanceOf[MasterState]) + masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState]) } } } ~ @@ -36,7 +37,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct // A bit ugly an inefficient, but we won't have a number of jobs // so large that it will make a significant difference. (masterState.activeJobs ::: masterState.completedJobs).find(_.id == jobId) match { - case Some(job) => masterui.html.job_details.render(job) + case Some(job) => spark.deploy.master.html.job_details.render(job) case _ => null } } diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala index 59474a0945..16b3f9b653 100644 --- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala @@ -3,7 +3,7 @@ package spark.deploy.master import akka.actor.ActorRef import scala.collection.mutable -class WorkerInfo( +private[spark] class WorkerInfo( val id: String, val host: String, val port: Int, diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 3e24380810..07ae7bca78 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -13,7 +13,7 @@ import spark.deploy.ExecutorStateChanged /** * Manages the execution of one executor process. */ -class ExecutorRunner( +private[spark] class ExecutorRunner( val jobId: String, val execId: Int, val jobDesc: JobDescription, @@ -29,12 +29,25 @@ class ExecutorRunner( val fullId = jobId + "/" + execId var workerThread: Thread = null var process: Process = null + var shutdownHook: Thread = null def start() { workerThread = new Thread("ExecutorRunner for " + fullId) { override def run() { fetchAndRunExecutor() } } workerThread.start() + + // Shutdown hook that kills actors on shutdown. + shutdownHook = new Thread() { + override def run() { + if (process != null) { + logInfo("Shutdown hook killing child process.") + process.destroy() + process.waitFor() + } + } + } + Runtime.getRuntime.addShutdownHook(shutdownHook) } /** Stop this executor runner, including killing the process it launched */ @@ -45,40 +58,10 @@ class ExecutorRunner( if (process != null) { logInfo("Killing process!") process.destroy() + process.waitFor() } worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None) - } - } - - /** - * Download a file requested by the executor. Supports fetching the file in a variety of ways, - * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. - */ - def fetchFile(url: String, targetDir: File) { - val filename = url.split("/").last - val targetFile = new File(targetDir, filename) - if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) { - // Use the java.net library to fetch it - logInfo("Fetching " + url + " to " + targetFile) - val in = new URL(url).openStream() - val out = new FileOutputStream(targetFile) - Utils.copyStream(in, out, true) - } else { - // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others - val uri = new URI(url) - val conf = new Configuration() - val fs = FileSystem.get(uri, conf) - val in = fs.open(new Path(uri)) - val out = new FileOutputStream(targetFile) - Utils.copyStream(in, out, true) - } - // Decompress the file if it's a .tar or .tar.gz - if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xzf", filename), targetDir) - } else if (filename.endsWith(".tar")) { - logInfo("Untarring " + filename) - Utils.execute(Seq("tar", "-xf", filename), targetDir) + Runtime.getRuntime.removeShutdownHook(shutdownHook) } } @@ -92,7 +75,8 @@ class ExecutorRunner( def buildCommandSeq(): Seq[String] = { val command = jobDesc.command - val runScript = new File(sparkHome, "run").getCanonicalPath + val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run"; + val runScript = new File(sparkHome, script).getCanonicalPath Seq(runScript, command.mainClass) ++ command.arguments.map(substituteVariables) } @@ -101,7 +85,12 @@ class ExecutorRunner( val out = new FileOutputStream(file) new Thread("redirect output to " + file) { override def run() { - Utils.copyStream(in, out, true) + try { + Utils.copyStream(in, out, true) + } catch { + case e: IOException => + logInfo("Redirection to " + file + " closed: " + e.getMessage) + } } }.start() } @@ -131,6 +120,9 @@ class ExecutorRunner( } env.put("SPARK_CORES", cores.toString) env.put("SPARK_MEMORY", memory.toString) + // In case we are running this from within the Spark Shell, avoid creating a "scala" + // parent process for the executor command + env.put("SPARK_LAUNCH_WITH_SCALA", "0") process = builder.start() // Redirect its stdout and stderr to files diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 0a80463c0b..474c9364fd 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -16,7 +16,14 @@ import spark.deploy.RegisterWorkerFailed import akka.actor.Terminated import java.io.File -class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, masterUrl: String) +private[spark] class Worker( + ip: String, + port: Int, + webUiPort: Int, + cores: Int, + memory: Int, + masterUrl: String, + workDirPath: String = null) extends Actor with Logging { val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs @@ -37,7 +44,11 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas def memoryFree: Int = memory - memoryUsed def createWorkDir() { - workDir = new File(sparkHome, "work") + workDir = if (workDirPath != null) { + new File(workDirPath) + } else { + new File(sparkHome, "work") + } try { if (!workDir.exists() && !workDir.mkdirs()) { logError("Failed to create work directory " + workDir) @@ -153,14 +164,19 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas def generateWorkerId(): String = { "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port) } + + override def postStop() { + executors.values.foreach(_.kill()) + } } -object Worker { +private[spark] object Worker { def main(argStrings: Array[String]) { val args = new WorkerArguments(argStrings) val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port) val actor = actorSystem.actorOf( - Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory, args.master)), + Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory, + args.master, args.workDir)), name = "Worker") actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index 1efe8304ea..60dc107a4c 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -8,13 +8,14 @@ import java.lang.management.ManagementFactory /** * Command-line parser for the master. */ -class WorkerArguments(args: Array[String]) { +private[spark] class WorkerArguments(args: Array[String]) { var ip = Utils.localIpAddress() var port = 0 var webUiPort = 8081 var cores = inferDefaultCores() var memory = inferDefaultMemory() var master: String = null + var workDir: String = null // Check for settings in environment variables if (System.getenv("SPARK_WORKER_PORT") != null) { @@ -29,6 +30,9 @@ class WorkerArguments(args: Array[String]) { if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt } + if (System.getenv("SPARK_WORKER_DIR") != null) { + workDir = System.getenv("SPARK_WORKER_DIR") + } parse(args.toList) @@ -49,6 +53,10 @@ class WorkerArguments(args: Array[String]) { memory = value parse(tail) + case ("--work-dir" | "-d") :: value :: tail => + workDir = value + parse(tail) + case "--webui-port" :: IntParam(value) :: tail => webUiPort = value parse(tail) @@ -77,13 +85,14 @@ class WorkerArguments(args: Array[String]) { */ def printUsageAndExit(exitCode: Int) { System.err.println( - "Usage: spark-worker [options] <master>\n" + + "Usage: Worker [options] <master>\n" + "\n" + "Master must be a URL of the form spark://hostname:port\n" + "\n" + "Options:\n" + " -c CORES, --cores CORES Number of cores to use\n" + " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + + " -d DIR, --work-dir DIR Directory to run jobs in (default: SPARK_HOME/work)\n" + " -i IP, --ip IP IP address or DNS name to listen on\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" + " --webui-port PORT Port for web UI (default: 8081)") diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index 58a05e1a38..d06f4884ee 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -9,6 +9,7 @@ import cc.spray.Directives import cc.spray.typeconversion.TwirlSupport._ import spark.deploy.{WorkerState, RequestWorkerState} +private[spark] class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives { val RESOURCE_DIR = "spark/deploy/worker/webui" val STATIC_RESOURCE_DIR = "spark/deploy/static" @@ -21,7 +22,7 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct completeWith{ val future = worker ? RequestWorkerState future.map { workerState => - workerui.html.index(workerState.asInstanceOf[WorkerState]) + spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState]) } } } ~ diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index e3958cec51..dfdb22024e 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -1,10 +1,12 @@ package spark.executor import java.io.{File, FileOutputStream} -import java.net.{URL, URLClassLoader} +import java.net.{URI, URL, URLClassLoader} import java.util.concurrent._ -import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.fs.FileUtil + +import scala.collection.mutable.{ArrayBuffer, Map, HashMap} import spark.broadcast._ import spark.scheduler._ @@ -14,11 +16,16 @@ import java.nio.ByteBuffer /** * The Mesos executor for Spark. */ -class Executor extends Logging { - var classLoader: ClassLoader = null +private[spark] class Executor extends Logging { + var urlClassLoader : ExecutorURLClassLoader = null var threadPool: ExecutorService = null var env: SparkEnv = null + // Application dependencies (added through SparkContext) that we've fetched so far on this node. + // Each map holds the master's timestamp for the version of that file or JAR we got. + val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() + val currentJars: HashMap[String, Long] = new HashMap[String, Long]() + val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) initLogging() @@ -32,15 +39,13 @@ class Executor extends Logging { System.setProperty(key, value) } + // Create our ClassLoader and set it on this thread + urlClassLoader = createClassLoader() + Thread.currentThread.setContextClassLoader(urlClassLoader) + // Initialize Spark environment (using system properties read above) env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false) SparkEnv.set(env) - // Old stuff that isn't yet using env - Broadcast.initialize(false) - - // Create our ClassLoader (using spark properties) and set it on this thread - classLoader = createClassLoader() - Thread.currentThread.setContextClassLoader(classLoader) // Start worker thread pool threadPool = new ThreadPoolExecutor( @@ -56,15 +61,17 @@ class Executor extends Logging { override def run() { SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(classLoader) + Thread.currentThread.setContextClassLoader(urlClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo("Running task ID " + taskId) context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) try { SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear() - val task = ser.deserialize[Task[Any]](serializedTask, classLoader) + val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) + updateDependencies(taskFiles, taskJars) + val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + logInfo("Its generation is " + task.generation) env.mapOutputTracker.updateGeneration(task.generation) val value = task.run(taskId.toInt) val accumUpdates = Accumulators.values @@ -97,25 +104,15 @@ class Executor extends Logging { * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes * created by the interpreter to the search path */ - private def createClassLoader(): ClassLoader = { + private def createClassLoader(): ExecutorURLClassLoader = { var loader = this.getClass.getClassLoader - // If any JAR URIs are given through spark.jar.uris, fetch them to the - // current directory and put them all on the classpath. We assume that - // each URL has a unique file name so that no local filenames will clash - // in this process. This is guaranteed by ClusterScheduler. - val uris = System.getProperty("spark.jar.uris", "") - val localFiles = ArrayBuffer[String]() - for (uri <- uris.split(",").filter(_.size > 0)) { - val url = new URL(uri) - val filename = url.getPath.split("/").last - downloadFile(url, filename) - localFiles += filename - } - if (localFiles.size > 0) { - val urls = localFiles.map(f => new File(f).toURI.toURL).toArray - loader = new URLClassLoader(urls, loader) - } + // For each of the jars in the jarSet, add them to the class loader. + // We assume each of the files has already been fetched. + val urls = currentJars.keySet.map { uri => + new File(uri.split("/").last).toURI.toURL + }.toArray + loader = new URLClassLoader(urls, loader) // If the REPL is in use, add another ClassLoader that will read // new classes defined by the REPL as the user types code @@ -134,13 +131,31 @@ class Executor extends Logging { } } - return loader + return new ExecutorURLClassLoader(Array(), loader) } - // Download a file from a given URL to the local filesystem - private def downloadFile(url: URL, localPath: String) { - val in = url.openStream() - val out = new FileOutputStream(localPath) - Utils.copyStream(in, out, true) + /** + * Download any missing dependencies if we receive a new set of files and JARs from the + * SparkContext. Also adds any new JARs we fetched to the class loader. + */ + private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(".")) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(".")) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(".", localName).toURI.toURL + if (!urlClassLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + urlClassLoader.addURL(url) + } + } } } diff --git a/core/src/main/scala/spark/executor/ExecutorBackend.scala b/core/src/main/scala/spark/executor/ExecutorBackend.scala index 24c8776f31..e97e509700 100644 --- a/core/src/main/scala/spark/executor/ExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/ExecutorBackend.scala @@ -6,6 +6,6 @@ import spark.TaskState.TaskState /** * A pluggable interface used by the Executor to send updates to the cluster scheduler. */ -trait ExecutorBackend { +private[spark] trait ExecutorBackend { def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) } diff --git a/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala new file mode 100644 index 0000000000..5beb4d049e --- /dev/null +++ b/core/src/main/scala/spark/executor/ExecutorURLClassLoader.scala @@ -0,0 +1,14 @@ +package spark.executor + +import java.net.{URLClassLoader, URL} + +/** + * The addURL method in URLClassLoader is protected. We subclass it to make this accessible. + */ +private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) + extends URLClassLoader(urls, parent) { + + override def addURL(url: URL) { + super.addURL(url) + } +} diff --git a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala index 50f4e41ede..eeab3959c6 100644 --- a/core/src/main/scala/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/MesosExecutorBackend.scala @@ -8,7 +8,7 @@ import com.google.protobuf.ByteString import spark.{Utils, Logging} import spark.TaskState -class MesosExecutorBackend(executor: Executor) +private[spark] class MesosExecutorBackend(executor: Executor) extends MesosExecutor with ExecutorBackend with Logging { @@ -59,7 +59,7 @@ class MesosExecutorBackend(executor: Executor) /** * Entry point for Mesos executor. */ -object MesosExecutorBackend { +private[spark] object MesosExecutorBackend { def main(args: Array[String]) { MesosNativeLibrary.load() // Create a new Executor and start it running diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 26b163de0a..915f71ba9f 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -14,7 +14,7 @@ import spark.scheduler.cluster.RegisterSlaveFailed import spark.scheduler.cluster.RegisterSlave -class StandaloneExecutorBackend( +private[spark] class StandaloneExecutorBackend( executor: Executor, masterUrl: String, slaveId: String, @@ -62,7 +62,7 @@ class StandaloneExecutorBackend( } } -object StandaloneExecutorBackend { +private[spark] object StandaloneExecutorBackend { def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index 451faee66e..80262ab7b4 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -11,6 +11,7 @@ import java.nio.channels.spi._ import java.net._ +private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging { channel.configureBlocking(false) @@ -23,8 +24,8 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex var onExceptionCallback: (Connection, Exception) => Unit = null var onKeyInterestChangeCallback: (Connection, Int) => Unit = null - lazy val remoteAddress = getRemoteAddress() - lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress) + val remoteAddress = getRemoteAddress() + val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress) def key() = channel.keyFor(selector) @@ -39,7 +40,10 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex } def close() { - key.cancel() + val k = key() + if (k != null) { + k.cancel() + } channel.close() callOnCloseCallback() } @@ -99,7 +103,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex } -class SendingConnection(val address: InetSocketAddress, selector_ : Selector) +private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector) extends Connection(SocketChannel.open, selector_) { class Outbox(fair: Int = 0) { @@ -111,7 +115,7 @@ extends Connection(SocketChannel.open, selector_) { messages.synchronized{ /*messages += message*/ messages.enqueue(message) - logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") + logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") } } @@ -134,9 +138,12 @@ extends Connection(SocketChannel.open, selector_) { if (!message.started) logDebug("Starting to send [" + message + "]") message.started = true return chunk + } else { + /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ + message.finishTime = System.currentTimeMillis + logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + "] in " + message.timeTaken ) } - /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ - logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) } } None @@ -159,10 +166,11 @@ extends Connection(SocketChannel.open, selector_) { } logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]") return chunk - } - /*messages -= message*/ - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) + } else { + message.finishTime = System.currentTimeMillis + logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + + "] in " + message.timeTaken ) + } } } None @@ -216,7 +224,7 @@ extends Connection(SocketChannel.open, selector_) { while(true) { if (currentBuffers.size == 0) { outbox.synchronized { - outbox.getChunk match { + outbox.getChunk() match { case Some(chunk) => { currentBuffers ++= chunk.buffers } @@ -252,7 +260,7 @@ extends Connection(SocketChannel.open, selector_) { } -class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) +private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) extends Connection(channel_, selector_) { class Inbox() { diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index 1a22d06cc8..da39108164 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -14,19 +14,21 @@ import scala.collection.mutable.SynchronizedQueue import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer -import akka.dispatch.{Promise, ExecutionContext, Future} +import akka.dispatch.{Await, Promise, ExecutionContext, Future} +import akka.util.Duration +import akka.util.duration._ -case class ConnectionManagerId(host: String, port: Int) { +private[spark] case class ConnectionManagerId(host: String, port: Int) { def toSocketAddress() = new InetSocketAddress(host, port) } -object ConnectionManagerId { +private[spark] object ConnectionManagerId { def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) } } -class ConnectionManager(port: Int) extends Logging { +private[spark] class ConnectionManager(port: Int) extends Logging { class MessageStatus( val message: Message, @@ -111,7 +113,7 @@ class ConnectionManager(port: Int) extends Logging { val selectedKeysCount = selector.select() if (selectedKeysCount == 0) { - logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") + logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") } if (selectorThread.isInterrupted) { logInfo("Selector thread was interrupted!") @@ -165,7 +167,6 @@ class ConnectionManager(port: Int) extends Logging { } def removeConnection(connection: Connection) { - /*logInfo("Removing connection")*/ connectionsByKey -= connection.key if (connection.isInstanceOf[SendingConnection]) { val sendingConnection = connection.asInstanceOf[SendingConnection] @@ -233,7 +234,7 @@ class ConnectionManager(port: Int) extends Logging { def receiveMessage(connection: Connection, message: Message) { val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logInfo("Received [" + message + "] from [" + connectionManagerId + "]") + logDebug("Received [" + message + "] from [" + connectionManagerId + "]") val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { @@ -247,7 +248,7 @@ class ConnectionManager(port: Int) extends Logging { } private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { - logInfo("Handling [" + message + "] from [" + connectionManagerId + "]") + logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") message match { case bufferMessage: BufferMessage => { if (bufferMessage.hasAckId) { @@ -274,15 +275,15 @@ class ConnectionManager(port: Int) extends Logging { logDebug("Calling back") onReceiveCallback(bufferMessage, connectionManagerId) } else { - logWarning("Not calling back as callback is null") + logDebug("Not calling back as callback is null") None } if (ackMessage.isDefined) { if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logWarning("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass()) + logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass()) } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logWarning("Response to " + bufferMessage + " does not have ack id set") + logDebug("Response to " + bufferMessage + " does not have ack id set") ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id } } @@ -305,7 +306,7 @@ class ConnectionManager(port: Int) extends Logging { } val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection()) message.senderAddress = id.toSocketAddress() - logInfo("Sending [" + message + "] to [" + connectionManagerId + "]") + logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") /*connection.send(message)*/ sendMessageRequests.synchronized { sendMessageRequests += ((message, connection)) @@ -325,7 +326,7 @@ class ConnectionManager(port: Int) extends Logging { } def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = { - sendMessageReliably(connectionManagerId, message)() + Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf) } def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { @@ -347,7 +348,7 @@ class ConnectionManager(port: Int) extends Logging { } -object ConnectionManager { +private[spark] object ConnectionManager { def main(args: Array[String]) { @@ -402,7 +403,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis val mb = size * count / 1024.0 / 1024.0 @@ -429,7 +433,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis val ms = finishTime - startTime @@ -456,7 +463,10 @@ object ConnectionManager { (0 until count).map(i => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) + }).foreach(f => { + val g = Await.result(f, 1 second) + if (!g.isDefined) println("Failed") + }) val finishTime = System.currentTimeMillis Thread.sleep(1000) val mb = size * count / 1024.0 / 1024.0 diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala index 5d21bb793f..47ceaf3c07 100644 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -8,7 +8,10 @@ import scala.io.Source import java.nio.ByteBuffer import java.net.InetAddress -object ConnectionManagerTest extends Logging{ +import akka.dispatch.Await +import akka.util.duration._ + +private[spark] object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { if (args.length < 2) { println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>") @@ -53,7 +56,7 @@ object ConnectionManagerTest extends Logging{ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) }) - val results = futures.map(f => f()) + val results = futures.map(f => Await.result(f, 1.second)) val finishTime = System.currentTimeMillis Thread.sleep(5000) diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala index 2e85803679..525751b5bf 100644 --- a/core/src/main/scala/spark/network/Message.scala +++ b/core/src/main/scala/spark/network/Message.scala @@ -7,8 +7,9 @@ import scala.collection.mutable.ArrayBuffer import java.nio.ByteBuffer import java.net.InetAddress import java.net.InetSocketAddress +import storage.BlockManager -class MessageChunkHeader( +private[spark] class MessageChunkHeader( val typ: Long, val id: Int, val totalSize: Int, @@ -36,7 +37,7 @@ class MessageChunkHeader( " and sizes " + totalSize + " / " + chunkSize + " bytes" } -class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { +private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { val size = if (buffer == null) 0 else buffer.remaining lazy val buffers = { val ab = new ArrayBuffer[ByteBuffer]() @@ -50,7 +51,7 @@ class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" } -abstract class Message(val typ: Long, val id: Int) { +private[spark] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null var started = false var startTime = -1L @@ -64,10 +65,10 @@ abstract class Message(val typ: Long, val id: Int) { def timeTaken(): String = (finishTime - startTime).toString + " ms" - override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" + override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" } -class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) +private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) extends Message(Message.BUFFER_MESSAGE, id_) { val initialSize = currentSize() @@ -97,10 +98,11 @@ extends Message(Message.BUFFER_MESSAGE, id_) { while(!buffers.isEmpty) { val buffer = buffers(0) if (buffer.remaining == 0) { + BlockManager.dispose(buffer) buffers -= buffer } else { val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate + buffer.duplicate() } else { buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] } @@ -147,11 +149,10 @@ extends Message(Message.BUFFER_MESSAGE, id_) { } else { "BufferMessage(id = " + id + ", size = " + size + ")" } - } } -object MessageChunkHeader { +private[spark] object MessageChunkHeader { val HEADER_SIZE = 40 def create(buffer: ByteBuffer): MessageChunkHeader = { @@ -172,7 +173,7 @@ object MessageChunkHeader { } } -object Message { +private[spark] object Message { val BUFFER_MESSAGE = 1111111111L var lastId = 1 diff --git a/core/src/main/scala/spark/network/ReceiverTest.scala b/core/src/main/scala/spark/network/ReceiverTest.scala index e1ba7c06c0..a174d5f403 100644 --- a/core/src/main/scala/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/spark/network/ReceiverTest.scala @@ -3,7 +3,7 @@ package spark.network import java.nio.ByteBuffer import java.net.InetAddress -object ReceiverTest { +private[spark] object ReceiverTest { def main(args: Array[String]) { val manager = new ConnectionManager(9999) diff --git a/core/src/main/scala/spark/network/SenderTest.scala b/core/src/main/scala/spark/network/SenderTest.scala index 4ab6dd3414..a4ff69e4d2 100644 --- a/core/src/main/scala/spark/network/SenderTest.scala +++ b/core/src/main/scala/spark/network/SenderTest.scala @@ -3,7 +3,7 @@ package spark.network import java.nio.ByteBuffer import java.net.InetAddress -object SenderTest { +private[spark] object SenderTest { def main(args: Array[String]) { diff --git a/core/src/main/scala/spark/package.scala b/core/src/main/scala/spark/package.scala new file mode 100644 index 0000000000..389ec4da3e --- /dev/null +++ b/core/src/main/scala/spark/package.scala @@ -0,0 +1,15 @@ +/** + * Core Spark functionality. [[spark.SparkContext]] serves as the main entry point to Spark, while + * [[spark.RDD]] is the data type representing a distributed collection, and provides most + * parallel operations. + * + * In addition, [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value + * pairs, such as `groupByKey` and `join`; [[spark.DoubleRDDFunctions]] contains operations + * available only on RDDs of Doubles; and [[spark.SequenceFileRDDFunctions]] contains operations + * available on RDDs that can be saved as SequenceFiles. These operations are automatically + * available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit conversions when + * you `import spark.SparkContext._`. + */ +package object spark { + // For package docs only +} diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala index e6535836ab..42f46e06ed 100644 --- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala +++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala @@ -12,7 +12,7 @@ import spark.scheduler.JobListener * a result of type U for each partition, and that the action returns a partial or complete result * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt). */ -class ApproximateActionListener[T, U, R]( +private[spark] class ApproximateActionListener[T, U, R]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], diff --git a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala index 4772e43ef0..75713b2eaa 100644 --- a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala +++ b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala @@ -4,7 +4,7 @@ package spark.partial * An object that computes a function incrementally by merging in results of type U from multiple * tasks. Allows partial evaluation at any point by calling currentResult(). */ -trait ApproximateEvaluator[U, R] { +private[spark] trait ApproximateEvaluator[U, R] { def merge(outputId: Int, taskResult: U): Unit def currentResult(): R } diff --git a/core/src/main/scala/spark/partial/CountEvaluator.scala b/core/src/main/scala/spark/partial/CountEvaluator.scala index 1bc90d6b39..daf2c5170c 100644 --- a/core/src/main/scala/spark/partial/CountEvaluator.scala +++ b/core/src/main/scala/spark/partial/CountEvaluator.scala @@ -8,7 +8,7 @@ import cern.jet.stat.Probability * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might * be best to make this a special case of GroupedCountEvaluator with one group. */ -class CountEvaluator(totalOutputs: Int, confidence: Double) +private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[Long, BoundedDouble] { var outputsMerged = 0 diff --git a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala index 3e631c0efc..01fbb8a11b 100644 --- a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala @@ -14,7 +14,7 @@ import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} /** * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. */ -class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double) +private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] { var outputsMerged = 0 diff --git a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala index 2a9ccba205..c622df5220 100644 --- a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala +++ b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala @@ -12,7 +12,7 @@ import spark.util.StatCounter /** * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval. */ -class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double) +private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { var outputsMerged = 0 diff --git a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala index 6a2ec7a7bd..20fa55cff2 100644 --- a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala +++ b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala @@ -12,7 +12,7 @@ import spark.util.StatCounter /** * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval. */ -class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double) +private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { var outputsMerged = 0 diff --git a/core/src/main/scala/spark/partial/MeanEvaluator.scala b/core/src/main/scala/spark/partial/MeanEvaluator.scala index b8c7cb8863..762c85400d 100644 --- a/core/src/main/scala/spark/partial/MeanEvaluator.scala +++ b/core/src/main/scala/spark/partial/MeanEvaluator.scala @@ -7,7 +7,7 @@ import spark.util.StatCounter /** * An ApproximateEvaluator for means. */ -class MeanEvaluator(totalOutputs: Int, confidence: Double) +private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[StatCounter, BoundedDouble] { var outputsMerged = 0 diff --git a/core/src/main/scala/spark/partial/StudentTCacher.scala b/core/src/main/scala/spark/partial/StudentTCacher.scala index 6263ee3518..443abba5cd 100644 --- a/core/src/main/scala/spark/partial/StudentTCacher.scala +++ b/core/src/main/scala/spark/partial/StudentTCacher.scala @@ -7,7 +7,7 @@ import cern.jet.stat.Probability * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate * confidence intervals for many keys. */ -class StudentTCacher(confidence: Double) { +private[spark] class StudentTCacher(confidence: Double) { val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2) val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) diff --git a/core/src/main/scala/spark/partial/SumEvaluator.scala b/core/src/main/scala/spark/partial/SumEvaluator.scala index 0357a6bff8..58fb60f441 100644 --- a/core/src/main/scala/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/spark/partial/SumEvaluator.scala @@ -9,7 +9,7 @@ import spark.util.StatCounter * together, then uses the formula for the variance of two independent random variables to get * a variance for the result and compute a confidence interval. */ -class SumEvaluator(totalOutputs: Int, confidence: Double) +private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[StatCounter, BoundedDouble] { var outputsMerged = 0 diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index ea009f0f4f..cb73976aed 100644 --- a/core/src/main/scala/spark/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -1,13 +1,20 @@ -package spark +package spark.rdd import scala.collection.mutable.HashMap -class BlockRDDSplit(val blockId: String, idx: Int) extends Split { +import spark.Dependency +import spark.RDD +import spark.SparkContext +import spark.SparkEnv +import spark.Split + +private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split { val index = idx } - -class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) { +private[spark] +class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) + extends RDD[T](sc) { @transient val splits_ = (0 until blockIds.size).map(i => { diff --git a/core/src/main/scala/spark/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index e26041555a..7c354b6b2e 100644 --- a/core/src/main/scala/spark/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -1,9 +1,16 @@ -package spark +package spark.rdd +import spark.NarrowDependency +import spark.RDD +import spark.SparkContext +import spark.Split + +private[spark] class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable { override val index: Int = idx } +private[spark] class CartesianRDD[T: ClassManifest, U:ClassManifest]( sc: SparkContext, rdd1: RDD[T], @@ -44,4 +51,4 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2) } ) -}
\ No newline at end of file +} diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 6959917d14..ace2500627 100644 --- a/core/src/main/scala/spark/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -1,21 +1,33 @@ -package spark +package spark.rdd import java.net.URL import java.io.EOFException import java.io.ObjectInputStream + import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -sealed trait CoGroupSplitDep extends Serializable -case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep -case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep +import spark.Aggregator +import spark.Dependency +import spark.Logging +import spark.OneToOneDependency +import spark.Partitioner +import spark.RDD +import spark.ShuffleDependency +import spark.SparkEnv +import spark.Split + +private[spark] sealed trait CoGroupSplitDep extends Serializable +private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep +private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep +private[spark] class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable { override val index: Int = idx override def hashCode(): Int = idx } -class CoGroupAggregator +private[spark] class CoGroupAggregator extends Aggregator[Any, Any, ArrayBuffer[Any]]( { x => ArrayBuffer(x) }, { (b, x) => b += x }, @@ -36,8 +48,7 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) deps += new OneToOneDependency(rdd) } else { logInfo("Adding shuffle dependency with " + rdd) - deps += new ShuffleDependency[Any, Any, ArrayBuffer[Any]]( - context.newShuffleId, rdd, aggr, part) + deps += new ShuffleDependency[Any, Any, ArrayBuffer[Any]](rdd, Some(aggr), part) } } deps.toList diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala new file mode 100644 index 0000000000..0967f4f5df --- /dev/null +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -0,0 +1,47 @@ +package spark.rdd + +import spark.NarrowDependency +import spark.RDD +import spark.Split + +private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split + +/** + * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of + * this RDD computes one or more of the parent ones. Will produce exactly `maxPartitions` if the + * parent had more than this many partitions, or fewer if the parent had fewer. + * + * This transformation is useful when an RDD with many partitions gets filtered into a smaller one, + * or to avoid having a large number of small tasks when processing a directory with many files. + */ +class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int) + extends RDD[T](prev.context) { + + @transient val splits_ : Array[Split] = { + val prevSplits = prev.splits + if (prevSplits.length < maxPartitions) { + prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) } + } else { + (0 until maxPartitions).map { i => + val rangeStart = (i * prevSplits.length) / maxPartitions + val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions + new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd)) + }.toArray + } + } + + override def splits = splits_ + + override def compute(split: Split): Iterator[T] = { + split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { + parentSplit => prev.iterator(parentSplit) + } + } + + val dependencies = List( + new NarrowDependency(prev) { + def getParents(id: Int): Seq[Int] = + splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index) + } + ) +} diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala new file mode 100644 index 0000000000..dfe9dc73f3 --- /dev/null +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -0,0 +1,12 @@ +package spark.rdd + +import spark.OneToOneDependency +import spark.RDD +import spark.Split + +private[spark] +class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) { + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = prev.iterator(split).filter(f) +}
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala new file mode 100644 index 0000000000..3534dc8057 --- /dev/null +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -0,0 +1,16 @@ +package spark.rdd + +import spark.OneToOneDependency +import spark.RDD +import spark.Split + +private[spark] +class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: T => TraversableOnce[U]) + extends RDD[U](prev.context) { + + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = prev.iterator(split).flatMap(f) +} diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala new file mode 100644 index 0000000000..e30564f2da --- /dev/null +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -0,0 +1,12 @@ +package spark.rdd + +import spark.OneToOneDependency +import spark.RDD +import spark.Split + +private[spark] +class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) { + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator +}
\ No newline at end of file diff --git a/core/src/main/scala/spark/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index f282a4023b..bf29a1f075 100644 --- a/core/src/main/scala/spark/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -1,4 +1,4 @@ -package spark +package spark.rdd import java.io.EOFException import java.util.NoSuchElementException @@ -15,10 +15,16 @@ import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.util.ReflectionUtils +import spark.Dependency +import spark.RDD +import spark.SerializableWritable +import spark.SparkContext +import spark.Split + /** * A Spark split class that wraps around a Hadoop InputSplit. */ -class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) +private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) extends Split with Serializable { @@ -42,7 +48,8 @@ class HadoopRDD[K, V]( minSplits: Int) extends RDD[(K, V)](sc) { - val serializableConf = new SerializableWritable(conf) + // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it + val confBroadcast = sc.broadcast(new SerializableWritable(conf)) @transient val splits_ : Array[Split] = { @@ -66,7 +73,7 @@ class HadoopRDD[K, V]( val split = theSplit.asInstanceOf[HadoopSplit] var reader: RecordReader[K, V] = null - val conf = serializableConf.value + val conf = confBroadcast.value.value val fmt = createInputFormat(conf) reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala new file mode 100644 index 0000000000..b2c7a1cb9e --- /dev/null +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -0,0 +1,16 @@ +package spark.rdd + +import spark.OneToOneDependency +import spark.RDD +import spark.Split + +private[spark] +class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: Iterator[T] => Iterator[U]) + extends RDD[U](prev.context) { + + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = f(prev.iterator(split)) +}
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala new file mode 100644 index 0000000000..adc541694e --- /dev/null +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -0,0 +1,21 @@ +package spark.rdd + +import spark.OneToOneDependency +import spark.RDD +import spark.Split + +/** + * A variant of the MapPartitionsRDD that passes the split index into the + * closure. This can be used to generate or collect partition specific + * information such as the number of tuples in a partition. + */ +private[spark] +class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: (Int, Iterator[T]) => Iterator[U]) + extends RDD[U](prev.context) { + + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = f(split.index, prev.iterator(split)) +}
\ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala new file mode 100644 index 0000000000..59bedad8ef --- /dev/null +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -0,0 +1,16 @@ +package spark.rdd + +import spark.OneToOneDependency +import spark.RDD +import spark.Split + +private[spark] +class MappedRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: T => U) + extends RDD[U](prev.context) { + + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = prev.iterator(split).map(f) +}
\ No newline at end of file diff --git a/core/src/main/scala/spark/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index d024d38aa9..dcbceab246 100644 --- a/core/src/main/scala/spark/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -1,4 +1,4 @@ -package spark +package spark.rdd import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.Writable @@ -13,6 +13,13 @@ import org.apache.hadoop.mapreduce.TaskAttemptID import java.util.Date import java.text.SimpleDateFormat +import spark.Dependency +import spark.RDD +import spark.SerializableWritable +import spark.SparkContext +import spark.Split + +private[spark] class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) extends Split { @@ -28,7 +35,9 @@ class NewHadoopRDD[K, V]( @transient conf: Configuration) extends RDD[(K, V)](sc) { - private val serializableConf = new SerializableWritable(conf) + // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it + val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + // private val serializableConf = new SerializableWritable(conf) private val jobtrackerId: String = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") @@ -41,7 +50,7 @@ class NewHadoopRDD[K, V]( @transient private val splits_ : Array[Split] = { val inputFormat = inputFormatClass.newInstance - val jobContext = new JobContext(serializableConf.value, jobId) + val jobContext = new JobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Split](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -54,9 +63,9 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] - val conf = serializableConf.value + val conf = confBroadcast.value.value val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) - val context = new TaskAttemptContext(serializableConf.value, attemptId) + val context = new TaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance val reader = format.createRecordReader(split.serializableHadoopSplit.value, context) reader.initialize(split.serializableHadoopSplit.value, context) diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 3103d7889b..98ea0c92d6 100644 --- a/core/src/main/scala/spark/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -1,4 +1,4 @@ -package spark +package spark.rdd import java.io.PrintWriter import java.util.StringTokenizer @@ -8,6 +8,12 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source +import spark.OneToOneDependency +import spark.RDD +import spark.SparkEnv +import spark.Split + + /** * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. diff --git a/core/src/main/scala/spark/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index 8ef40d8d9e..87a5268f27 100644 --- a/core/src/main/scala/spark/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -1,7 +1,14 @@ -package spark +package spark.rdd import java.util.Random +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand +import spark.RDD +import spark.OneToOneDependency +import spark.Split + +private[spark] class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { override val index: Int = prev.index } @@ -28,19 +35,21 @@ class SampledRDD[T: ClassManifest]( override def compute(splitIn: Split) = { val split = splitIn.asInstanceOf[SampledRDDSplit] - val rg = new Random(split.seed) - // Sampling with replacement (TODO: use reservoir sampling to make this more efficient?) if (withReplacement) { - val oldData = prev.iterator(split.prev).toArray - val sampleSize = (oldData.size * frac).ceil.toInt - val sampledData = { - // all of oldData's indices are candidates, even if sampleSize < oldData.size - for (i <- 1 to sampleSize) - yield oldData(rg.nextInt(oldData.size)) + // For large datasets, the expected number of occurrences of each element in a sample with + // replacement is Poisson(frac). We use that to get a count for each element. + val poisson = new Poisson(frac, new DRand(split.seed)) + prev.iterator(split.prev).flatMap { element => + val count = poisson.nextInt() + if (count == 0) { + Iterator.empty // Avoid object allocation when we return 0 items, which is quite often + } else { + Iterator.fill(count)(element) + } } - sampledData.iterator } else { // Sampling without replacement - prev.iterator(split.prev).filter(x => (rg.nextDouble <= frac)) + val rand = new Random(split.seed) + prev.iterator(split.prev).filter(x => (rand.nextDouble <= frac)) } } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala new file mode 100644 index 0000000000..be120acc71 --- /dev/null +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -0,0 +1,142 @@ +package spark.rdd + +import scala.collection.mutable.ArrayBuffer +import java.util.{HashMap => JHashMap} + +import spark.Aggregator +import spark.Partitioner +import spark.RangePartitioner +import spark.RDD +import spark.ShuffleDependency +import spark.SparkEnv +import spark.Split + +private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { + override val index = idx + override def hashCode(): Int = idx +} + + +/** + * The resulting RDD from a shuffle (e.g. repartitioning of data). + */ +abstract class ShuffledRDD[K, V, C]( + @transient parent: RDD[(K, V)], + aggregator: Option[Aggregator[K, V, C]], + part: Partitioner) + extends RDD[(K, C)](parent.context) { + + override val partitioner = Some(part) + + @transient + val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) + + override def splits = splits_ + + override def preferredLocations(split: Split) = Nil + + val dep = new ShuffleDependency(parent, aggregator, part) + override val dependencies = List(dep) +} + + +/** + * Repartition a key-value pair RDD. + */ +class RepartitionShuffledRDD[K, V]( + @transient parent: RDD[(K, V)], + part: Partitioner) + extends ShuffledRDD[K, V, V]( + parent, + None, + part) { + + override def compute(split: Split): Iterator[(K, V)] = { + val buf = new ArrayBuffer[(K, V)] + val fetcher = SparkEnv.get.shuffleFetcher + def addTupleToBuffer(k: K, v: V) = { buf += Tuple(k, v) } + fetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer) + buf.iterator + } +} + + +/** + * A sort-based shuffle (that doesn't apply aggregation). It does so by first + * repartitioning the RDD by range, and then sort within each range. + */ +class ShuffledSortedRDD[K <% Ordered[K]: ClassManifest, V]( + @transient parent: RDD[(K, V)], + ascending: Boolean, + numSplits: Int) + extends RepartitionShuffledRDD[K, V]( + parent, + new RangePartitioner(numSplits, parent, ascending)) { + + override def compute(split: Split): Iterator[(K, V)] = { + // By separating this from RepartitionShuffledRDD, we avoided a + // buf.iterator.toArray call, thus avoiding building up the buffer twice. + val buf = new ArrayBuffer[(K, V)] + def addTupleToBuffer(k: K, v: V) { buf += ((k, v)) } + SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer) + if (ascending) { + buf.sortWith((x, y) => x._1 < y._1).iterator + } else { + buf.sortWith((x, y) => x._1 > y._1).iterator + } + } +} + + +/** + * The resulting RDD from shuffle and running (hash-based) aggregation. + */ +class ShuffledAggregatedRDD[K, V, C]( + @transient parent: RDD[(K, V)], + aggregator: Aggregator[K, V, C], + part : Partitioner) + extends ShuffledRDD[K, V, C](parent, Some(aggregator), part) { + + override def compute(split: Split): Iterator[(K, C)] = { + val combiners = new JHashMap[K, C] + val fetcher = SparkEnv.get.shuffleFetcher + + if (aggregator.mapSideCombine) { + // Apply combiners on map partitions. In this case, post-shuffle we get a + // list of outputs from the combiners and merge them using mergeCombiners. + def mergePairWithMapSideCombiners(k: K, c: C) { + val oldC = combiners.get(k) + if (oldC == null) { + combiners.put(k, c) + } else { + combiners.put(k, aggregator.mergeCombiners(oldC, c)) + } + } + fetcher.fetch[K, C](dep.shuffleId, split.index, mergePairWithMapSideCombiners) + } else { + // Do not apply combiners on map partitions (i.e. map side aggregation is + // turned off). Post-shuffle we get a list of values and we use mergeValue + // to merge them. + def mergePairWithoutMapSideCombiners(k: K, v: V) { + val oldC = combiners.get(k) + if (oldC == null) { + combiners.put(k, aggregator.createCombiner(v)) + } else { + combiners.put(k, aggregator.mergeValue(oldC, v)) + } + } + fetcher.fetch[K, V](dep.shuffleId, split.index, mergePairWithoutMapSideCombiners) + } + + return new Iterator[(K, C)] { + var iter = combiners.entrySet().iterator() + + def hasNext: Boolean = iter.hasNext() + + def next(): (K, C) = { + val entry = iter.next() + (entry.getKey, entry.getValue) + } + } + } +} diff --git a/core/src/main/scala/spark/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 0e8164d6ab..f0b9225f7c 100644 --- a/core/src/main/scala/spark/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -1,8 +1,14 @@ -package spark +package spark.rdd import scala.collection.mutable.ArrayBuffer -class UnionSplit[T: ClassManifest]( +import spark.Dependency +import spark.RangeDependency +import spark.RDD +import spark.SparkContext +import spark.Split + +private[spark] class UnionSplit[T: ClassManifest]( idx: Int, rdd: RDD[T], split: Split) @@ -37,7 +43,7 @@ class UnionRDD[T: ClassManifest]( override val dependencies = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 - for ((rdd, index) <- rdds.zipWithIndex) { + for (rdd <- rdds) { deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) pos += rdd.splits.size } diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala index 0ecff9ce77..5a4e9a582d 100644 --- a/core/src/main/scala/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/spark/scheduler/ActiveJob.scala @@ -5,11 +5,12 @@ import spark.TaskContext /** * Tracks information about an active job in the DAGScheduler. */ -class ActiveJob( +private[spark] class ActiveJob( val runId: Int, val finalStage: Stage, val func: (TaskContext, Iterator[_]) => _, val partitions: Array[Int], + val callSite: String, val listener: JobListener) { val numPartitions = partitions.length diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index f7472971b5..6f4c6bffd7 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -21,6 +21,7 @@ import spark.storage.BlockManagerId * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). */ +private[spark] class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging { taskSched.setListener(this) @@ -38,6 +39,11 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with eventQueue.put(HostLost(host)) } + // Called by TaskScheduler to cancel an entier TaskSet due to repeated failures. + override def taskSetFailed(taskSet: TaskSet, reason: String) { + eventQueue.put(TaskSetFailed(taskSet, reason)) + } + // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in @@ -116,7 +122,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of splits is unknown - logInfo("Registering RDD " + rdd.id + ": " + rdd) + logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") cacheTracker.registerRDD(rdd.id, rdd.splits.size) if (shuffleDep != None) { mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) @@ -139,7 +145,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with visited += r // Kind of ugly: need to register RDDs with the cache here since // we can't do it in its constructor because # of splits is unknown - logInfo("Registering parent RDD " + r.id + ": " + r) + logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")") cacheTracker.registerRDD(r.id, r.splits.size) for (dep <- r.dependencies) { dep match { @@ -183,23 +189,25 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with missing.toList } - def runJob[T, U]( + def runJob[T, U: ClassManifest]( finalRdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], + callSite: String, allowLocal: Boolean) - (implicit m: ClassManifest[U]): Array[U] = + : Array[U] = { if (partitions.size == 0) { return new Array[U](0) } val waiter = new JobWaiter(partitions.size) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter)) + eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)) waiter.getResult() match { case JobSucceeded(results: Seq[_]) => return results.asInstanceOf[Seq[U]].toArray case JobFailed(exception: Exception) => + logInfo("Failed to run " + callSite) throw exception } } @@ -208,13 +216,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], - timeout: Long - ): PartialResult[R] = + callSite: String, + timeout: Long) + : PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.splits.size).toArray - eventQueue.put(JobSubmitted(rdd, func2, partitions, false, listener)) + eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener)) return listener.getResult() // Will throw an exception if the job fails } @@ -234,13 +243,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, listener) => + case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) => val runId = nextRunId.getAndIncrement() val finalStage = newStage(finalRDD, None, runId) - val job = new ActiveJob(runId, finalStage, func, partitions, listener) + val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) updateCacheLocs() - logInfo("Got job " + job.runId + " with " + partitions.length + " output partitions") - logInfo("Final stage: " + finalStage) + logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + + " output partitions") + logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { @@ -258,6 +268,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with case completion: CompletionEvent => handleTaskCompletion(completion) + case TaskSetFailed(taskSet, reason) => + abortStage(idToStage(taskSet.stageId), reason) + case StopDAGScheduler => // Cancel any active jobs for (job <- activeJobs) { @@ -329,7 +342,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) if (missing == Nil) { - logInfo("Submitting " + stage + ", which has no missing parents") + logInfo("Submitting " + stage + " (" + stage.origin + "), which has no missing parents") submitMissingTasks(stage) running += stage } else { @@ -409,14 +422,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with case smt: ShuffleMapTask => val stage = idToStage(smt.stageId) - val bmAddress = event.result.asInstanceOf[BlockManagerId] - val host = bmAddress.ip + val status = event.result.asInstanceOf[MapStatus] + val host = status.address.ip logInfo("ShuffleMapTask finished with host " + host) if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos - stage.addOutputLoc(smt.partition, bmAddress) + stage.addOutputLoc(smt.partition, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { - logInfo(stage + " finished; looking for newly runnable stages") + logInfo(stage + " (" + stage.origin + ") finished; looking for newly runnable stages") running -= stage logInfo("running: " + running) logInfo("waiting: " + waiting) @@ -430,7 +443,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with if (stage.outputLocs.count(_ == Nil) != 0) { // Some tasks had failed; let's resubmit this stage // TODO: Lower-level scheduler should also deal with this - logInfo("Resubmitting " + stage + " because some of its tasks had failed: " + + logInfo("Resubmitting " + stage + " (" + stage.origin + + ") because some of its tasks had failed: " + stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) submitStage(stage) } else { @@ -444,6 +458,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with waiting --= newlyRunnable running ++= newlyRunnable for (stage <- newlyRunnable.sortBy(_.id)) { + logInfo("Submitting " + stage + " (" + stage.origin + "), which is now runnable") submitMissingTasks(stage) } } @@ -460,12 +475,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with running -= failedStage failed += failedStage // TODO: Cancel running tasks in the stage - logInfo("Marking " + failedStage + " for resubmision due to a fetch failure") + logInfo("Marking " + failedStage + " (" + failedStage.origin + + ") for resubmision due to a fetch failure") // Mark the map whose fetch failed as broken in the map stage val mapStage = shuffleToMapStage(shuffleId) mapStage.removeOutputLoc(mapId, bmAddress) mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission") + logInfo("The failed fetch was from " + mapStage + " (" + mapStage.origin + + "); marking it for resubmission") failed += mapStage // Remember that a fetch failed now; this is used to resubmit the broken // stages later, after a small wait (to give other tasks the chance to fail) @@ -475,18 +492,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with handleHostLost(bmAddress.ip) } - case _ => - // Non-fetch failure -- probably a bug in the job, so bail out - // TODO: Cancel all tasks that are still running - resultStageToJob.get(stage) match { - case Some(job) => - val error = new SparkException("Task failed: " + task + ", reason: " + event.reason) - job.listener.jobFailed(error) - activeJobs -= job - resultStageToJob -= stage - case None => - logInfo("Ignoring result from " + task + " because its job has finished") - } + case other => + // Non-fetch failure -- probably a bug in user code; abort all jobs depending on this stage + abortStage(idToStage(task.stageId), task + " failed: " + other) } } @@ -509,6 +517,53 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with updateCacheLocs() } } + + /** + * Aborts all jobs depending on a particular Stage. This is called in response to a task set + * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. + */ + def abortStage(failedStage: Stage, reason: String) { + val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq + for (resultStage <- dependentStages) { + val job = resultStageToJob(resultStage) + job.listener.jobFailed(new SparkException("Job failed: " + reason)) + activeJobs -= job + resultStageToJob -= resultStage + } + if (dependentStages.isEmpty) { + logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") + } + } + + /** + * Return true if one of stage's ancestors is target. + */ + def stageDependsOn(stage: Stage, target: Stage): Boolean = { + if (stage == target) { + return true + } + val visitedRdds = new HashSet[RDD[_]] + val visitedStages = new HashSet[Stage] + def visit(rdd: RDD[_]) { + if (!visitedRdds(rdd)) { + visitedRdds += rdd + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_,_,_] => + val mapStage = getShuffleMapStage(shufDep, stage.priority) + if (!mapStage.isAvailable) { + visitedStages += mapStage + visit(mapStage.rdd) + } // Otherwise there's no need to follow the dependency back + case narrowDep: NarrowDependency[_] => + visit(narrowDep.rdd) + } + } + } + } + visit(stage.rdd) + visitedRdds.contains(target.rdd) + } def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { // If the partition is cached, return the cache locations diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index 0fc73059c3..3422a21d9d 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -10,23 +10,26 @@ import spark._ * submitted) but there is a single "logic" thread that reads these events and takes decisions. * This greatly simplifies synchronization. */ -sealed trait DAGSchedulerEvent +private[spark] sealed trait DAGSchedulerEvent -case class JobSubmitted( +private[spark] case class JobSubmitted( finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], allowLocal: Boolean, + callSite: String, listener: JobListener) extends DAGSchedulerEvent -case class CompletionEvent( +private[spark] case class CompletionEvent( task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]) extends DAGSchedulerEvent -case class HostLost(host: String) extends DAGSchedulerEvent +private[spark] case class HostLost(host: String) extends DAGSchedulerEvent -case object StopDAGScheduler extends DAGSchedulerEvent +private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent + +private[spark] case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/JobListener.scala b/core/src/main/scala/spark/scheduler/JobListener.scala index d4dd536a7d..f46b9d551d 100644 --- a/core/src/main/scala/spark/scheduler/JobListener.scala +++ b/core/src/main/scala/spark/scheduler/JobListener.scala @@ -5,7 +5,7 @@ package spark.scheduler * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole * job fails (and no further taskSucceeded events will happen). */ -trait JobListener { +private[spark] trait JobListener { def taskSucceeded(index: Int, result: Any) def jobFailed(exception: Exception) } diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala index 62b458eccb..c4a74e526f 100644 --- a/core/src/main/scala/spark/scheduler/JobResult.scala +++ b/core/src/main/scala/spark/scheduler/JobResult.scala @@ -3,7 +3,7 @@ package spark.scheduler /** * A result of a job in the DAGScheduler. */ -sealed trait JobResult +private[spark] sealed trait JobResult -case class JobSucceeded(results: Seq[_]) extends JobResult -case class JobFailed(exception: Exception) extends JobResult +private[spark] case class JobSucceeded(results: Seq[_]) extends JobResult +private[spark] case class JobFailed(exception: Exception) extends JobResult diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala index 4c2ae23051..b3d4feebe5 100644 --- a/core/src/main/scala/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala @@ -5,7 +5,7 @@ import scala.collection.mutable.ArrayBuffer /** * An object that waits for a DAGScheduler job to complete. */ -class JobWaiter(totalTasks: Int) extends JobListener { +private[spark] class JobWaiter(totalTasks: Int) extends JobListener { private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null) private var finishedTasks = 0 diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala new file mode 100644 index 0000000000..4532d9497f --- /dev/null +++ b/core/src/main/scala/spark/scheduler/MapStatus.scala @@ -0,0 +1,27 @@ +package spark.scheduler + +import spark.storage.BlockManagerId +import java.io.{ObjectOutput, ObjectInput, Externalizable} + +/** + * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the + * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. + * The map output sizes are compressed using MapOutputTracker.compressSize. + */ +private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes: Array[Byte]) + extends Externalizable { + + def this() = this(null, null) // For deserialization only + + def writeExternal(out: ObjectOutput) { + address.writeExternal(out) + out.writeInt(compressedSizes.length) + out.write(compressedSizes) + } + + def readExternal(in: ObjectInput) { + address = new BlockManagerId(in) + compressedSizes = new Array[Byte](in.readInt()) + in.readFully(compressedSizes) + } +} diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 090ced9d76..2ebd4075a2 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -2,7 +2,7 @@ package spark.scheduler import spark._ -class ResultTask[T, U]( +private[spark] class ResultTask[T, U]( stageId: Int, rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index f78e0e5fb2..86796d3677 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -1,10 +1,10 @@ package spark.scheduler import java.io._ -import java.util.HashMap +import java.util.{HashMap => JHashMap} import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConversions._ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream @@ -15,9 +15,12 @@ import com.ning.compress.lzf.LZFOutputStream import spark._ import spark.storage._ -object ShuffleMapTask { - val serializedInfoCache = new HashMap[Int, Array[Byte]] - val deserializedInfoCache = new HashMap[Int, (RDD[_], ShuffleDependency[_,_,_])] +private[spark] object ShuffleMapTask { + + // A simple map between the stage id to the serialized byte array of a task. + // Served as a cache for task serialization because serialization can be + // expensive on the master node if it needs to launch thousands of tasks. + val serializedInfoCache = new JHashMap[Int, Array[Byte]] def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = { synchronized { @@ -26,7 +29,8 @@ object ShuffleMapTask { return old } else { val out = new ByteArrayOutputStream - val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) + val ser = SparkEnv.get.closureSerializer.newInstance + val objOut = ser.serializeStream(new GZIPOutputStream(out)) objOut.writeObject(rdd) objOut.writeObject(dep) objOut.close() @@ -39,40 +43,38 @@ object ShuffleMapTask { def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = { synchronized { - val old = deserializedInfoCache.get(stageId) - if (old != null) { - return old - } else { - val loader = Thread.currentThread.getContextClassLoader - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]] - val tuple = (rdd, dep) - deserializedInfoCache.put(stageId, tuple) - return tuple - } + val loader = Thread.currentThread.getContextClassLoader + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]] + return (rdd, dep) } } + // Since both the JarSet and FileSet have the same format this is used for both. + def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = { + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val objIn = new ObjectInputStream(in) + val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap + return (HashMap(set.toSeq: _*)) + } + def clearCache() { synchronized { serializedInfoCache.clear() - deserializedInfoCache.clear() } } } -class ShuffleMapTask( +private[spark] class ShuffleMapTask( stageId: Int, var rdd: RDD[_], var dep: ShuffleDependency[_,_,_], var partition: Int, @transient var locs: Seq[String]) - extends Task[BlockManagerId](stageId) + extends Task[MapStatus](stageId) with Externalizable with Logging { @@ -90,6 +92,7 @@ class ShuffleMapTask( out.writeInt(bytes.length) out.write(bytes) out.writeInt(partition) + out.writeLong(generation) out.writeObject(split) } @@ -102,35 +105,54 @@ class ShuffleMapTask( rdd = rdd_ dep = dep_ partition = in.readInt() + generation = in.readLong() split = in.readObject().asInstanceOf[Split] } - override def run(attemptId: Long): BlockManagerId = { + override def run(attemptId: Long): MapStatus = { val numOutputSplits = dep.partitioner.numPartitions - val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]] val partitioner = dep.partitioner - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any]) - for (elem <- rdd.iterator(split)) { - val (k, v) = elem.asInstanceOf[(Any, Any)] - var bucketId = partitioner.getPartition(k) - val bucket = buckets(bucketId) - var existing = bucket.get(k) - if (existing == null) { - bucket.put(k, aggregator.createCombiner(v)) + + val bucketIterators = + if (dep.aggregator.isDefined && dep.aggregator.get.mapSideCombine) { + val aggregator = dep.aggregator.get.asInstanceOf[Aggregator[Any, Any, Any]] + // Apply combiners (map-side aggregation) to the map output. + val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any]) + for (elem <- rdd.iterator(split)) { + val (k, v) = elem.asInstanceOf[(Any, Any)] + val bucketId = partitioner.getPartition(k) + val bucket = buckets(bucketId) + val existing = bucket.get(k) + if (existing == null) { + bucket.put(k, aggregator.createCombiner(v)) + } else { + bucket.put(k, aggregator.mergeValue(existing, v)) + } + } + buckets.map(_.iterator) } else { - bucket.put(k, aggregator.mergeValue(existing, v)) + // No combiners (no map-side aggregation). Simply partition the map output. + val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)]) + for (elem <- rdd.iterator(split)) { + val pair = elem.asInstanceOf[(Any, Any)] + val bucketId = partitioner.getPartition(pair._1) + buckets(bucketId) += pair + } + buckets.map(_.iterator) } - } - val ser = SparkEnv.get.serializer.newInstance() + + val compressedSizes = new Array[Byte](numOutputSplits) + val blockManager = SparkEnv.get.blockManager for (i <- 0 until numOutputSplits) { - val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i - // Get a scala iterator from java map - val iter: Iterator[(Any, Any)] = buckets(i).iterator - // TODO: This should probably be DISK_ONLY - blockManager.put(blockId, iter, StorageLevel.MEMORY_ONLY, false) + val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i + // Get a Scala iterator from Java map + val iter: Iterator[(Any, Any)] = bucketIterators(i) + val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false) + compressedSizes(i) = MapOutputTracker.compressSize(size) } - return SparkEnv.get.blockManager.blockManagerId + + return new MapStatus(blockManager.blockManagerId, compressedSizes) } override def preferredLocations: Seq[String] = locs diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index cd660c9085..1149c00a23 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -19,7 +19,7 @@ import spark.storage.BlockManagerId * Each Stage also has a priority, which is (by default) based on the job it was submitted in. * This allows Stages from earlier jobs to be computed first or recovered faster on failure. */ -class Stage( +private[spark] class Stage( val id: Int, val rdd: RDD[_], val shuffleDep: Option[ShuffleDependency[_,_,_]], // Output shuffle if stage is a map stage @@ -29,29 +29,29 @@ class Stage( val isShuffleMap = shuffleDep != None val numPartitions = rdd.splits.size - val outputLocs = Array.fill[List[BlockManagerId]](numPartitions)(Nil) + val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) var numAvailableOutputs = 0 private var nextAttemptId = 0 def isAvailable: Boolean = { - if (/*parents.size == 0 &&*/ !isShuffleMap) { + if (!isShuffleMap) { true } else { numAvailableOutputs == numPartitions } } - def addOutputLoc(partition: Int, bmAddress: BlockManagerId) { + def addOutputLoc(partition: Int, status: MapStatus) { val prevList = outputLocs(partition) - outputLocs(partition) = bmAddress :: prevList + outputLocs(partition) = status :: prevList if (prevList == Nil) numAvailableOutputs += 1 } def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { val prevList = outputLocs(partition) - val newList = prevList.filterNot(_ == bmAddress) + val newList = prevList.filterNot(_.address == bmAddress) outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { numAvailableOutputs -= 1 @@ -62,7 +62,7 @@ class Stage( var becameUnavailable = false for (partition <- 0 until numPartitions) { val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.ip == host) + val newList = prevList.filterNot(_.address.ip == host) outputLocs(partition) = newList if (prevList != Nil && newList == Nil) { becameUnavailable = true @@ -80,6 +80,8 @@ class Stage( return id } + def origin: String = rdd.origin + override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]" override def hashCode(): Int = id diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala index f84d8d9c4f..ef987fdeb6 100644 --- a/core/src/main/scala/spark/scheduler/Task.scala +++ b/core/src/main/scala/spark/scheduler/Task.scala @@ -1,11 +1,95 @@ package spark.scheduler +import scala.collection.mutable.HashMap +import spark.serializer.{SerializerInstance, Serializer} +import java.io.{DataInputStream, DataOutputStream} +import java.nio.ByteBuffer +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import spark.util.ByteBufferInputStream +import scala.collection.mutable.HashMap + /** * A task to execute on a worker node. */ -abstract class Task[T](val stageId: Int) extends Serializable { +private[spark] abstract class Task[T](val stageId: Int) extends Serializable { def run(attemptId: Long): T def preferredLocations: Seq[String] = Nil var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler. } + +/** + * Handles transmission of tasks and their dependencies, because this can be slightly tricky. We + * need to send the list of JARs and files added to the SparkContext with each task to ensure that + * worker nodes find out about it, but we can't make it part of the Task because the user's code in + * the task might depend on one of the JARs. Thus we serialize each task as multiple objects, by + * first writing out its dependencies. + */ +private[spark] object Task { + /** + * Serialize a task and the current app dependencies (files and JARs added to the SparkContext) + */ + def serializeWithDependencies( + task: Task[_], + currentFiles: HashMap[String, Long], + currentJars: HashMap[String, Long], + serializer: SerializerInstance) + : ByteBuffer = { + + val out = new FastByteArrayOutputStream(4096) + val dataOut = new DataOutputStream(out) + + // Write currentFiles + dataOut.writeInt(currentFiles.size) + for ((name, timestamp) <- currentFiles) { + dataOut.writeUTF(name) + dataOut.writeLong(timestamp) + } + + // Write currentJars + dataOut.writeInt(currentJars.size) + for ((name, timestamp) <- currentJars) { + dataOut.writeUTF(name) + dataOut.writeLong(timestamp) + } + + // Write the task itself and finish + dataOut.flush() + val taskBytes = serializer.serialize(task).array() + out.write(taskBytes) + out.trim() + ByteBuffer.wrap(out.array) + } + + /** + * Deserialize the list of dependencies in a task serialized with serializeWithDependencies, + * and return the task itself as a serialized ByteBuffer. The caller can then update its + * ClassLoaders and deserialize the task. + * + * @return (taskFiles, taskJars, taskBytes) + */ + def deserializeWithDependencies(serializedTask: ByteBuffer) + : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { + + val in = new ByteBufferInputStream(serializedTask) + val dataIn = new DataInputStream(in) + + // Read task's files + val taskFiles = new HashMap[String, Long]() + val numFiles = dataIn.readInt() + for (i <- 0 until numFiles) { + taskFiles(dataIn.readUTF()) = dataIn.readLong() + } + + // Read task's JARs + val taskJars = new HashMap[String, Long]() + val numJars = dataIn.readInt() + for (i <- 0 until numJars) { + taskJars(dataIn.readUTF()) = dataIn.readLong() + } + + // Create a sub-buffer for the rest of the data, which is the serialized Task object + val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task + (taskFiles, taskJars, subBuffer) + } +} diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala index 868ddb237c..9a54d0e854 100644 --- a/core/src/main/scala/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/spark/scheduler/TaskResult.scala @@ -7,6 +7,7 @@ import scala.collection.mutable.Map // Task result. Also contains updates to accumulator variables. // TODO: Use of distributed cache to return result is a hack to get around // what seems to be a bug with messages over 60KB in libprocess; fix it +private[spark] class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable { def this() = this(null.asInstanceOf[T], null) diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala index c35633d53c..d549b184b0 100644 --- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala @@ -7,7 +7,7 @@ package spark.scheduler * are failures, and mitigating stragglers. They return events to the DAGScheduler through * the TaskSchedulerListener interface. */ -trait TaskScheduler { +private[spark] trait TaskScheduler { def start(): Unit // Disconnect from the cluster. diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala index a647eec9e4..fa4de15d0d 100644 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala @@ -7,10 +7,13 @@ import spark.TaskEndReason /** * Interface for getting events back from the TaskScheduler. */ -trait TaskSchedulerListener { +private[spark] trait TaskSchedulerListener { // A task has finished or failed. def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit // A node was lost from the cluster. def hostLost(host: String): Unit + + // The TaskScheduler wants to abort an entire task set. + def taskSetFailed(taskSet: TaskSet, reason: String): Unit } diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala index 6f29dd2e9d..a3002ca477 100644 --- a/core/src/main/scala/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/spark/scheduler/TaskSet.scala @@ -4,6 +4,8 @@ package spark.scheduler * A set of tasks submitted together to the low-level TaskScheduler, usually representing * missing partitions of a particular stage. */ -class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) { +private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) { val id: String = stageId + "." + attempt + + override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 5b59479682..f5e852d203 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -16,7 +16,7 @@ import java.util.concurrent.atomic.AtomicLong * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call * start(), then submit task sets through the runTasks method. */ -class ClusterScheduler(sc: SparkContext) +private[spark] class ClusterScheduler(val sc: SparkContext) extends TaskScheduler with Logging { @@ -60,7 +60,6 @@ class ClusterScheduler(sc: SparkContext) def initialize(context: SchedulerBackend) { backend = context - createJarServer() } def newTaskId(): Long = nextTaskId.getAndIncrement() @@ -115,6 +114,7 @@ class ClusterScheduler(sc: SparkContext) */ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = { synchronized { + SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname for (o <- offers) { slaveIdToHost(o.slaveId) = o.hostname @@ -235,32 +235,7 @@ class ClusterScheduler(sc: SparkContext) } override def defaultParallelism() = backend.defaultParallelism() - - // Create a server for all the JARs added by the user to SparkContext. - // We first copy the JARs to a temp directory for easier server setup. - private def createJarServer() { - val jarDir = Utils.createTempDir() - logInfo("Temp directory for JARs: " + jarDir) - val filenames = ArrayBuffer[String]() - // Copy each JAR to a unique filename in the jarDir - for ((path, index) <- sc.jars.zipWithIndex) { - val file = new File(path) - if (file.exists) { - val filename = index + "_" + file.getName - Utils.copyFile(file, new File(jarDir, filename)) - filenames += filename - } - } - // Create the server - jarServer = new HttpServer(jarDir) - jarServer.start() - // Build up the jar URI list - val serverUri = jarServer.uri - jarUris = filenames.map(f => serverUri + "/" + f).mkString(",") - System.setProperty("spark.jar.uris", jarUris) - logInfo("JAR server started at " + serverUri) - } - + // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { var shouldRevive = false diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala index 897976c3f9..ddcd64d7c6 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala @@ -5,7 +5,7 @@ package spark.scheduler.cluster * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as * machines become available and can launch tasks on them. */ -trait SchedulerBackend { +private[spark] trait SchedulerBackend { def start(): Unit def stop(): Unit def reviveOffers(): Unit diff --git a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala b/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala index e15d577a8b..96ebaa4601 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SlaveResources.scala @@ -1,3 +1,4 @@ package spark.scheduler.cluster +private[spark] class SlaveResources(val slaveId: String, val hostname: String, val coresFree: Int) {} diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 0bd2d15479..7aba7324ab 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -5,7 +5,7 @@ import spark.deploy.client.{Client, ClientListener} import spark.deploy.{Command, JobDescription} import scala.collection.mutable.HashMap -class SparkDeploySchedulerBackend( +private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, sc: SparkContext, master: String, @@ -16,17 +16,10 @@ class SparkDeploySchedulerBackend( var client: Client = null var stopping = false + var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt - // Environment variables to pass to our executors - val ENV_VARS_TO_SEND_TO_EXECUTORS = Array( - "SPARK_MEM", - "SPARK_CLASSPATH", - "SPARK_LIBRARY_PATH", - "SPARK_JAVA_OPTS" - ) - // Memory used by each executor (in megabytes) val executorMemory = { if (System.getenv("SPARK_MEM") != null) { @@ -40,17 +33,11 @@ class SparkDeploySchedulerBackend( override def start() { super.start() - val environment = new HashMap[String, String] - for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) { - if (System.getenv(key) != null) { - environment(key) = System.getenv(key) - } - } val masterUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.master.host"), System.getProperty("spark.master.port"), StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}") - val command = Command("spark.executor.StandaloneExecutorBackend", args, environment) + val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command) client = new Client(sc.env.actorSystem, master, jobDesc, this) @@ -61,6 +48,9 @@ class SparkDeploySchedulerBackend( stopping = true; super.stop() client.stop() + if (shutdownCallback != null) { + shutdownCallback(this) + } } def connected(jobId: String) { diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index 80e8733671..1386cd9d44 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -4,19 +4,27 @@ import spark.TaskState.TaskState import java.nio.ByteBuffer import spark.util.SerializableBuffer -sealed trait StandaloneClusterMessage extends Serializable +private[spark] sealed trait StandaloneClusterMessage extends Serializable // Master to slaves +private[spark] case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage + +private[spark] case class RegisteredSlave(sparkProperties: Seq[(String, String)]) extends StandaloneClusterMessage + +private[spark] case class RegisterSlaveFailed(message: String) extends StandaloneClusterMessage // Slaves to master +private[spark] case class RegisterSlave(slaveId: String, host: String, cores: Int) extends StandaloneClusterMessage +private[spark] case class StatusUpdate(slaveId: String, taskId: Long, state: TaskState, data: SerializableBuffer) extends StandaloneClusterMessage +private[spark] object StatusUpdate { /** Alternate factory method that takes a ByteBuffer directly for the data field */ def apply(slaveId: String, taskId: Long, state: TaskState, data: ByteBuffer): StatusUpdate = { @@ -25,5 +33,5 @@ object StatusUpdate { } // Internal messages in master -case object ReviveOffers extends StandaloneClusterMessage -case object StopMaster extends StandaloneClusterMessage +private[spark] case object ReviveOffers extends StandaloneClusterMessage +private[spark] case object StopMaster extends StandaloneClusterMessage diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 013671c1c8..d2cce0dc05 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -2,19 +2,21 @@ package spark.scheduler.cluster import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import akka.actor.{Props, Actor, ActorRef, ActorSystem} +import akka.actor._ import akka.util.duration._ import akka.pattern.ask import spark.{SparkException, Logging, TaskState} import akka.dispatch.Await import java.util.concurrent.atomic.AtomicInteger +import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} /** * A standalone scheduler backend, which waits for standalone executors to connect to it through * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*). */ +private[spark] class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) extends SchedulerBackend with Logging { @@ -23,8 +25,16 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor class MasterActor(sparkProperties: Seq[(String, String)]) extends Actor { val slaveActor = new HashMap[String, ActorRef] + val slaveAddress = new HashMap[String, Address] val slaveHost = new HashMap[String, String] val freeCores = new HashMap[String, Int] + val actorToSlaveId = new HashMap[ActorRef, String] + val addressToSlaveId = new HashMap[Address, String] + + override def preStart() { + // Listen for remote client disconnection events, since they don't go through Akka's watch() + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + } def receive = { case RegisterSlave(slaveId, host, cores) => @@ -33,9 +43,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } else { logInfo("Registered slave: " + sender + " with ID " + slaveId) sender ! RegisteredSlave(sparkProperties) + context.watch(sender) slaveActor(slaveId) = sender slaveHost(slaveId) = host freeCores(slaveId) = cores + slaveAddress(slaveId) = sender.path.address + actorToSlaveId(sender) = slaveId + addressToSlaveId(sender.path.address) = slaveId totalCoreCount.addAndGet(cores) makeOffers() } @@ -54,7 +68,14 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor sender ! true context.stop(self) - // TODO: Deal with nodes disconnecting too! (Including decreasing totalCoreCount) + case Terminated(actor) => + actorToSlaveId.get(actor).foreach(removeSlave) + + case RemoteClientDisconnected(transport, address) => + addressToSlaveId.get(address).foreach(removeSlave) + + case RemoteClientShutdown(transport, address) => + addressToSlaveId.get(address).foreach(removeSlave) } // Make fake resource offers on all slaves @@ -76,6 +97,20 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor slaveActor(task.slaveId) ! LaunchTask(task) } } + + // Remove a disconnected slave from the cluster + def removeSlave(slaveId: String) { + logInfo("Slave " + slaveId + " disconnected, so removing it") + val numCores = freeCores(slaveId) + actorToSlaveId -= slaveActor(slaveId) + addressToSlaveId -= slaveAddress(slaveId) + slaveActor -= slaveId + slaveHost -= slaveId + freeCores -= slaveId + slaveHost -= slaveId + totalCoreCount.addAndGet(-numCores) + scheduler.slaveLost(slaveId) + } } var masterActor: ActorRef = null @@ -115,6 +150,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2) } -object StandaloneSchedulerBackend { +private[spark] object StandaloneSchedulerBackend { val ACTOR_NAME = "StandaloneScheduler" } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala index f9a1b74fa5..aa097fd3a2 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala @@ -3,7 +3,7 @@ package spark.scheduler.cluster import java.nio.ByteBuffer import spark.util.SerializableBuffer -class TaskDescription( +private[spark] class TaskDescription( val taskId: Long, val slaveId: String, val name: String, diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala index 0fc1d8ed30..ca84503780 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala @@ -3,6 +3,7 @@ package spark.scheduler.cluster /** * Information about a running task attempt inside a TaskSet. */ +private[spark] class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: String) { var finishTime: Long = 0 var failed = false @@ -20,6 +21,8 @@ class TaskInfo(val taskId: Long, val index: Int, val launchTime: Long, val host: def successful: Boolean = finished && !failed + def running: Boolean = !finished + def duration: Long = { if (!finished) { throw new UnsupportedOperationException("duration() called on unfinished tasks") diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index be24316e80..cf4aae03a7 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -17,7 +17,7 @@ import java.nio.ByteBuffer /** * Schedules the tasks within a single TaskSet in the ClusterScheduler. */ -class TaskSetManager( +private[spark] class TaskSetManager( sched: ClusterScheduler, val taskSet: TaskSet) extends Logging { @@ -88,6 +88,7 @@ class TaskSetManager( // Figure out the current map output tracker generation and set it on all tasks val generation = sched.mapOutputTracker.getGeneration + logDebug("Generation for " + taskSet.id + ": " + generation) for (t <- tasks) { t.generation = generation } @@ -213,7 +214,8 @@ class TaskSetManager( } // Serialize and return the task val startTime = System.currentTimeMillis - val serializedTask = ser.serialize(task) + val serializedTask = Task.serializeWithDependencies( + task, sched.sc.addedFiles, sched.sc.addedJars, ser) val timeTaken = System.currentTimeMillis - startTime logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) @@ -242,6 +244,11 @@ class TaskSetManager( def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } val index = info.index info.markSuccessful() if (!finished(index)) { @@ -264,6 +271,11 @@ class TaskSetManager( def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } val index = info.index info.markFailed() if (!finished(index)) { @@ -329,18 +341,19 @@ class TaskSetManager( def error(message: String) { // Save the error message - abort("Mesos error: " + message) + abort("Error: " + message) } def abort(message: String) { failed = true causeOfFailure = message // TODO: Kill running tasks if we were not terminated due to a Mesos error + sched.listener.taskSetFailed(taskSet, message) sched.taskSetFinished(this) } def hostLost(hostname: String) { - logInfo("Re-queueing tasks for " + hostname) + logInfo("Re-queueing tasks for " + hostname + " from TaskSet " + taskSet.id) // If some task has preferred locations only on hostname, put it in the no-prefs list // to avoid the wait from delay scheduling for (index <- getPendingTasksForHost(hostname)) { @@ -349,7 +362,7 @@ class TaskSetManager( pendingTasksWithNoPrefs += index } } - // Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage + // Re-enqueue any tasks that ran on the failed host if this is a shuffle map stage if (tasks(0).isInstanceOf[ShuffleMapTask]) { for ((tid, info) <- taskInfos if info.host == hostname) { val index = taskInfos(tid).index @@ -364,6 +377,10 @@ class TaskSetManager( } } } + // Also re-enqueue any tasks that were running on the node + for ((tid, info) <- taskInfos if info.running && info.host == hostname) { + taskLost(tid, TaskState.KILLED, null) + } } /** diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala index 1e83f103e7..6b919d68b2 100644 --- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala +++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala @@ -3,5 +3,6 @@ package spark.scheduler.cluster /** * Represents free resources available on a worker node. */ +private[spark] class WorkerOffer(val slaveId: String, val hostname: String, val cores: Int) { } diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index eb47988f0c..b84b4dc2ed 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -1,9 +1,13 @@ package spark.scheduler.local +import java.io.File +import java.net.URLClassLoader import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable.HashMap import spark._ +import executor.ExecutorURLClassLoader import spark.scheduler._ /** @@ -11,15 +15,25 @@ import spark.scheduler._ * the scheduler also allows each task to fail up to maxFailures times, which is useful for * testing fault recovery. */ -class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging { +private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext) + extends TaskScheduler + with Logging { + var attemptId = new AtomicInteger(0) var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) val env = SparkEnv.get var listener: TaskSchedulerListener = null + // Application dependencies (added through SparkContext) that we've fetched so far on this node. + // Each map holds the master's timestamp for the version of that file or JAR we got. + val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() + val currentJars: HashMap[String, Long] = new HashMap[String, Long]() + + val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) + // TODO: Need to take into account stage priority in scheduling - override def start() {} + override def start() { } override def setListener(listener: TaskSchedulerListener) { this.listener = listener @@ -43,15 +57,22 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with // Set the Spark execution environment for the worker thread SparkEnv.set(env) try { + Accumulators.clear() + Thread.currentThread().setContextClassLoader(classLoader) + // Serialize and deserialize the task so that accumulators are changed to thread-local ones; // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - Accumulators.clear val ser = SparkEnv.get.closureSerializer.newInstance() - val bytes = ser.serialize(task) + val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser) logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes") + val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) + updateDependencies(taskFiles, taskJars) // Download any files added with addFile val deserializedTask = ser.deserialize[Task[_]]( - bytes, Thread.currentThread.getContextClassLoader) + taskBytes, Thread.currentThread.getContextClassLoader) + + // Run it val result: Any = deserializedTask.run(attemptId) + // Serialize and deserialize the result to emulate what the Mesos // executor does. This is useful to catch serialization errors early // on in development (so when users move their local Spark programs @@ -80,6 +101,31 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with submitTask(task, i) } } + + /** + * Download any missing dependencies if we receive a new set of files and JARs from the + * SparkContext. Also adds any new JARs we fetched to the class loader. + */ + private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { + // Fetch missing dependencies + for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(".")) + currentFiles(name) = timestamp + } + for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { + logInfo("Fetching " + name + " with timestamp " + timestamp) + Utils.fetchFile(name, new File(".")) + currentJars(name) = timestamp + // Add it to our class loader + val localName = name.split("/").last + val url = new File(".", localName).toURI.toURL + if (!classLoader.getURLs.contains(url)) { + logInfo("Adding " + url + " to class loader") + classLoader.addURL(url) + } + } + } override def stop() { threadPool.shutdownNow() diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index 31784985dc..c45c7df69c 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -24,7 +24,7 @@ import spark.TaskState * Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to * remove this. */ -class CoarseMesosSchedulerBackend( +private[spark] class CoarseMesosSchedulerBackend( scheduler: ClusterScheduler, sc: SparkContext, master: String, @@ -33,14 +33,6 @@ class CoarseMesosSchedulerBackend( with MScheduler with Logging { - // Environment variables to pass to our executors - val ENV_VARS_TO_SEND_TO_EXECUTORS = Array( - "SPARK_MEM", - "SPARK_CLASSPATH", - "SPARK_LIBRARY_PATH", - "SPARK_JAVA_OPTS" - ) - val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures // Memory used by each executor (in megabytes) @@ -80,6 +72,8 @@ class CoarseMesosSchedulerBackend( "property, the SPARK_HOME environment variable or the SparkContext constructor") } + val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt + var nextMesosTaskId = 0 def newMesosTaskId(): Int = { @@ -120,13 +114,11 @@ class CoarseMesosSchedulerBackend( val command = "\"%s\" spark.executor.StandaloneExecutorBackend %s %s %s %d".format( runScript, masterUrl, offer.getSlaveId.getValue, offer.getHostname, numCores) val environment = Environment.newBuilder() - for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) { - if (System.getenv(key) != null) { - environment.addVariables(Environment.Variable.newBuilder() - .setName(key) - .setValue(System.getenv(key)) - .build()) - } + sc.executorEnvs.foreach { case (key, value) => + environment.addVariables(Environment.Variable.newBuilder() + .setName(key) + .setValue(value) + .build()) } return CommandInfo.newBuilder().setValue(command).setEnvironment(environment).build() } @@ -177,7 +169,7 @@ class CoarseMesosSchedulerBackend( val task = MesosTaskInfo.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse)) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) .addResources(createResource("mem", executorMemory)) diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 44eda93dd1..cdfe1f2563 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -20,7 +20,7 @@ import spark.TaskState * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks * from multiple apps can run on different cores) and in time (a core can switch ownership). */ -class MesosSchedulerBackend( +private[spark] class MesosSchedulerBackend( scheduler: ClusterScheduler, sc: SparkContext, master: String, @@ -29,14 +29,6 @@ class MesosSchedulerBackend( with MScheduler with Logging { - // Environment variables to pass to our executors - val ENV_VARS_TO_SEND_TO_EXECUTORS = Array( - "SPARK_MEM", - "SPARK_CLASSPATH", - "SPARK_LIBRARY_PATH", - "SPARK_JAVA_OPTS" - ) - // Memory used by each executor (in megabytes) val EXECUTOR_MEMORY = { if (System.getenv("SPARK_MEM") != null) { @@ -93,13 +85,11 @@ class MesosSchedulerBackend( } val execScript = new File(sparkHome, "spark-executor").getCanonicalPath val environment = Environment.newBuilder() - for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) { - if (System.getenv(key) != null) { - environment.addVariables(Environment.Variable.newBuilder() - .setName(key) - .setValue(System.getenv(key)) - .build()) - } + sc.executorEnvs.foreach { case (key, value) => + environment.addVariables(Environment.Variable.newBuilder() + .setName(key) + .setValue(value) + .build()) } val memory = Resource.newBuilder() .setName("mem") diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala index 61a70beaf1..50b086125a 100644 --- a/core/src/main/scala/spark/Serializer.scala +++ b/core/src/main/scala/spark/serializer/Serializer.scala @@ -1,23 +1,21 @@ -package spark +package spark.serializer -import java.io.{EOFException, InputStream, OutputStream} import java.nio.ByteBuffer -import java.nio.channels.Channels - +import java.io.{EOFException, InputStream, OutputStream} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - import spark.util.ByteBufferInputStream /** - * A serializer. Because some serialization libraries are not thread safe, this class is used to - * create SerializerInstances that do the actual serialization. + * A serializer. Because some serialization libraries are not thread safe, this class is used to + * create [[spark.serializer.SerializerInstance]] objects that do the actual serialization and are + * guaranteed to only be called from one thread at a time. */ trait Serializer { def newInstance(): SerializerInstance } /** - * An instance of the serializer, for use by one thread at a time. + * An instance of a serializer, for use by one thread at a time. */ trait SerializerInstance { def serialize[T](t: T): ByteBuffer @@ -43,7 +41,7 @@ trait SerializerInstance { def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { // Default implementation uses deserializeStream buffer.rewind() - deserializeStream(new ByteBufferInputStream(buffer)).toIterator + deserializeStream(new ByteBufferInputStream(buffer)).asIterator } } @@ -51,7 +49,7 @@ trait SerializerInstance { * A stream for writing serialized objects. */ trait SerializationStream { - def writeObject[T](t: T): Unit + def writeObject[T](t: T): SerializationStream def flush(): Unit def close(): Unit @@ -74,7 +72,7 @@ trait DeserializationStream { * Read the elements of this stream through an iterator. This can only be called once, as * reading each element will consume data from the input source. */ - def toIterator: Iterator[Any] = new Iterator[Any] { + def asIterator: Iterator[Any] = new Iterator[Any] { var gotNext = false var finished = false var nextValue: Any = null @@ -88,7 +86,7 @@ trait DeserializationStream { } gotNext = true } - + override def hasNext: Boolean = { if (!gotNext) { getNext() diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index ff9914ae25..bd9155ef29 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -1,34 +1,29 @@ package spark.storage -import java.io._ -import java.nio._ -import java.nio.channels.FileChannel.MapMode -import java.util.{HashMap => JHashMap} -import java.util.LinkedHashMap -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.LinkedBlockingQueue -import java.util.Collections - import akka.dispatch.{Await, Future} -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ +import akka.util.Duration -import it.unimi.dsi.fastutil.io._ +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import spark.CacheTracker -import spark.Logging -import spark.Serializer -import spark.SizeEstimator -import spark.SparkEnv -import spark.SparkException -import spark.Utils -import spark.util.ByteBufferInputStream +import java.io.{InputStream, OutputStream, Externalizable, ObjectInput, ObjectOutput} +import java.nio.{MappedByteBuffer, ByteBuffer} +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.collection.JavaConversions._ + +import spark.{CacheTracker, Logging, SizeEstimator, SparkException, Utils} import spark.network._ -import akka.util.Duration +import spark.serializer.Serializer +import spark.util.ByteBufferInputStream +import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} +import sun.nio.ch.DirectBuffer + -class BlockManagerId(var ip: String, var port: Int) extends Externalizable { - def this() = this(null, 0) +private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable { + def this() = this(null, 0) // For deserialization only + + def this(in: ObjectInput) = this(in.readUTF(), in.readInt()) override def writeExternal(out: ObjectOutput) { out.writeUTF(ip) @@ -51,41 +46,76 @@ class BlockManagerId(var ip: String, var port: Int) extends Externalizable { } -case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) +private[spark] +case class BlockException(blockId: String, message: String, ex: Exception = null) +extends Exception(message) -class BlockLocker(numLockers: Int) { +private[spark] class BlockLocker(numLockers: Int) { private val hashLocker = Array.fill(numLockers)(new Object()) - + def getLock(blockId: String): Object = { return hashLocker(math.abs(blockId.hashCode % numLockers)) } } +private[spark] class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, maxMemory: Long) extends Logging { - case class BlockInfo(level: StorageLevel, tellMaster: Boolean) + class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { + var pending: Boolean = true + var size: Long = -1L + + /** Wait for this BlockInfo to be marked as ready (i.e. block is finished writing) */ + def waitForReady() { + if (pending) { + synchronized { + while (pending) this.wait() + } + } + } + + /** Mark this BlockInfo as ready (i.e. block is finished writing) */ + def markReady(sizeInBytes: Long) { + pending = false + size = sizeInBytes + synchronized { + this.notifyAll() + } + } + } private val NUM_LOCKS = 337 private val locker = new BlockLocker(NUM_LOCKS) private val blockInfo = new ConcurrentHashMap[String, BlockInfo]() - private val memoryStore: BlockStore = new MemoryStore(this, maxMemory) - private val diskStore: BlockStore = new DiskStore(this, - System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) - + + private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private[storage] val diskStore: BlockStore = + new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + val connectionManager = new ConnectionManager(0) implicit val futureExecContext = connectionManager.futureExecContext - + val connectionManagerId = connectionManager.id val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) - + // TODO: This will be removed after cacheTracker is removed from the code base. var cacheTracker: CacheTracker = null - initLogging() + // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory + // for receiving shuffle outputs) + val maxBytesInFlight = + System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 + + val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean + val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean + // Whether to compress RDD partitions that are stored serialized + val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean + + val host = System.getProperty("spark.hostname", Utils.localHostName()) initialize() @@ -102,7 +132,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m */ private def initialize() { master.mustRegisterBlockManager( - RegisterBlockManager(blockManagerId, maxMemory, maxMemory)) + RegisterBlockManager(blockManagerId, maxMemory)) BlockManagerWorker.startBlockManagerWorker(this) } @@ -115,36 +145,32 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } /** - * Change storage level for a local block and tell master is necesary. - * If new level is invalid, then block info (if it exists) will be silently removed. + * Tell the master about the current storage status of a block. This will send a heartbeat + * message reflecting the current status, *not* the desired storage level in its block info. + * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk. */ - def setLevel(blockId: String, level: StorageLevel, tellMaster: Boolean = true) { - if (level == null) { - throw new IllegalArgumentException("Storage level is null") - } - - // If there was earlier info about the block, then use earlier tellMaster - val oldInfo = blockInfo.get(blockId) - val newTellMaster = if (oldInfo != null) oldInfo.tellMaster else tellMaster - if (oldInfo != null && oldInfo.tellMaster != tellMaster) { - logWarning("Ignoring tellMaster setting as it is different from earlier setting") - } - - // If level is valid, store the block info, else remove the block info - if (level.isValid) { - blockInfo.put(blockId, new BlockInfo(level, newTellMaster)) - logDebug("Info for block " + blockId + " updated with new level as " + level) - } else { - blockInfo.remove(blockId) - logDebug("Info for block " + blockId + " removed as new level is null or invalid") - } - - // Tell master if necessary - if (newTellMaster) { + def reportBlockStatus(blockId: String) { + locker.getLock(blockId).synchronized { + val curLevel = blockInfo.get(blockId) match { + case null => + StorageLevel.NONE + case info => + info.level match { + case null => + StorageLevel.NONE + case level => + val inMem = level.useMemory && memoryStore.contains(blockId) + val onDisk = level.useDisk && diskStore.contains(blockId) + new StorageLevel(onDisk, inMem, level.deserialized, level.replication) + } + } + master.mustHeartBeat(HeartBeat( + blockManagerId, + blockId, + curLevel, + if (curLevel.useMemory) memoryStore.getSize(blockId) else 0L, + if (curLevel.useDisk) diskStore.getSize(blockId) else 0L)) logDebug("Told master about block " + blockId) - notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0)) - } else { - logDebug("Did not tell master about block " + blockId) } } @@ -174,55 +200,149 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m * Get block from local block manager. */ def getLocal(blockId: String): Option[Iterator[Any]] = { - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } logDebug("Getting local block " + blockId) + + // As an optimization for map output fetches, if the block is for a shuffle, return it + // without acquiring a lock; the disk store never deletes (recent) items so this should work + if (blockId.startsWith("shuffle_")) { + return diskStore.getValues(blockId) match { + case Some(iterator) => + Some(iterator) + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") + } + } + locker.getLock(blockId).synchronized { - - // Check storage level of block - val level = getLevel(blockId) - if (level != null) { - logDebug("Level for block " + blockId + " is " + level + " on local machine") - + val info = blockInfo.get(blockId) + if (info != null) { + info.waitForReady() // In case the block is still being put() by another thread + val level = info.level + logDebug("Level for block " + blockId + " is " + level) + // Look for the block in memory if (level.useMemory) { logDebug("Getting block " + blockId + " from memory") memoryStore.getValues(blockId) match { - case Some(iterator) => { - logDebug("Block " + blockId + " found in memory") + case Some(iterator) => return Some(iterator) - } - case None => { + case None => logDebug("Block " + blockId + " not found in memory") - } } - } else { - logDebug("Not getting block " + blockId + " from memory") } - // Look for block in disk + // Look for block on disk, potentially loading it back into memory if required if (level.useDisk) { logDebug("Getting block " + blockId + " from disk") - diskStore.getValues(blockId) match { - case Some(iterator) => { - logDebug("Block " + blockId + " found in disk") - return Some(iterator) + if (level.useMemory && level.deserialized) { + diskStore.getValues(blockId) match { + case Some(iterator) => + // Put the block back in memory before returning it + // TODO: Consider creating a putValues that also takes in a iterator ? + val elements = new ArrayBuffer[Any] + elements ++= iterator + memoryStore.putValues(blockId, elements, level, true).data match { + case Left(iterator2) => + return Some(iterator2) + case _ => + throw new Exception("Memory store did not return back an iterator") + } + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") } - case None => { - throw new Exception("Block " + blockId + " not found in disk") - return None + } else if (level.useMemory && !level.deserialized) { + // Read it as a byte buffer into memory first, then return it + diskStore.getBytes(blockId) match { + case Some(bytes) => + // Put a copy of the block back in memory before returning it. Note that we can't + // put the ByteBuffer returned by the disk store as that's a memory-mapped file. + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + memoryStore.putBytes(blockId, copyForMemory, level) + bytes.rewind() + return Some(dataDeserialize(blockId, bytes)) + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") + } + } else { + diskStore.getValues(blockId) match { + case Some(iterator) => + return Some(iterator) + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") } } - } else { - logDebug("Not getting block " + blockId + " from disk") } + } else { + logDebug("Block " + blockId + " not registered locally") + } + } + return None + } + /** + * Get block from the local block manager as serialized bytes. + */ + def getLocalBytes(blockId: String): Option[ByteBuffer] = { + // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow + logDebug("Getting local block " + blockId + " as bytes") + + // As an optimization for map output fetches, if the block is for a shuffle, return it + // without acquiring a lock; the disk store never deletes (recent) items so this should work + if (blockId.startsWith("shuffle_")) { + return diskStore.getBytes(blockId) match { + case Some(bytes) => + Some(bytes) + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") + } + } + + locker.getLock(blockId).synchronized { + val info = blockInfo.get(blockId) + if (info != null) { + info.waitForReady() // In case the block is still being put() by another thread + val level = info.level + logDebug("Level for block " + blockId + " is " + level) + + // Look for the block in memory + if (level.useMemory) { + logDebug("Getting block " + blockId + " from memory") + memoryStore.getBytes(blockId) match { + case Some(bytes) => + return Some(bytes) + case None => + logDebug("Block " + blockId + " not found in memory") + } + } + + // Look for block on disk + if (level.useDisk) { + // Read it as a byte buffer into memory first, then return it + diskStore.getBytes(blockId) match { + case Some(bytes) => + if (level.useMemory) { + if (level.deserialized) { + memoryStore.putBytes(blockId, bytes, level) + } else { + // The memory store will hang onto the ByteBuffer, so give it a copy instead of + // the memory-mapped file buffer we got from the disk store + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + memoryStore.putBytes(blockId, copyForMemory, level) + } + } + bytes.rewind() + return Some(bytes) + case None => + throw new Exception("Block " + blockId + " not found on disk, though it should be") + } + } } else { - logDebug("Level for block " + blockId + " not found") + logDebug("Block " + blockId + " not registered locally") } - } - return None + } + return None } /** @@ -243,7 +363,7 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port)) if (data != null) { logDebug("Data is not null: " + data) - return Some(dataDeserialize(data)) + return Some(dataDeserialize(blockId, data)) } logDebug("Data is null") } @@ -261,9 +381,10 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m /** * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined - * fashion as they're received. + * fashion as they're received. Expects a size in bytes to be provided for each block fetched, + * so that we can control the maxMegabytesInFlight for the fetch. */ - def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[String])]) + def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]) : Iterator[(String, Option[Iterator[Any]])] = { if (blocksByAddress == null) { @@ -272,70 +393,128 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m val totalBlocks = blocksByAddress.map(_._2.size).sum logDebug("Getting " + totalBlocks + " blocks") var startTime = System.currentTimeMillis - val results = new LinkedBlockingQueue[(String, Option[Iterator[Any]])] val localBlockIds = new ArrayBuffer[String]() - val remoteBlockIds = new ArrayBuffer[String]() - val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]() + val remoteBlockIds = new HashSet[String]() - // Split local and remote blocks - for ((address, blockIds) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockIds - } else { - remoteBlockIds ++= blockIds - remoteBlockIdsPerLocation(address) = blockIds - } + // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize + // the block (since we want all deserializaton to happen in the calling thread); can also + // represent a fetch failure if size == -1. + class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 } - - // Start getting remote blocks - for ((bmId, bIds) <- remoteBlockIdsPerLocation) { - val cmId = ConnectionManagerId(bmId.ip, bmId.port) - val blockMessages = bIds.map(bId => BlockMessage.fromGetBlock(GetBlock(bId))) - val blockMessageArray = new BlockMessageArray(blockMessages) + + // A queue to hold our results. + val results = new LinkedBlockingQueue[FetchResult] + + // A request to fetch one or more blocks, complete with their sizes + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + val size = blocks.map(_._2).sum + } + + // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + // the number of bytes in flight is limited to maxBytesInFlight + val fetchRequests = new Queue[FetchRequest] + + // Current bytes in flight from our requests + var bytesInFlight = 0L + + def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip)) + val cmId = new ConnectionManagerId(req.address.ip, req.address.port) + val blockMessageArray = new BlockMessageArray(req.blocks.map { + case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) + }) + bytesInFlight += req.size + val sizeMap = req.blocks.toMap // so we can look up the size of each blockID val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) future.onSuccess { case Some(message) => { val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - blockMessageArray.foreach(blockMessage => { + for (blockMessage <- blockMessageArray) { if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { throw new SparkException( "Unexpected message " + blockMessage.getType + " received from " + cmId) } - val buffer = blockMessage.getData val blockId = blockMessage.getId - val block = dataDeserialize(buffer) - results.put((blockId, Some(block))) + results.put(new FetchResult( + blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData))) logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - }) + } } case None => { - logError("Could not get blocks from " + cmId) - for (blockId <- bIds) { - results.put((blockId, None)) + logError("Could not get block(s) from " + cmId) + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } + } + } + } + + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[FetchRequest] + for ((address, blockInfos) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockInfos.map(_._1) + } else { + remoteBlockIds ++= blockInfos.map(_._1) + // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val minRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(String, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + curBlocks += ((blockId, size)) + curRequestSize += size + if (curRequestSize >= minRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curRequestSize = 0 + curBlocks = new ArrayBuffer[(String, Long)] } } + // Add in the final request + if (!curBlocks.isEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } } } - logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " + - Utils.getUsedTimeMs(startTime) + " ms") + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) - // Get the local blocks while remote blocks are being fetched + // Send out initial requests for blocks, up to our maxBytesInFlight + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + + val numGets = remoteBlockIds.size - fetchRequests.size + logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime)) + + // Get the local blocks while remote blocks are being fetched. Note that it's okay to do + // these all at once because they will just memory-map some files, so they won't consume + // any memory that might exceed our maxBytesInFlight startTime = System.currentTimeMillis - localBlockIds.foreach(id => { - get(id) match { - case Some(block) => { - results.put((id, Some(block))) + for (id <- localBlockIds) { + getLocal(id) match { + case Some(iter) => { + results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight logDebug("Got local block " + id) } case None => { throw new BlockException(id, "Could not get block " + id + " from local machine") } } - }) + } logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - // Return an iterator that will read fetched blocks off the queue as they arrive + // Return an iterator that will read fetched blocks off the queue as they arrive. return new Iterator[(String, Option[Iterator[Any]])] { var resultsGotten = 0 @@ -343,15 +522,30 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m def next(): (String, Option[Iterator[Any]]) = { resultsGotten += 1 - results.take() + val result = results.take() + bytesInFlight -= result.size + if (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) } } } + def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) + : Long = { + val elements = new ArrayBuffer[Any] + elements ++= values + put(blockId, elements, level, tellMaster) + } + /** - * Put a new block of values to the block manager. + * Put a new block of values to the block manager. Returns its (estimated) size in bytes. */ - def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true) { + def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, + tellMaster: Boolean = true) : Long = { + if (blockId == null) { throw new IllegalArgumentException("Block Id is null") } @@ -362,70 +556,97 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m throw new IllegalArgumentException("Storage level is null or invalid") } - val startTimeMs = System.currentTimeMillis - var bytes: ByteBuffer = null - + val oldBlock = blockInfo.get(blockId) + if (oldBlock != null) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + oldBlock.waitForReady() + return oldBlock.size + } + + // Remember the block's storage level so that we can correctly drop it to disk if it needs + // to be dropped right after it got put into memory. Note, however, that other threads will + // not be able to get() this block until we call markReady on its BlockInfo. + val myInfo = new BlockInfo(level, tellMaster) + blockInfo.put(blockId, myInfo) + + val startTimeMs = System.currentTimeMillis + + // If we need to replicate the data, we'll want access to the values, but because our + // put will read the whole iterator, there will be no values left. For the case where + // the put serializes data, we'll remember the bytes, above; but for the case where it + // doesn't, such as deserialized storage, let's rely on the put returning an Iterator. + var valuesAfterPut: Iterator[Any] = null + + // Ditto for the bytes after the put + var bytesAfterPut: ByteBuffer = null + + // Size of the block in bytes (to return to caller) + var size = 0L + locker.getLock(blockId).synchronized { logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") - - // Check and warn if block with same id already exists - if (getLevel(blockId) != null) { - logWarning("Block " + blockId + " already exists in local machine") - return - } - if (level.useMemory && level.useDisk) { - // If saving to both memory and disk, then serialize only once - memoryStore.putValues(blockId, values, level) match { - case Left(newValues) => - diskStore.putValues(blockId, newValues, level) match { - case Right(newBytes) => bytes = newBytes - case _ => throw new Exception("Unexpected return value") - } - case Right(newBytes) => - bytes = newBytes - diskStore.putBytes(blockId, newBytes, level) - } - } else if (level.useMemory) { - // If only save to memory - memoryStore.putValues(blockId, values, level) match { - case Right(newBytes) => bytes = newBytes - case _ => + if (level.useMemory) { + // Save it just to memory first, even if it also has useDisk set to true; we will later + // drop it to disk if the memory store can't hold it. + val res = memoryStore.putValues(blockId, values, level, true) + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case Left(newIterator) => valuesAfterPut = newIterator } } else { - // If only save to disk - diskStore.putValues(blockId, values, level) match { - case Right(newBytes) => bytes = newBytes - case _ => throw new Exception("Unexpected return value") + // Save directly to disk. + val askForBytes = level.replication > 1 // Don't get back the bytes unless we replicate them + val res = diskStore.putValues(blockId, values, level, askForBytes) + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case _ => } } - - // Store the storage level - setLevel(blockId, level, tellMaster) + + // Now that the block is in either the memory or disk store, let other threads read it, + // and tell the master about it. + myInfo.markReady(size) + if (tellMaster) { + reportBlockStatus(blockId) + } } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) - // Replicate block if required + // Replicate block if required if (level.replication > 1) { - if (bytes == null) { - bytes = dataSerialize(values) // serialize the block if not already done + // Serialize the block if not already done + if (bytesAfterPut == null) { + if (valuesAfterPut == null) { + throw new SparkException( + "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + } + bytesAfterPut = dataSerialize(blockId, valuesAfterPut) } - replicate(blockId, bytes, level) + replicate(blockId, bytesAfterPut, level) } + BlockManager.dispose(bytesAfterPut) + // TODO: This code will be removed when CacheTracker is gone. if (blockId.startsWith("rdd")) { - notifyTheCacheTracker(blockId) + notifyCacheTracker(blockId) } logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)) + + return size } /** * Put a new block of serialized bytes to the block manager. */ - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { + def putBytes( + blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { + if (blockId == null) { throw new IllegalArgumentException("Block Id is null") } @@ -435,14 +656,26 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m if (level == null || !level.isValid) { throw new IllegalArgumentException("Storage level is null or invalid") } - - val startTimeMs = System.currentTimeMillis - - // Initiate the replication before storing it locally. This is faster as + + if (blockInfo.containsKey(blockId)) { + logWarning("Block " + blockId + " already exists on this machine; not re-adding it") + return + } + + // Remember the block's storage level so that we can correctly drop it to disk if it needs + // to be dropped right after it got put into memory. Note, however, that other threads will + // not be able to get() this block until we call markReady on its BlockInfo. + val myInfo = new BlockInfo(level, tellMaster) + blockInfo.put(blockId, myInfo) + + val startTimeMs = System.currentTimeMillis + + // Initiate the replication before storing it locally. This is faster as // data is already serialized and ready for sending val replicationFuture = if (level.replication > 1) { + val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper Future { - replicate(blockId, bytes, level) + replicate(blockId, bufferView, level) } } else { null @@ -451,27 +684,29 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m locker.getLock(blockId).synchronized { logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") - if (getLevel(blockId) != null) { - logWarning("Block " + blockId + " already exists") - return - } if (level.useMemory) { + // Store it only in memory at first, even if useDisk is also set to true + bytes.rewind() memoryStore.putBytes(blockId, bytes, level) - } - if (level.useDisk) { + } else { + bytes.rewind() diskStore.putBytes(blockId, bytes, level) } - // Store the storage level - setLevel(blockId, level, tellMaster) + // Now that the block is in either the memory or disk store, let other threads read it, + // and tell the master about it. + myInfo.markReady(bytes.limit) + if (tellMaster) { + reportBlockStatus(blockId) + } } // TODO: This code will be removed when CacheTracker is gone. if (blockId.startsWith("rdd")) { - notifyTheCacheTracker(blockId) + notifyCacheTracker(blockId) } - + // If replication had started, then wait for it to finish if (level.replication > 1) { if (replicationFuture == null) { @@ -480,12 +715,11 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m Await.ready(replicationFuture, Duration.Inf) } - val finishTime = System.currentTimeMillis if (level.replication > 1) { - logDebug("PutBytes for block " + blockId + " with replication took " + + logDebug("PutBytes for block " + blockId + " with replication took " + Utils.getUsedTimeMs(startTimeMs)) } else { - logDebug("PutBytes for block " + blockId + " without replication took " + + logDebug("PutBytes for block " + blockId + " without replication took " + Utils.getUsedTimeMs(startTimeMs)) } } @@ -493,39 +727,43 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m /** * Replicate block to another node. */ - + var cachedPeers: Seq[BlockManagerId] = null private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) - var peers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) - for (peer: BlockManagerId <- peers) { + if (cachedPeers == null) { + cachedPeers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) + } + for (peer: BlockManagerId <- cachedPeers) { val start = System.nanoTime + data.rewind() logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " - + data.array().length + " Bytes. To node: " + peer) + + data.limit() + " Bytes. To node: " + peer) if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), new ConnectionManagerId(peer.ip, peer.port))) { logError("Failed to call syncPutBlock to " + peer) } logDebug("Replicated BlockId " + blockId + " once used " + (System.nanoTime - start) / 1e6 + " s; The size of the data is " + - data.array().length + " bytes.") + data.limit() + " bytes.") } } // TODO: This code will be removed when CacheTracker is gone. - private def notifyTheCacheTracker(key: String) { - val rddInfo = key.split(":") - val rddId: Int = rddInfo(1).toInt - val splitIndex: Int = rddInfo(2).toInt - val host = System.getProperty("spark.hostname", Utils.localHostName) - cacheTracker.notifyTheCacheTrackerFromBlockManager(spark.AddedToCache(rddId, splitIndex, host)) + private def notifyCacheTracker(key: String) { + if (cacheTracker != null) { + val rddInfo = key.split("_") + val rddId: Int = rddInfo(1).toInt + val partition: Int = rddInfo(2).toInt + cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host)) + } } /** * Read a block consisting of a single object. */ def getSingle(blockId: String): Option[Any] = { - get(blockId).map(_.next) + get(blockId).map(_.next()) } /** @@ -536,42 +774,76 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } /** - * Drop block from memory (called when memory store has reached it limit) + * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory + * store reaches its limit and needs to free up space. */ - def dropFromMemory(blockId: String) { + def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { + logInfo("Dropping block " + blockId + " from memory") locker.getLock(blockId).synchronized { - val level = getLevel(blockId) - if (level == null) { - logWarning("Block " + blockId + " cannot be removed from memory as it does not exist") - return + val info = blockInfo.get(blockId) + val level = info.level + if (level.useDisk && !diskStore.contains(blockId)) { + logInfo("Writing block " + blockId + " to disk") + data match { + case Left(elements) => + diskStore.putValues(blockId, elements, level, false) + case Right(bytes) => + diskStore.putBytes(blockId, bytes, level) + } + } + memoryStore.remove(blockId) + if (info.tellMaster) { + reportBlockStatus(blockId) } - if (!level.useMemory) { - logWarning("Block " + blockId + " cannot be removed from memory as it is not in memory") - return + if (!level.useDisk) { + // The block is completely gone from this node; forget it so we can put() it again later. + blockInfo.remove(blockId) } - memoryStore.remove(blockId) - val newLevel = new StorageLevel(level.useDisk, false, level.deserialized, level.replication) - setLevel(blockId, newLevel) } } - def dataSerialize(values: Iterator[Any]): ByteBuffer = { - /*serializer.newInstance().serializeMany(values)*/ + def shouldCompress(blockId: String): Boolean = { + if (blockId.startsWith("shuffle_")) { + compressShuffle + } else if (blockId.startsWith("broadcast_")) { + compressBroadcast + } else if (blockId.startsWith("rdd_")) { + compressRdds + } else { + false // Won't happen in a real cluster, but it can in tests + } + } + + /** + * Wrap an output stream for compression if block compression is enabled for its block type + */ + def wrapForCompression(blockId: String, s: OutputStream): OutputStream = { + if (shouldCompress(blockId)) new LZFOutputStream(s) else s + } + + /** + * Wrap an input stream for compression if block compression is enabled for its block type + */ + def wrapForCompression(blockId: String, s: InputStream): InputStream = { + if (shouldCompress(blockId)) new LZFInputStream(s) else s + } + + def dataSerialize(blockId: String, values: Iterator[Any]): ByteBuffer = { val byteStream = new FastByteArrayOutputStream(4096) - serializer.newInstance().serializeStream(byteStream).writeAll(values).close() + val ser = serializer.newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() byteStream.trim() ByteBuffer.wrap(byteStream.array) } - def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = { - /*serializer.newInstance().deserializeMany(bytes)*/ - val ser = serializer.newInstance() + /** + * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of + * the iterator is reached. + */ + def dataDeserialize(blockId: String, bytes: ByteBuffer): Iterator[Any] = { bytes.rewind() - return ser.deserializeStream(new ByteBufferInputStream(bytes)).toIterator - } - - private def notifyMaster(heartBeat: HeartBeat) { - master.mustHeartBeat(heartBeat) + val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) + serializer.newInstance().deserializeStream(stream).asIterator } def stop() { @@ -583,9 +855,25 @@ class BlockManager(val master: BlockManagerMaster, val serializer: Serializer, m } } -object BlockManager { - def getMaxMemoryFromSystemProperties(): Long = { +private[spark] +object BlockManager extends Logging { + def getMaxMemoryFromSystemProperties: Long = { val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble (Runtime.getRuntime.maxMemory * memoryFraction).toLong } + + /** + * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that + * might cause errors if one attempts to read from the unmapped buffer, but it's better than + * waiting for the GC to find it because that could lead to huge numbers of open files. There's + * unfortunately no standard API to do this. + */ + def dispose(buffer: ByteBuffer) { + if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { + logDebug("Unmapping " + buffer) + if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { + buffer.asInstanceOf[DirectBuffer].cleaner().clean() + } + } + } } diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index 9f03c5a32c..7bfa31ac3d 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -3,37 +3,35 @@ package spark.storage import java.io._ import java.util.{HashMap => JHashMap} -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random import akka.actor._ import akka.dispatch._ import akka.pattern.ask import akka.remote._ -import akka.util.Duration -import akka.util.Timeout +import akka.util.{Duration, Timeout} import akka.util.duration._ -import spark.Logging -import spark.SparkException -import spark.Utils +import spark.{Logging, SparkException, Utils} + +private[spark] sealed trait ToBlockManagerMaster +private[spark] case class RegisterBlockManager( blockManagerId: BlockManagerId, - maxMemSize: Long, - maxDiskSize: Long) + maxMemSize: Long) extends ToBlockManagerMaster - + +private[spark] class HeartBeat( var blockManagerId: BlockManagerId, var blockId: String, var storageLevel: StorageLevel, - var deserializedSize: Long, - var size: Long) + var memSize: Long, + var diskSize: Long) extends ToBlockManagerMaster with Externalizable { @@ -43,8 +41,8 @@ class HeartBeat( blockManagerId.writeExternal(out) out.writeUTF(blockId) storageLevel.writeExternal(out) - out.writeInt(deserializedSize.toInt) - out.writeInt(size.toInt) + out.writeInt(memSize.toInt) + out.writeInt(diskSize.toInt) } override def readExternal(in: ObjectInput) { @@ -53,84 +51,101 @@ class HeartBeat( blockId = in.readUTF() storageLevel = new StorageLevel() storageLevel.readExternal(in) - deserializedSize = in.readInt() - size = in.readInt() + memSize = in.readInt() + diskSize = in.readInt() } } +private[spark] object HeartBeat { def apply(blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, - deserializedSize: Long, - size: Long): HeartBeat = { - new HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) + memSize: Long, + diskSize: Long): HeartBeat = { + new HeartBeat(blockManagerId, blockId, storageLevel, memSize, diskSize) } - // For pattern-matching def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.deserializedSize, h.size)) + Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) } } +private[spark] case class GetLocations(blockId: String) extends ToBlockManagerMaster +private[spark] case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster +private[spark] case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster +private[spark] case class RemoveHost(host: String) extends ToBlockManagerMaster +private[spark] case object StopBlockManagerMaster extends ToBlockManagerMaster -class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { +private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { class BlockManagerInfo( + val blockManagerId: BlockManagerId, timeMs: Long, - maxMem: Long, - maxDisk: Long) { + val maxMem: Long) { private var lastSeenMs = timeMs - private var remainedMem = maxMem - private var remainedDisk = maxDisk + private var remainingMem = maxMem private val blocks = new JHashMap[String, StorageLevel] + + logInfo("Registering block manager %s:%d with %s RAM".format( + blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem))) def updateLastSeenMs() { lastSeenMs = System.currentTimeMillis() / 1000 } - def addBlock(blockId: String, storageLevel: StorageLevel, deserializedSize: Long, size: Long) = - synchronized { + def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long) + : Unit = synchronized { + updateLastSeenMs() if (blocks.containsKey(blockId)) { - val oriLevel: StorageLevel = blocks.get(blockId) + // The block exists on the slave already. + val originalLevel: StorageLevel = blocks.get(blockId) - if (oriLevel.deserialized) { - remainedMem += deserializedSize - } - if (oriLevel.useMemory) { - remainedMem += size - } - if (oriLevel.useDisk) { - remainedDisk += size + if (originalLevel.useMemory) { + remainingMem += memSize } } - if (storageLevel.isValid) { + if (storageLevel.isValid) { + // isValid means it is either stored in-memory or on-disk. blocks.put(blockId, storageLevel) - if (storageLevel.deserialized) { - remainedMem -= deserializedSize - } if (storageLevel.useMemory) { - remainedMem -= size + remainingMem -= memSize + logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + Utils.memoryBytesToString(remainingMem))) } if (storageLevel.useDisk) { - remainedDisk -= size + logInfo("Added %s on disk on %s:%d (size: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) } - } else { + } else if (blocks.containsKey(blockId)) { + // If isValid is not true, drop the block. + val originalLevel: StorageLevel = blocks.get(blockId) blocks.remove(blockId) + if (originalLevel.useMemory) { + remainingMem += memSize + logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize), + Utils.memoryBytesToString(remainingMem))) + } + if (originalLevel.useDisk) { + logInfo("Removed %s on %s:%d on disk (size: %s)".format( + blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize))) + } } } @@ -139,15 +154,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } def getRemainedMem: Long = { - return remainedMem - } - - def getRemainedDisk: Long = { - return remainedDisk + return remainingMem } override def toString: String = { - return "BlockManagerInfo " + timeMs + " " + remainedMem + " " + remainedDisk + return "BlockManagerInfo " + timeMs + " " + remainingMem } def clear() { @@ -171,8 +182,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } def receive = { - case RegisterBlockManager(blockManagerId, maxMemSize, maxDiskSize) => - register(blockManagerId, maxMemSize, maxDiskSize) + case RegisterBlockManager(blockManagerId, maxMemSize) => + register(blockManagerId, maxMemSize) case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) => heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) @@ -200,16 +211,15 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { logInfo("Got unknown message: " + other) } - private def register(blockManagerId: BlockManagerId, maxMemSize: Long, maxDiskSize: Long) { + private def register(blockManagerId: BlockManagerId, maxMemSize: Long) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) - logInfo("Got Register Msg from " + blockManagerId) if (blockManagerId.ip == Utils.localHostName() && !isLocal) { logInfo("Got Register Msg from master node, don't register it") } else { blockManagerInfo += (blockManagerId -> new BlockManagerInfo( - System.currentTimeMillis() / 1000, maxMemSize, maxDiskSize)) + blockManagerId, System.currentTimeMillis() / 1000, maxMemSize)) } logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) sender ! true @@ -219,8 +229,8 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { blockManagerId: BlockManagerId, blockId: String, storageLevel: StorageLevel, - deserializedSize: Long, - size: Long) { + memSize: Long, + diskSize: Long) { val startTimeMs = System.currentTimeMillis() val tmp = " " + blockManagerId + " " + blockId + " " @@ -231,7 +241,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! true } - blockManagerInfo(blockManagerId).addBlock(blockId, storageLevel, deserializedSize, size) + blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) var locations: HashSet[BlockManagerId] = null if (blockInfo.containsKey(blockId)) { @@ -329,7 +339,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } } -class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean) +private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean) extends Logging { val AKKA_ACTOR_NAME: String = "BlockMasterManager" @@ -386,10 +396,12 @@ class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: B } def mustRegisterBlockManager(msg: RegisterBlockManager) { + logInfo("Trying to register BlockManager") while (! syncRegisterBlockManager(msg)) { logWarning("Failed to register " + msg) Thread.sleep(REQUEST_RETRY_INTERVAL_MS) } + logInfo("Done registering BlockManager") } def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index d74cdb38a8..d2985559c1 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -1,25 +1,23 @@ package spark.storage -import java.nio._ +import java.nio.ByteBuffer import scala.actors._ import scala.actors.Actor._ import scala.actors.remote._ - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random -import spark.Logging -import spark.Utils -import spark.SparkEnv +import spark.{Logging, Utils, SparkEnv} import spark.network._ /** - * This should be changed to use event model late. + * A network interface for BlockManager. Each slave should have one + * BlockManagerWorker. + * + * TODO: Use event model. */ -class BlockManagerWorker(val blockManager: BlockManager) extends Logging { +private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging { initLogging() blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) @@ -32,11 +30,10 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging { logDebug("Handling as a buffer message " + bufferMessage) val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage _).filter(_ != None).map(_.get) - /*logDebug("Processed block messages")*/ + val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) return Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => logError("Exception handling buffer message: " + e.getMessage) + case e: Exception => logError("Exception handling buffer message", e) return None } } @@ -51,13 +48,13 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging { blockMessage.getType match { case BlockMessage.TYPE_PUT_BLOCK => { val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - logInfo("Received [" + pB + "]") + logDebug("Received [" + pB + "]") putBlock(pB.id, pB.data, pB.level) return None } case BlockMessage.TYPE_GET_BLOCK => { val gB = new GetBlock(blockMessage.getId) - logInfo("Received [" + gB + "]") + logDebug("Received [" + gB + "]") val buffer = getBlock(gB.id) if (buffer == null) { return None @@ -73,22 +70,15 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging { logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) blockManager.putBytes(id, bytes, level) logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.array().length) + + " with data size: " + bytes.limit) } private def getBlock(id: String): ByteBuffer = { val startTimeMs = System.currentTimeMillis() - logDebug("Getblock " + id + " started from " + startTimeMs) - val block = blockManager.getLocal(id) - val buffer = block match { - case Some(tValues) => { - val values = tValues - val buffer = blockManager.dataSerialize(values) - buffer - } - case None => { - null - } + logDebug("GetBlock " + id + " started from " + startTimeMs) + val buffer = blockManager.getLocalBytes(id) match { + case Some(bytes) => bytes + case None => null } logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) @@ -96,7 +86,7 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging { } } -object BlockManagerWorker extends Logging { +private[spark] object BlockManagerWorker extends Logging { private var blockManagerWorker: BlockManagerWorker = null private val DATA_TRANSFER_TIME_OUT_MS: Long = 500 private val REQUEST_RETRY_INTERVAL_MS: Long = 1000 diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala index 0b2ed69e07..3f234df654 100644 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/spark/storage/BlockMessage.scala @@ -1,6 +1,6 @@ package spark.storage -import java.nio._ +import java.nio.ByteBuffer import scala.collection.mutable.StringBuilder import scala.collection.mutable.ArrayBuffer @@ -8,11 +8,11 @@ import scala.collection.mutable.ArrayBuffer import spark._ import spark.network._ -case class GetBlock(id: String) -case class GotBlock(id: String, data: ByteBuffer) -case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) +private[spark] case class GetBlock(id: String) +private[spark] case class GotBlock(id: String, data: ByteBuffer) +private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) -class BlockMessage() extends Logging{ +private[spark] class BlockMessage() { // Un-initialized: typ = 0 // GetBlock: typ = 1 // GotBlock: typ = 2 @@ -22,8 +22,6 @@ class BlockMessage() extends Logging{ private var data: ByteBuffer = null private var level: StorageLevel = null - initLogging() - def set(getBlock: GetBlock) { typ = BlockMessage.TYPE_GET_BLOCK id = getBlock.id @@ -62,8 +60,6 @@ class BlockMessage() extends Logging{ } id = idBuilder.toString() - logDebug("Set from buffer Result: " + typ + " " + id) - logDebug("Buffer position is " + buffer.position) if (typ == BlockMessage.TYPE_PUT_BLOCK) { val booleanInt = buffer.getInt() @@ -77,23 +73,18 @@ class BlockMessage() extends Logging{ } data.put(buffer) data.flip() - logDebug("Set from buffer Result 2: " + level + " " + data) } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { val dataLength = buffer.getInt() - logDebug("Data length is "+ dataLength) - logDebug("Buffer position is " + buffer.position) data = ByteBuffer.allocate(dataLength) if (dataLength != buffer.remaining) { throw new Exception("Error parsing buffer") } data.put(buffer) data.flip() - logDebug("Set from buffer Result 3: " + data) } val finishTime = System.currentTimeMillis - logDebug("Converted " + id + " from bytebuffer in " + (finishTime - startTime) / 1000.0 + " s") } def set(bufferMsg: BufferMessage) { @@ -145,8 +136,6 @@ class BlockMessage() extends Logging{ buffers += data } - logDebug("Start to log buffers.") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) /* println() println("BlockMessage: ") @@ -160,7 +149,6 @@ class BlockMessage() extends Logging{ println() */ val finishTime = System.currentTimeMillis - logDebug("Converted " + id + " to buffer message in " + (finishTime - startTime) / 1000.0 + " s") return Message.createBufferMessage(buffers) } @@ -170,7 +158,7 @@ class BlockMessage() extends Logging{ } } -object BlockMessage { +private[spark] object BlockMessage { val TYPE_NON_INITIALIZED: Int = 0 val TYPE_GET_BLOCK: Int = 1 val TYPE_GOT_BLOCK: Int = 2 @@ -208,7 +196,7 @@ object BlockMessage { def main(args: Array[String]) { val B = new BlockMessage() - B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.DISK_AND_MEMORY_2)) + B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) val bMsg = B.toBufferMessage val C = new BlockMessage() C.set(bMsg) diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala index 497a19856e..a25decb123 100644 --- a/core/src/main/scala/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala @@ -1,5 +1,6 @@ package spark.storage -import java.nio._ + +import java.nio.ByteBuffer import scala.collection.mutable.StringBuilder import scala.collection.mutable.ArrayBuffer @@ -7,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer import spark._ import spark.network._ +private[spark] class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging { def this(bm: BlockMessage) = this(Array(bm)) @@ -84,7 +86,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockM } } -object BlockMessageArray { +private[spark] object BlockMessageArray { def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { val newBlockMessageArray = new BlockMessageArray() @@ -98,7 +100,7 @@ object BlockMessageArray { if (i % 2 == 0) { val buffer = ByteBuffer.allocate(100) buffer.clear - BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY)) + BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER)) } else { BlockMessage.fromGetBlock(GetBlock(i.toString)) } diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala index 17f4f51aa8..096bf8bdd9 100644 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ b/core/src/main/scala/spark/storage/BlockStore.scala @@ -1,25 +1,31 @@ package spark.storage -import spark.{Utils, Logging, Serializer, SizeEstimator} -import scala.collection.mutable.ArrayBuffer -import java.io.{File, RandomAccessFile} import java.nio.ByteBuffer -import java.nio.channels.FileChannel.MapMode -import java.util.{UUID, LinkedHashMap} -import java.util.concurrent.Executors -import java.util.concurrent.ConcurrentHashMap -import it.unimi.dsi.fastutil.io._ -import java.util.concurrent.ArrayBlockingQueue +import scala.collection.mutable.ArrayBuffer + +import spark.Logging /** * Abstract class to store blocks */ -abstract class BlockStore(blockManager: BlockManager) extends Logging { - initLogging() - - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) - - def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] +private[spark] +abstract class BlockStore(val blockManager: BlockManager) extends Logging { + def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) + + /** + * Put in a block and, possibly, also return its content as either bytes or another Iterator. + * This is used to efficiently write the values to multiple locations (e.g. for replication). + * + * @return a PutResult that contains the size of the data, as well as the values put if + * returnValues is true (if not, the result's data field can be null) + */ + def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, + returnValues: Boolean) : PutResult + + /** + * Return the size of a block in bytes. + */ + def getSize(blockId: String): Long def getBytes(blockId: String): Option[ByteBuffer] @@ -27,284 +33,7 @@ abstract class BlockStore(blockManager: BlockManager) extends Logging { def remove(blockId: String) - def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values) - - def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes) + def contains(blockId: String): Boolean def clear() { } } - -/** - * Class to store blocks in memory - */ -class MemoryStore(blockManager: BlockManager, maxMemory: Long) - extends BlockStore(blockManager) { - - case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false) - - private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true) - private var currentMemory = 0L - - //private val blockDropper = Executors.newSingleThreadExecutor() - private val blocksToDrop = new ArrayBlockingQueue[String](10000, true) - private val blockDropper = new Thread("memory store - block dropper") { - override def run() { - try{ - while (true) { - val blockId = blocksToDrop.take() - logDebug("Block " + blockId + " ready to be dropped") - blockManager.dropFromMemory(blockId) - } - } catch { - case ie: InterruptedException => - logInfo("Shutting down block dropper") - } - } - } - blockDropper.start() - - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { - if (level.deserialized) { - bytes.rewind() - val values = dataDeserialize(bytes) - val elements = new ArrayBuffer[Any] - elements ++= values - val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) - ensureFreeSpace(sizeEstimate) - val entry = new Entry(elements, sizeEstimate, true) - memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += sizeEstimate - logDebug("Block " + blockId + " stored as values to memory") - } else { - val entry = new Entry(bytes, bytes.array().length, false) - ensureFreeSpace(bytes.array.length) - memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += bytes.array().length - logDebug("Block " + blockId + " stored as " + bytes.array().length + " bytes to memory") - } - } - - def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = { - if (level.deserialized) { - val elements = new ArrayBuffer[Any] - elements ++= values - val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) - ensureFreeSpace(sizeEstimate) - val entry = new Entry(elements, sizeEstimate, true) - memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += sizeEstimate - logDebug("Block " + blockId + " stored as values to memory") - return Left(elements.iterator) - } else { - val bytes = dataSerialize(values) - ensureFreeSpace(bytes.array().length) - val entry = new Entry(bytes, bytes.array().length, false) - memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += bytes.array().length - logDebug("Block " + blockId + " stored as " + bytes.array.length + " bytes to memory") - return Right(bytes) - } - } - - def getBytes(blockId: String): Option[ByteBuffer] = { - throw new UnsupportedOperationException("Not implemented") - } - - def getValues(blockId: String): Option[Iterator[Any]] = { - val entry = memoryStore.synchronized { memoryStore.get(blockId) } - if (entry == null) { - return None - } - if (entry.deserialized) { - return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator) - } else { - return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer])) - } - } - - def remove(blockId: String) { - memoryStore.synchronized { - val entry = memoryStore.get(blockId) - if (entry != null) { - memoryStore.remove(blockId) - currentMemory -= entry.size - logDebug("Block " + blockId + " of size " + entry.size + " dropped from memory") - } else { - logWarning("Block " + blockId + " could not be removed as it doesnt exist") - } - } - } - - override def clear() { - memoryStore.synchronized { - memoryStore.clear() - } - //blockDropper.shutdown() - blockDropper.interrupt() - logInfo("MemoryStore cleared") - } - - private def ensureFreeSpace(space: Long) { - logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( - space, currentMemory, maxMemory)) - - if (maxMemory - currentMemory < space) { - - val selectedBlocks = new ArrayBuffer[String]() - var selectedMemory = 0L - - memoryStore.synchronized { - val iter = memoryStore.entrySet().iterator() - while (maxMemory - (currentMemory - selectedMemory) < space && iter.hasNext) { - val pair = iter.next() - val blockId = pair.getKey - val entry = pair.getValue() - if (!entry.dropPending) { - selectedBlocks += blockId - entry.dropPending = true - } - selectedMemory += pair.getValue.size - logDebug("Block " + blockId + " selected for dropping") - } - } - - logDebug("" + selectedBlocks.size + " new blocks selected for dropping, " + - blocksToDrop.size + " blocks pending") - var i = 0 - while (i < selectedBlocks.size) { - blocksToDrop.add(selectedBlocks(i)) - i += 1 - } - selectedBlocks.clear() - } - } -} - - -/** - * Class to store blocks in disk - */ -class DiskStore(blockManager: BlockManager, rootDirs: String) - extends BlockStore(blockManager) { - - val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - val localDirs = createLocalDirs() - var lastLocalDirUsed = 0 - - addShutdownHook() - - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { - logDebug("Attempting to put block " + blockId) - val startTime = System.currentTimeMillis - val file = createFile(blockId) - if (file != null) { - val channel = new RandomAccessFile(file, "rw").getChannel() - val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.array.length) - buffer.put(bytes.array) - channel.close() - val finishTime = System.currentTimeMillis - logDebug("Block " + blockId + " stored to file of " + bytes.array.length + " bytes to disk in " + (finishTime - startTime) + " ms") - } else { - logError("File not created for block " + blockId) - } - } - - def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = { - val bytes = dataSerialize(values) - logDebug("Converted block " + blockId + " to " + bytes.array.length + " bytes") - putBytes(blockId, bytes, level) - return Right(bytes) - } - - def getBytes(blockId: String): Option[ByteBuffer] = { - val file = getFile(blockId) - val length = file.length().toInt - val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = ByteBuffer.allocate(length) - bytes.put(channel.map(MapMode.READ_WRITE, 0, length)) - return Some(bytes) - } - - def getValues(blockId: String): Option[Iterator[Any]] = { - val file = getFile(blockId) - val length = file.length().toInt - val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = channel.map(MapMode.READ_ONLY, 0, length) - val buffer = dataDeserialize(bytes) - channel.close() - return Some(buffer) - } - - def remove(blockId: String) { - throw new UnsupportedOperationException("Not implemented") - } - - private def createFile(blockId: String): File = { - val file = getFile(blockId) - if (file == null) { - lastLocalDirUsed = (lastLocalDirUsed + 1) % localDirs.size - val newFile = new File(localDirs(lastLocalDirUsed), blockId) - newFile.getParentFile.mkdirs() - return newFile - } else { - logError("File for block " + blockId + " already exists on disk, " + file) - return null - } - } - - private def getFile(blockId: String): File = { - logDebug("Getting file for block " + blockId) - // Search for the file in all the local directories, only one of them should have the file - val files = localDirs.map(localDir => new File(localDir, blockId)).filter(_.exists) - if (files.size > 1) { - throw new Exception("Multiple files for same block " + blockId + " exists: " + - files.map(_.toString).reduceLeft(_ + ", " + _)) - return null - } else if (files.size == 0) { - return null - } else { - logDebug("Got file " + files(0) + " of size " + files(0).length + " bytes") - return files(0) - } - } - - private def createLocalDirs(): Seq[File] = { - logDebug("Creating local directories at root dirs '" + rootDirs + "'") - rootDirs.split("[;,:]").map(rootDir => { - var foundLocalDir: Boolean = false - var localDir: File = null - var localDirUuid: UUID = null - var tries = 0 - while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { - tries += 1 - try { - localDirUuid = UUID.randomUUID() - localDir = new File(rootDir, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + - " attempts to create local dir in " + rootDir) - System.exit(1) - } - logDebug("Created local directory at " + localDir) - localDir - }) - } - - private def addShutdownHook() { - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { - override def run() { - logDebug("Shutdown hook called") - localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) - } - }) - } -} diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala new file mode 100644 index 0000000000..8ba64e4b76 --- /dev/null +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -0,0 +1,180 @@ +package spark.storage + +import java.nio.ByteBuffer +import java.io.{File, FileOutputStream, RandomAccessFile} +import java.nio.channels.FileChannel.MapMode +import java.util.{Random, Date} +import java.text.SimpleDateFormat + +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream + +import scala.collection.mutable.ArrayBuffer + +import spark.Utils + +/** + * Stores BlockManager blocks on disk. + */ +private class DiskStore(blockManager: BlockManager, rootDirs: String) + extends BlockStore(blockManager) { + + val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + + // Create one local directory for each path mentioned in spark.local.dir; then, inside this + // directory, create multiple subdirectories that we will hash files into, in order to avoid + // having really large inodes at the top level. + val localDirs = createLocalDirs() + val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + + addShutdownHook() + + override def getSize(blockId: String): Long = { + getFile(blockId).length() + } + + override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + logDebug("Attempting to put block " + blockId) + val startTime = System.currentTimeMillis + val file = createFile(blockId) + val channel = new RandomAccessFile(file, "rw").getChannel() + while (bytes.remaining > 0) { + channel.write(bytes) + } + channel.close() + val finishTime = System.currentTimeMillis + logDebug("Block %s stored as %s file on disk in %d ms".format( + blockId, Utils.memoryBytesToString(bytes.limit), (finishTime - startTime))) + } + + override def putValues( + blockId: String, + values: ArrayBuffer[Any], + level: StorageLevel, + returnValues: Boolean) + : PutResult = { + + logDebug("Attempting to write values for block " + blockId) + val startTime = System.currentTimeMillis + val file = createFile(blockId) + val fileOut = blockManager.wrapForCompression(blockId, + new FastBufferedOutputStream(new FileOutputStream(file))) + val objOut = blockManager.serializer.newInstance().serializeStream(fileOut) + objOut.writeAll(values.iterator) + objOut.close() + val length = file.length() + logDebug("Block %s stored as %s file on disk in %d ms".format( + blockId, Utils.memoryBytesToString(length), (System.currentTimeMillis - startTime))) + + if (returnValues) { + // Return a byte buffer for the contents of the file + val channel = new RandomAccessFile(file, "r").getChannel() + val buffer = channel.map(MapMode.READ_ONLY, 0, length) + channel.close() + PutResult(length, Right(buffer)) + } else { + PutResult(length, null) + } + } + + override def getBytes(blockId: String): Option[ByteBuffer] = { + val file = getFile(blockId) + val length = file.length().toInt + val channel = new RandomAccessFile(file, "r").getChannel() + val bytes = channel.map(MapMode.READ_ONLY, 0, length) + channel.close() + Some(bytes) + } + + override def getValues(blockId: String): Option[Iterator[Any]] = { + getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) + } + + override def remove(blockId: String) { + val file = getFile(blockId) + if (file.exists()) { + file.delete() + } + } + + override def contains(blockId: String): Boolean = { + getFile(blockId).exists() + } + + private def createFile(blockId: String): File = { + val file = getFile(blockId) + if (file.exists()) { + throw new Exception("File for block " + blockId + " already exists on disk: " + file) + } + file + } + + private def getFile(blockId: String): File = { + logDebug("Getting file for block " + blockId) + + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = math.abs(blockId.hashCode) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + + // Create the subdirectory if it doesn't already exist + var subDir = subDirs(dirId)(subDirId) + if (subDir == null) { + subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + newDir.mkdir() + subDirs(dirId)(subDirId) = newDir + newDir + } + } + } + + new File(subDir, blockId) + } + + private def createLocalDirs(): Array[File] = { + logDebug("Creating local directories at root dirs '" + rootDirs + "'") + val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + rootDirs.split(",").map(rootDir => { + var foundLocalDir: Boolean = false + var localDir: File = null + var localDirId: String = null + var tries = 0 + val rand = new Random() + while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { + tries += 1 + try { + localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) + localDir = new File(rootDir, "spark-local-" + localDirId) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + + " attempts to create local dir in " + rootDir) + System.exit(1) + } + logInfo("Created local directory at " + localDir) + localDir + }) + } + + private def addShutdownHook() { + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { + override def run() { + logDebug("Shutdown hook called") + localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) + } + }) + } +} diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala new file mode 100644 index 0000000000..773970446a --- /dev/null +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -0,0 +1,218 @@ +package spark.storage + +import java.util.LinkedHashMap +import java.util.concurrent.ArrayBlockingQueue +import spark.{SizeEstimator, Utils} +import java.nio.ByteBuffer +import collection.mutable.ArrayBuffer + +/** + * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as + * serialized ByteBuffers. + */ +private class MemoryStore(blockManager: BlockManager, maxMemory: Long) + extends BlockStore(blockManager) { + + case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false) + + private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) + private var currentMemory = 0L + + logInfo("MemoryStore started with capacity %s.".format(Utils.memoryBytesToString(maxMemory))) + + def freeMemory: Long = maxMemory - currentMemory + + override def getSize(blockId: String): Long = { + synchronized { + entries.get(blockId).size + } + } + + override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + if (level.deserialized) { + bytes.rewind() + val values = blockManager.dataDeserialize(blockId, bytes) + val elements = new ArrayBuffer[Any] + elements ++= values + val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) + tryToPut(blockId, elements, sizeEstimate, true) + } else { + val entry = new Entry(bytes, bytes.limit, false) + ensureFreeSpace(blockId, bytes.limit) + synchronized { entries.put(blockId, entry) } + tryToPut(blockId, bytes, bytes.limit, false) + } + } + + override def putValues( + blockId: String, + values: ArrayBuffer[Any], + level: StorageLevel, + returnValues: Boolean) + : PutResult = { + + if (level.deserialized) { + val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef]) + tryToPut(blockId, values, sizeEstimate, true) + PutResult(sizeEstimate, Left(values.iterator)) + } else { + val bytes = blockManager.dataSerialize(blockId, values.iterator) + tryToPut(blockId, bytes, bytes.limit, false) + PutResult(bytes.limit(), Right(bytes)) + } + } + + override def getBytes(blockId: String): Option[ByteBuffer] = { + val entry = synchronized { + entries.get(blockId) + } + if (entry == null) { + None + } else if (entry.deserialized) { + Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)) + } else { + Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data + } + } + + override def getValues(blockId: String): Option[Iterator[Any]] = { + val entry = synchronized { + entries.get(blockId) + } + if (entry == null) { + None + } else if (entry.deserialized) { + Some(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator) + } else { + val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data + Some(blockManager.dataDeserialize(blockId, buffer)) + } + } + + override def remove(blockId: String) { + synchronized { + val entry = entries.get(blockId) + if (entry != null) { + entries.remove(blockId) + currentMemory -= entry.size + logInfo("Block %s of size %d dropped from memory (free %d)".format( + blockId, entry.size, freeMemory)) + } else { + logWarning("Block " + blockId + " could not be removed as it does not exist") + } + } + } + + override def clear() { + synchronized { + entries.clear() + } + logInfo("MemoryStore cleared") + } + + /** + * Return the RDD ID that a given block ID is from, or null if it is not an RDD block. + */ + private def getRddId(blockId: String): String = { + if (blockId.startsWith("rdd_")) { + blockId.split('_')(1) + } else { + null + } + } + + /** + * Try to put in a set of values, if we can free up enough space. The value should either be + * an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated) + * size must also be passed by the caller. + */ + private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { + synchronized { + if (ensureFreeSpace(blockId, size)) { + val entry = new Entry(value, size, deserialized) + entries.put(blockId, entry) + currentMemory += size + if (deserialized) { + logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format( + blockId, Utils.memoryBytesToString(size), Utils.memoryBytesToString(freeMemory))) + } else { + logInfo("Block %s stored as bytes to memory (size %s, free %s)".format( + blockId, Utils.memoryBytesToString(size), Utils.memoryBytesToString(freeMemory))) + } + true + } else { + // Tell the block manager that we couldn't put it in memory so that it can drop it to + // disk if the block allows disk storage. + val data = if (deserialized) { + Left(value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(value.asInstanceOf[ByteBuffer].duplicate()) + } + blockManager.dropFromMemory(blockId, data) + false + } + } + } + + /** + * Tries to free up a given amount of space to store a particular block, but can fail and return + * false if either the block is bigger than our memory or it would require replacing another + * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that + * don't fit into memory that we want to avoid). + * + * Assumes that a lock on the MemoryStore is held by the caller. (Otherwise, the freed space + * might fill up before the caller puts in their new value.) + */ + private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { + logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( + space, currentMemory, maxMemory)) + + if (space > maxMemory) { + logInfo("Will not store " + blockIdToAdd + " as it is larger than our memory limit") + return false + } + + // TODO: This should relinquish the lock on the MemoryStore while flushing out old blocks + // in order to allow parallelism in writing to disk + if (maxMemory - currentMemory < space) { + val rddToAdd = getRddId(blockIdToAdd) + val selectedBlocks = new ArrayBuffer[String]() + var selectedMemory = 0L + + val iterator = entries.entrySet().iterator() + while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { + val pair = iterator.next() + val blockId = pair.getKey + if (rddToAdd != null && rddToAdd == getRddId(blockId)) { + logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + + "block from the same RDD") + return false + } + selectedBlocks += blockId + selectedMemory += pair.getValue.size + } + + if (maxMemory - (currentMemory - selectedMemory) >= space) { + logInfo(selectedBlocks.size + " blocks selected for dropping") + for (blockId <- selectedBlocks) { + val entry = entries.get(blockId) + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + blockManager.dropFromMemory(blockId, data) + } + return true + } else { + return false + } + } + return true + } + + override def contains(blockId: String): Boolean = { + synchronized { entries.containsKey(blockId) } + } +} + diff --git a/core/src/main/scala/spark/storage/PutResult.scala b/core/src/main/scala/spark/storage/PutResult.scala new file mode 100644 index 0000000000..76f236057b --- /dev/null +++ b/core/src/main/scala/spark/storage/PutResult.scala @@ -0,0 +1,9 @@ +package spark.storage + +import java.nio.ByteBuffer + +/** + * Result of adding a block into a BlockStore. Contains its estimated size, and possibly the + * values put if the caller asked for them to be returned (e.g. for chaining replication) + */ +private[spark] case class PutResult(size: Long, data: Either[Iterator[_], ByteBuffer]) diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala index f067a2a6c5..c497f03e0c 100644 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -1,7 +1,14 @@ package spark.storage -import java.io._ +import java.io.{Externalizable, ObjectInput, ObjectOutput} +/** + * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, + * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory + * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. + * The [[spark.storage.StorageLevel$]] singleton object contains some static constants for + * commonly useful storage levels. + */ class StorageLevel( var useDisk: Boolean, var useMemory: Boolean, @@ -66,12 +73,13 @@ class StorageLevel( object StorageLevel { val NONE = new StorageLevel(false, false, false) val DISK_ONLY = new StorageLevel(true, false, false) - val MEMORY_ONLY = new StorageLevel(false, true, false) - val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2) - val MEMORY_ONLY_DESER = new StorageLevel(false, true, true) - val MEMORY_ONLY_DESER_2 = new StorageLevel(false, true, true, 2) - val DISK_AND_MEMORY = new StorageLevel(true, true, false) - val DISK_AND_MEMORY_2 = new StorageLevel(true, true, false, 2) - val DISK_AND_MEMORY_DESER = new StorageLevel(true, true, true) - val DISK_AND_MEMORY_DESER_2 = new StorageLevel(true, true, true, 2) + val DISK_ONLY_2 = new StorageLevel(true, false, false, 2) + val MEMORY_ONLY = new StorageLevel(false, true, true) + val MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2) + val MEMORY_ONLY_SER = new StorageLevel(false, true, false) + val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2) + val MEMORY_AND_DISK = new StorageLevel(true, true, true) + val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) + val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) + val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 57d212e4ca..b466b5239c 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -17,12 +17,14 @@ import java.util.concurrent.TimeoutException /** * Various utility classes for working with Akka. */ -object AkkaUtils { +private[spark] object AkkaUtils { /** * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the * ActorSystem itself and its port (which is hard to get from Akka). */ def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { + val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt + val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt val akkaConf = ConfigFactory.parseString(""" akka.daemonic = on akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] @@ -31,7 +33,9 @@ object AkkaUtils { akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d akka.remote.netty.connection-timeout = 1s - """.format(host, port)) + akka.remote.netty.execution-pool-size = %d + akka.actor.default-dispatcher.throughput = %d + """.format(host, port, akkaThreads, akkaBatchSize)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala index 0ce255105a..d7e67497fe 100644 --- a/core/src/main/scala/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/spark/util/ByteBufferInputStream.scala @@ -2,10 +2,19 @@ package spark.util import java.io.InputStream import java.nio.ByteBuffer +import spark.storage.BlockManager + +/** + * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose() + * at the end of the stream (e.g. to close a memory-mapped file). + */ +private[spark] +class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false) + extends InputStream { -class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream { override def read(): Int = { - if (buffer.remaining() == 0) { + if (buffer == null || buffer.remaining() == 0) { + cleanUp() -1 } else { buffer.get() & 0xFF @@ -17,7 +26,8 @@ class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream { } override def read(dest: Array[Byte], offset: Int, length: Int): Int = { - if (buffer.remaining() == 0) { + if (buffer == null || buffer.remaining() == 0) { + cleanUp() -1 } else { val amountToGet = math.min(buffer.remaining(), length) @@ -27,8 +37,27 @@ class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream { } override def skip(bytes: Long): Long = { - val amountToSkip = math.min(bytes, buffer.remaining).toInt - buffer.position(buffer.position + amountToSkip) - return amountToSkip + if (buffer != null) { + val amountToSkip = math.min(bytes, buffer.remaining).toInt + buffer.position(buffer.position + amountToSkip) + if (buffer.remaining() == 0) { + cleanUp() + } + amountToSkip + } else { + 0L + } + } + + /** + * Clean up the buffer, and potentially dispose of it using BlockManager.dispose(). + */ + private def cleanUp() { + if (buffer != null) { + if (dispose) { + BlockManager.dispose(buffer) + } + buffer = null + } } } diff --git a/core/src/main/scala/spark/util/IntParam.scala b/core/src/main/scala/spark/util/IntParam.scala index c3ff063569..0427646747 100644 --- a/core/src/main/scala/spark/util/IntParam.scala +++ b/core/src/main/scala/spark/util/IntParam.scala @@ -3,7 +3,7 @@ package spark.util /** * An extractor object for parsing strings into integers. */ -object IntParam { +private[spark] object IntParam { def unapply(str: String): Option[Int] = { try { Some(str.toInt) diff --git a/core/src/main/scala/spark/util/MemoryParam.scala b/core/src/main/scala/spark/util/MemoryParam.scala index 4fba914afe..3726738842 100644 --- a/core/src/main/scala/spark/util/MemoryParam.scala +++ b/core/src/main/scala/spark/util/MemoryParam.scala @@ -6,7 +6,7 @@ import spark.Utils * An extractor object for parsing JVM memory strings, such as "10g", into an Int representing * the number of megabytes. Supports the same formats as Utils.memoryStringToMb. */ -object MemoryParam { +private[spark] object MemoryParam { def unapply(str: String): Option[Int] = { try { Some(Utils.memoryStringToMb(str)) diff --git a/core/src/main/scala/spark/util/SerializableBuffer.scala b/core/src/main/scala/spark/util/SerializableBuffer.scala index 0830843a77..09d588fe1c 100644 --- a/core/src/main/scala/spark/util/SerializableBuffer.scala +++ b/core/src/main/scala/spark/util/SerializableBuffer.scala @@ -8,6 +8,7 @@ import java.nio.channels.Channels * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make * it easier to pass ByteBuffers in case class messages. */ +private[spark] class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable { def value = buffer diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala index 11d7939204..5f80180339 100644 --- a/core/src/main/scala/spark/util/StatCounter.scala +++ b/core/src/main/scala/spark/util/StatCounter.scala @@ -2,8 +2,10 @@ package spark.util /** * A class for tracking the statistics of a set of numbers (count, mean and variance) in a - * numerically robust way. Includes support for merging two StatCounters. Based on Welford and - * Chan's algorithms described at http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. + * numerically robust way. Includes support for merging two StatCounters. Based on + * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Welford and Chan's algorithms for running variance]]. + * + * @constructor Initialize the StatCounter with the given values. */ class StatCounter(values: TraversableOnce[Double]) extends Serializable { private var n: Long = 0 // Running count of our values @@ -12,8 +14,10 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { merge(values) + /** Initialize the StatCounter with no values. */ def this() = this(Nil) + /** Add a value into this StatCounter, updating the internal statistics. */ def merge(value: Double): StatCounter = { val delta = value - mu n += 1 @@ -22,11 +26,13 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { this } + /** Add multiple values into this StatCounter, updating the internal statistics. */ def merge(values: TraversableOnce[Double]): StatCounter = { values.foreach(v => merge(v)) this } + /** Merge another StatCounter into this one, adding up the internal statistics. */ def merge(other: StatCounter): StatCounter = { if (other == this) { merge(other.copy()) // Avoid overwriting fields in a weird order @@ -45,6 +51,7 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { } } + /** Clone this StatCounter */ def copy(): StatCounter = { val other = new StatCounter other.n = n @@ -59,6 +66,7 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { def sum: Double = n * mu + /** Return the variance of the values. */ def variance: Double = { if (n == 0) Double.NaN @@ -66,6 +74,10 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { m2 / n } + /** + * Return the sample variance, which corrects for bias in estimating the variance by dividing + * by N-1 instead of N. + */ def sampleVariance: Double = { if (n <= 1) Double.NaN @@ -73,8 +85,13 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { m2 / (n - 1) } + /** Return the standard deviation of the values. */ def stdev: Double = math.sqrt(variance) + /** + * Return the sample standard deviation of the values, which corrects for bias in estimating the + * variance by dividing by N-1 instead of N. + */ def sampleStdev: Double = math.sqrt(sampleVariance) override def toString: String = { @@ -83,7 +100,9 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { } object StatCounter { + /** Build a StatCounter from a list of values. */ def apply(values: TraversableOnce[Double]) = new StatCounter(values) + /** Build a StatCounter from a list of values passed as variable-length arguments. */ def apply(values: Double*) = new StatCounter(values) } |