diff options
45 files changed, 1893 insertions, 355 deletions
diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 17d1978dde..f7f8535594 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -5,5 +5,7 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n -# Ignore messages below warning level from Jetty, because it's a bit verbose +# Settings to quiet third party logs that are too verbose log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 17d1978dde..f7f8535594 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -5,5 +5,7 @@ log4j.appender.console.target=System.err log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n -# Ignore messages below warning level from Jetty, because it's a bit verbose +# Settings to quiet third party logs that are too verbose log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 1a2ec55876..8b30cd4bfe 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -17,7 +17,7 @@ package org.apache.spark -import org.apache.spark.util.AppendOnlyMap +import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} /** * A set of functions used to aggregate data. @@ -31,30 +31,51 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { + private val sparkConf = SparkEnv.get.conf + private val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true) + def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = { - val combiners = new AppendOnlyMap[K, C] - var kv: Product2[K, V] = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) - } - while (iter.hasNext) { - kv = iter.next() - combiners.changeValue(kv._1, update) + if (!externalSorting) { + val combiners = new AppendOnlyMap[K,C] + var kv: Product2[K, V] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) + } + while (iter.hasNext) { + kv = iter.next() + combiners.changeValue(kv._1, update) + } + combiners.iterator + } else { + val combiners = + new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + while (iter.hasNext) { + val (k, v) = iter.next() + combiners.insert(k, v) + } + combiners.iterator } - combiners.iterator } def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = { - val combiners = new AppendOnlyMap[K, C] - var kc: (K, C) = null - val update = (hadValue: Boolean, oldValue: C) => { - if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 + if (!externalSorting) { + val combiners = new AppendOnlyMap[K,C] + var kc: Product2[K, C] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 + } + while (iter.hasNext) { + kc = iter.next() + combiners.changeValue(kc._1, update) + } + combiners.iterator + } else { + val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + while (iter.hasNext) { + val (k, c) = iter.next() + combiners.insert(k, c) + } + combiners.iterator } - while (iter.hasNext) { - kc = iter.next() - combiners.changeValue(kc._1, update) - } - combiners.iterator } } - diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 66c226e491..139048d5c7 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -677,10 +677,10 @@ class SparkContext( key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (SparkHadoopUtil.get.isYarnMode()) { - // In order for this to work on yarn the user must specify the --addjars option to - // the client to upload the file into the distributed cache to make it show up in the - // current working directory. + if (SparkHadoopUtil.get.isYarnMode() && master == "yarn-standalone") { + // In order for this to work in yarn standalone mode the user must specify the + // --addjars option to the client to upload the file into the distributed cache + // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() try { env.httpFileServer.addJar(new File(fileName)) @@ -1086,7 +1086,7 @@ object SparkContext { * parameters that are passed as the default value of null, instead of throwing an exception * like SparkConf would. */ - private def updatedConf( + private[spark] def updatedConf( conf: SparkConf, master: String, appName: String, diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index e093e2f162..08b592df71 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -54,7 +54,11 @@ class SparkEnv private[spark] ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, - val conf: SparkConf) { + val conf: SparkConf) extends Logging { + + // A mapping of thread ID to amount of memory used for shuffle in bytes + // All accesses should be manually synchronized + val shuffleMemoryMap = mutable.HashMap[Long, Long]() private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e51d274d33..a7b2328a02 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -279,6 +279,11 @@ private[spark] class Executor( //System.exit(1) } } finally { + // TODO: Unregister shuffle memory only for ShuffleMapTask + val shuffleMemoryMap = env.shuffleMemoryMap + shuffleMemoryMap.synchronized { + shuffleMemoryMap.remove(Thread.currentThread().getId) + } runningTasks.remove(taskId) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 4ba4696fef..a73714abca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,8 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} -import org.apache.spark.util.AppendOnlyMap - +import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -44,14 +43,12 @@ private[spark] case class NarrowCoGroupSplitDep( private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep -private[spark] -class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) +private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx } - /** * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. @@ -62,6 +59,14 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { + // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs). + // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner. + // CoGroupValue is the intermediate state of each value before being merged in compute. + private type CoGroup = ArrayBuffer[Any] + private type CoGroupValue = (Any, Int) // Int is dependency number + private type CoGroupCombiner = Seq[CoGroup] + + private val sparkConf = SparkEnv.get.conf private var serializerClass: String = null def setSerializer(cls: String): CoGroupedRDD[K] = { @@ -100,37 +105,74 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: override val partitioner = Some(part) - override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { + override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = { + val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true) val split = s.asInstanceOf[CoGroupPartition] val numRdds = split.deps.size - // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) - val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]] - val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any]) - } - - val getSeq = (k: K) => { - map.changeValue(k, update) - } - - val ser = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf) + // A list of (rdd iterator, dependency number) pairs + val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent - rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv => - getSeq(kv._1)(depNum) += kv._2 - } + val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] + rddIterators += ((it, depNum)) } case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach { - kv => getSeq(kv._1)(depNum) += kv._2 + val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf) + val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser) + rddIterators += ((it, depNum)) + } + } + + if (!externalSorting) { + val map = new AppendOnlyMap[K, CoGroupCombiner] + val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => { + if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup) + } + val getCombiner: K => CoGroupCombiner = key => { + map.changeValue(key, update) + } + rddIterators.foreach { case (it, depNum) => + while (it.hasNext) { + val kv = it.next() + getCombiner(kv._1)(depNum) += kv._2 } } + new InterruptibleIterator(context, map.iterator) + } else { + val map = createExternalMap(numRdds) + rddIterators.foreach { case (it, depNum) => + while (it.hasNext) { + val kv = it.next() + map.insert(kv._1, new CoGroupValue(kv._2, depNum)) + } + } + new InterruptibleIterator(context, map.iterator) + } + } + + private def createExternalMap(numRdds: Int) + : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = { + + val createCombiner: (CoGroupValue => CoGroupCombiner) = value => { + val newCombiner = Array.fill(numRdds)(new CoGroup) + value match { case (v, depNum) => newCombiner(depNum) += v } + newCombiner } - new InterruptibleIterator(context, map.iterator) + val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner = + (combiner, value) => { + value match { case (v, depNum) => combiner(depNum) += v } + combiner + } + val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner = + (combiner1, combiner2) => { + combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 } + } + new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner]( + createCombiner, mergeValue, mergeCombiners) } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c118ddfc01..1248409e35 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -99,8 +99,6 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) }, preservesPartitioning = true) } else { // Don't apply map-side combiner. - // A sanity check to make sure mergeCombiners is not defined. - assert(mergeCombiners == null) val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) values.mapPartitionsWithContext((context, iter) => { new InterruptibleIterator(context, aggregator.combineValuesByKey(iter)) @@ -267,8 +265,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) // into a hash table, leading to more objects in the old gen. def createCombiner(v: V) = ArrayBuffer(v) def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v + def mergeCombiners(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = c1 ++ c2 val bufs = combineByKey[ArrayBuffer[V]]( - createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false) + createCombiner _, mergeValue _, mergeCombiners _, partitioner, mapSideCombine=false) bufs.asInstanceOf[RDD[(K, Seq[V])]] } @@ -339,7 +338,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * existing partitioner/parallelism level. */ def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) - : RDD[(K, C)] = { + : RDD[(K, C)] = { combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 043e01dbfb..38b536023b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -106,7 +106,7 @@ class DAGScheduler( // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in - val RESUBMIT_TIMEOUT = 50.milliseconds + val RESUBMIT_TIMEOUT = 200.milliseconds // The time, in millis, to wake up between polls of the completion queue in order to potentially // resubmit failed stages @@ -196,7 +196,7 @@ class DAGScheduler( */ def receive = { case event: DAGSchedulerEvent => - logDebug("Got event of type " + event.getClass.getName) + logTrace("Got event of type " + event.getClass.getName) /** * All events are forwarded to `processEvent()`, so that the event processing logic can diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7156d855d8..301d784b35 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -17,12 +17,14 @@ package org.apache.spark.storage +import java.util.UUID + /** * Identifies a particular Block of data, usually associated with a single file. * A Block can be uniquely identified by its filename, but each type of Block has a different * set of keys which produce its unique name. * - * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method. + * If your BlockId should be serializable, be sure to add it to the BlockId.apply() method. */ private[spark] sealed abstract class BlockId { /** A globally unique identifier for this Block. Can be used for ser/de. */ @@ -55,7 +57,8 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { def name = "broadcast_" + broadcastId } -private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { +private[spark] +case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { def name = broadcastId.name + "_" + hType } @@ -67,6 +70,11 @@ private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends B def name = "input-" + streamId + "-" + uniqueId } +/** Id associated with temporary data managed as blocks. Not serializable. */ +private[spark] case class TempBlockId(id: UUID) extends BlockId { + def name = "temp_" + id +} + // Intended only for testing purposes private[spark] case class TestBlockId(id: String) extends BlockId { def name = "test_" + id diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e63e4eb737..ff9f241fc1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -864,7 +864,7 @@ private[spark] object BlockManager extends Logging { val ID_GENERATOR = new IdGenerator def getMaxMemory(conf: SparkConf): Long = { - val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.66) + val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6) (Runtime.getRuntime.maxMemory * memoryFraction).toLong } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 61e63c60d5..369a277232 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -181,4 +181,8 @@ class DiskBlockObjectWriter( // Only valid if called after close() override def timeWriting() = _timeWriting + + def bytesWritten: Long = { + lastValidPosition - initialPosition + } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index edc1133172..a8ef7fa8b6 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import java.io.File import java.text.SimpleDateFormat -import java.util.{Date, Random} +import java.util.{Date, Random, UUID} import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode @@ -90,6 +90,15 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD def getFile(blockId: BlockId): File = getFile(blockId.name) + /** Produces a unique block id and File suitable for intermediate results. */ + def createTempBlock(): (TempBlockId, File) = { + var blockId = new TempBlockId(UUID.randomUUID()) + while (getFile(blockId).exists()) { + blockId = new TempBlockId(UUID.randomUUID()) + } + (blockId, getFile(blockId)) + } + private def createLocalDirs(): Array[File] = { logDebug("Creating local directories at root dirs '" + rootDirs + "'") val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 181ae2fd45..8e07a0f29a 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -26,16 +26,23 @@ import org.apache.spark.Logging /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion - * time stamp along with each key-value pair. Key-value pairs that are older than a particular - * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in - * replacement of scala.collection.mutable.HashMap. + * timestamp along with each key-value pair. If specified, the timestamp of each pair can be + * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular + * threshold time can then be removed using the clearOldValues method. This is intended to + * be a drop-in replacement of scala.collection.mutable.HashMap. + * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be + * updated when it is accessed */ -class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { +class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends Map[A, B]() with Logging { val internalMap = new ConcurrentHashMap[A, (B, Long)]() def get(key: A): Option[B] = { val value = internalMap.get(key) - if (value != null) Some(value._1) else None + if (value != null && updateTimeStampOnGet) { + internalMap.replace(key, value, (value._1, currentTime)) + } + Option(value).map(_._1) } def iterator: Iterator[(A, B)] = { diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index fe710c58ac..62fd6d8da5 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.util.Random + class Vector(val elements: Array[Double]) extends Serializable { def length = elements.length @@ -124,6 +126,12 @@ object Vector { def ones(length: Int) = Vector(length, _ => 1) + /** + * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers + * between 0.0 and 1.0. Optional [[scala.util.Random]] number generator can be provided. + */ + def random(length: Int, random: Random = new XORShiftRandom()) = Vector(length, _ => random.nextDouble()) + class Multiplier(num: Double) { def * (vec: Vector) = vec * num } diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index 8bb4ee3bfa..d98c7aa3d7 100644 --- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.util.collection + +import java.util.{Arrays, Comparator} /** * A simple open hash table optimized for the append-only use case, where keys @@ -28,14 +30,15 @@ package org.apache.spark.util * TODO: Cache the hash values of each key? java.util.HashMap does that. */ private[spark] -class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable { +class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, + V)] with Serializable { require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") require(initialCapacity >= 1, "Invalid initial capacity") private var capacity = nextPowerOf2(initialCapacity) private var mask = capacity - 1 private var curSize = 0 - private var growThreshold = LOAD_FACTOR * capacity + private var growThreshold = (LOAD_FACTOR * capacity).toInt // Holds keys and values in the same array for memory locality; specifically, the order of // elements is key0, value0, key1, value1, key2, value2, etc. @@ -45,10 +48,15 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi private var haveNullValue = false private var nullValue: V = null.asInstanceOf[V] + // Triggered by destructiveSortedIterator; the underlying data array may no longer be used + private var destroyed = false + private val destructionMessage = "Map state is invalid from destructive sorting!" + private val LOAD_FACTOR = 0.7 /** Get the value for a given key */ def apply(key: K): V = { + assert(!destroyed, destructionMessage) val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { return nullValue @@ -72,6 +80,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi /** Set the value for a key */ def update(key: K, value: V): Unit = { + assert(!destroyed, destructionMessage) val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { if (!haveNullValue) { @@ -106,6 +115,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi * for key, if any, or null otherwise. Returns the newly updated value. */ def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { + assert(!destroyed, destructionMessage) val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { if (!haveNullValue) { @@ -139,35 +149,38 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } /** Iterator method from Iterable */ - override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { - var pos = -1 - - /** Get the next value we should return from next(), or null if we're finished iterating */ - def nextValue(): (K, V) = { - if (pos == -1) { // Treat position -1 as looking at the null value - if (haveNullValue) { - return (null.asInstanceOf[K], nullValue) + override def iterator: Iterator[(K, V)] = { + assert(!destroyed, destructionMessage) + new Iterator[(K, V)] { + var pos = -1 + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def nextValue(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + return (null.asInstanceOf[K], nullValue) + } + pos += 1 } - pos += 1 - } - while (pos < capacity) { - if (!data(2 * pos).eq(null)) { - return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + while (pos < capacity) { + if (!data(2 * pos).eq(null)) { + return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + } + pos += 1 } - pos += 1 + null } - null - } - override def hasNext: Boolean = nextValue() != null + override def hasNext: Boolean = nextValue() != null - override def next(): (K, V) = { - val value = nextValue() - if (value == null) { - throw new NoSuchElementException("End of iterator") + override def next(): (K, V) = { + val value = nextValue() + if (value == null) { + throw new NoSuchElementException("End of iterator") + } + pos += 1 + value } - pos += 1 - value } } @@ -190,7 +203,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } /** Double the table's size and re-hash everything */ - private def growTable() { + protected def growTable() { val newCapacity = capacity * 2 if (newCapacity >= (1 << 30)) { // We can't make the table this big because we want an array of 2x @@ -227,11 +240,58 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi data = newData capacity = newCapacity mask = newMask - growThreshold = LOAD_FACTOR * newCapacity + growThreshold = (LOAD_FACTOR * newCapacity).toInt } private def nextPowerOf2(n: Int): Int = { val highBit = Integer.highestOneBit(n) if (highBit == n) n else highBit << 1 } + + /** + * Return an iterator of the map in sorted order. This provides a way to sort the map without + * using additional memory, at the expense of destroying the validity of the map. + */ + def destructiveSortedIterator(cmp: Comparator[(K, V)]): Iterator[(K, V)] = { + destroyed = true + // Pack KV pairs into the front of the underlying array + var keyIndex, newIndex = 0 + while (keyIndex < capacity) { + if (data(2 * keyIndex) != null) { + data(newIndex) = (data(2 * keyIndex), data(2 * keyIndex + 1)) + newIndex += 1 + } + keyIndex += 1 + } + assert(curSize == newIndex + (if (haveNullValue) 1 else 0)) + + // Sort by the given ordering + val rawOrdering = new Comparator[AnyRef] { + def compare(x: AnyRef, y: AnyRef): Int = { + cmp.compare(x.asInstanceOf[(K, V)], y.asInstanceOf[(K, V)]) + } + } + Arrays.sort(data, 0, newIndex, rawOrdering) + + new Iterator[(K, V)] { + var i = 0 + var nullValueReady = haveNullValue + def hasNext: Boolean = (i < newIndex || nullValueReady) + def next(): (K, V) = { + if (nullValueReady) { + nullValueReady = false + (null.asInstanceOf[K], nullValue) + } else { + val item = data(i).asInstanceOf[(K, V)] + i += 1 + item + } + } + } + } + + /** + * Return whether the next insert will cause the map to grow + */ + def atGrowThreshold: Boolean = curSize == growThreshold } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala new file mode 100644 index 0000000000..e3bcd895aa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import java.io._ +import java.util.Comparator + +import it.unimi.dsi.fastutil.io.FastBufferedInputStream + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.serializer.Serializer +import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter} + +/** + * An append-only map that spills sorted content to disk when there is insufficient space for it + * to grow. + * + * This map takes two passes over the data: + * + * (1) Values are merged into combiners, which are sorted and spilled to disk as necessary + * (2) Combiners are read from disk and merged together + * + * The setting of the spill threshold faces the following trade-off: If the spill threshold is + * too high, the in-memory map may occupy more memory than is available, resulting in OOM. + * However, if the spill threshold is too low, we spill frequently and incur unnecessary disk + * writes. This may lead to a performance regression compared to the normal case of using the + * non-spilling AppendOnlyMap. + * + * Two parameters control the memory threshold: + * + * `spark.shuffle.memoryFraction` specifies the collective amount of memory used for storing + * these maps as a fraction of the executor's total memory. Since each concurrently running + * task maintains one map, the actual threshold for each map is this quantity divided by the + * number of running tasks. + * + * `spark.shuffle.safetyFraction` specifies an additional margin of safety as a fraction of + * this threshold, in case map size estimation is not sufficiently accurate. + */ + +private[spark] class ExternalAppendOnlyMap[K, V, C]( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + serializer: Serializer = SparkEnv.get.serializerManager.default, + diskBlockManager: DiskBlockManager = SparkEnv.get.blockManager.diskBlockManager) + extends Iterable[(K, C)] with Serializable with Logging { + + import ExternalAppendOnlyMap._ + + private var currentMap = new SizeTrackingAppendOnlyMap[K, C] + private val spilledMaps = new ArrayBuffer[DiskMapIterator] + private val sparkConf = SparkEnv.get.conf + + // Collective memory threshold shared across all running tasks + private val maxMemoryThreshold = { + val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.3) + val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8) + (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong + } + + // Number of pairs in the in-memory map + private var numPairsInMemory = 0 + + // Number of in-memory pairs inserted before tracking the map's shuffle memory usage + private val trackMemoryThreshold = 1000 + + // How many times we have spilled so far + private var spillCount = 0 + + private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024 + private val syncWrites = sparkConf.getBoolean("spark.shuffle.sync", false) + private val comparator = new KCComparator[K, C] + private val ser = serializer.newInstance() + + /** + * Insert the given key and value into the map. + * + * If the underlying map is about to grow, check if the global pool of shuffle memory has + * enough room for this to happen. If so, allocate the memory required to grow the map; + * otherwise, spill the in-memory map to disk. + * + * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked. + */ + def insert(key: K, value: V) { + val update: (Boolean, C) => C = (hadVal, oldVal) => { + if (hadVal) mergeValue(oldVal, value) else createCombiner(value) + } + if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) { + val mapSize = currentMap.estimateSize() + var shouldSpill = false + val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap + + // Atomically check whether there is sufficient memory in the global pool for + // this map to grow and, if possible, allocate the required amount + shuffleMemoryMap.synchronized { + val threadId = Thread.currentThread().getId + val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId) + val availableMemory = maxMemoryThreshold - + (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L)) + + // Assume map growth factor is 2x + shouldSpill = availableMemory < mapSize * 2 + if (!shouldSpill) { + shuffleMemoryMap(threadId) = mapSize * 2 + } + } + // Do not synchronize spills + if (shouldSpill) { + spill(mapSize) + } + } + currentMap.changeValue(key, update) + numPairsInMemory += 1 + } + + /** + * Sort the existing contents of the in-memory map and spill them to a temporary file on disk + */ + private def spill(mapSize: Long) { + spillCount += 1 + logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)" + .format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + val (blockId, file) = diskBlockManager.createTempBlock() + val writer = + new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize, identity, syncWrites) + try { + val it = currentMap.destructiveSortedIterator(comparator) + while (it.hasNext) { + val kv = it.next() + writer.write(kv) + } + writer.commit() + } finally { + // Partial failures cannot be tolerated; do not revert partial writes + writer.close() + } + currentMap = new SizeTrackingAppendOnlyMap[K, C] + spilledMaps.append(new DiskMapIterator(file)) + + // Reset the amount of shuffle memory used by this map in the global pool + val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap + shuffleMemoryMap.synchronized { + shuffleMemoryMap(Thread.currentThread().getId) = 0 + } + numPairsInMemory = 0 + } + + /** + * Return an iterator that merges the in-memory map with the spilled maps. + * If no spill has occurred, simply return the in-memory map's iterator. + */ + override def iterator: Iterator[(K, C)] = { + if (spilledMaps.isEmpty) { + currentMap.iterator + } else { + new ExternalIterator() + } + } + + /** + * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps + */ + private class ExternalIterator extends Iterator[(K, C)] { + + // A fixed-size queue that maintains a buffer for each stream we are currently merging + val mergeHeap = new mutable.PriorityQueue[StreamBuffer] + + // Input streams are derived both from the in-memory map and spilled maps on disk + // The in-memory map is sorted in place, while the spilled maps are already in sorted order + val sortedMap = currentMap.destructiveSortedIterator(comparator) + val inputStreams = Seq(sortedMap) ++ spilledMaps + + inputStreams.foreach { it => + val kcPairs = getMorePairs(it) + mergeHeap.enqueue(StreamBuffer(it, kcPairs)) + } + + /** + * Fetch from the given iterator until a key of different hash is retrieved. In the + * event of key hash collisions, this ensures no pairs are hidden from being merged. + * Assume the given iterator is in sorted order. + */ + def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = { + val kcPairs = new ArrayBuffer[(K, C)] + if (it.hasNext) { + var kc = it.next() + kcPairs += kc + val minHash = kc._1.hashCode() + while (it.hasNext && kc._1.hashCode() == minHash) { + kc = it.next() + kcPairs += kc + } + } + kcPairs + } + + /** + * If the given buffer contains a value for the given key, merge that value into + * baseCombiner and remove the corresponding (K, C) pair from the buffer + */ + def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = { + var i = 0 + while (i < buffer.pairs.size) { + val (k, c) = buffer.pairs(i) + if (k == key) { + buffer.pairs.remove(i) + return mergeCombiners(baseCombiner, c) + } + i += 1 + } + baseCombiner + } + + /** + * Return true if there exists an input stream that still has unvisited pairs + */ + override def hasNext: Boolean = mergeHeap.exists(!_.pairs.isEmpty) + + /** + * Select a key with the minimum hash, then combine all values with the same key from all input streams. + */ + override def next(): (K, C) = { + // Select a key from the StreamBuffer that holds the lowest key hash + val minBuffer = mergeHeap.dequeue() + val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash) + if (minPairs.length == 0) { + // Should only happen when no other stream buffers have any pairs left + throw new NoSuchElementException + } + var (minKey, minCombiner) = minPairs.remove(0) + assert(minKey.hashCode() == minHash) + + // For all other streams that may have this key (i.e. have the same minimum key hash), + // merge in the corresponding value (if any) from that stream + val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer) + while (!mergeHeap.isEmpty && mergeHeap.head.minKeyHash == minHash) { + val newBuffer = mergeHeap.dequeue() + minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer) + mergedBuffers += newBuffer + } + + // Repopulate each visited stream buffer and add it back to the merge heap + mergedBuffers.foreach { buffer => + if (buffer.pairs.length == 0) { + buffer.pairs ++= getMorePairs(buffer.iterator) + } + mergeHeap.enqueue(buffer) + } + + (minKey, minCombiner) + } + + /** + * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash. + * Each buffer maintains the lowest-ordered keys in the corresponding iterator. Due to + * hash collisions, it is possible for multiple keys to be "tied" for being the lowest. + * + * StreamBuffers are ordered by the minimum key hash found across all of their own pairs. + */ + case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)]) + extends Comparable[StreamBuffer] { + + def minKeyHash: Int = { + if (pairs.length > 0){ + // pairs are already sorted by key hash + pairs(0)._1.hashCode() + } else { + Int.MaxValue + } + } + + override def compareTo(other: StreamBuffer): Int = { + // minus sign because mutable.PriorityQueue dequeues the max, not the min + -minKeyHash.compareTo(other.minKeyHash) + } + } + } + + /** + * An iterator that returns (K, C) pairs in sorted order from an on-disk map + */ + private class DiskMapIterator(file: File) extends Iterator[(K, C)] { + val fileStream = new FileInputStream(file) + val bufferedStream = new FastBufferedInputStream(fileStream) + val deserializeStream = ser.deserializeStream(bufferedStream) + var nextItem: (K, C) = null + var eof = false + + def readNextItem(): (K, C) = { + if (!eof) { + try { + return deserializeStream.readObject().asInstanceOf[(K, C)] + } catch { + case e: EOFException => + eof = true + cleanup() + } + } + null + } + + override def hasNext: Boolean = { + if (nextItem == null) { + nextItem = readNextItem() + } + nextItem != null + } + + override def next(): (K, C) = { + val item = if (nextItem == null) readNextItem() else nextItem + if (item == null) { + throw new NoSuchElementException + } + nextItem = null + item + } + + // TODO: Ensure this gets called even if the iterator isn't drained. + def cleanup() { + deserializeStream.close() + file.delete() + } + } +} + +private[spark] object ExternalAppendOnlyMap { + private class KCComparator[K, C] extends Comparator[(K, C)] { + def compare(kc1: (K, C), kc2: (K, C)): Int = { + kc1._1.hashCode().compareTo(kc2._1.hashCode()) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala new file mode 100644 index 0000000000..204330dad4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.util.SizeEstimator +import org.apache.spark.util.collection.SizeTrackingAppendOnlyMap.Sample + +/** + * Append-only map that keeps track of its estimated size in bytes. + * We sample with a slow exponential back-off using the SizeEstimator to amortize the time, + * as each call to SizeEstimator can take a sizable amount of time (order of a few milliseconds). + */ +private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] { + + /** + * Controls the base of the exponential which governs the rate of sampling. + * E.g., a value of 2 would mean we sample at 1, 2, 4, 8, ... elements. + */ + private val SAMPLE_GROWTH_RATE = 1.1 + + /** All samples taken since last resetSamples(). Only the last two are used for extrapolation. */ + private val samples = new ArrayBuffer[Sample]() + + /** Total number of insertions and updates into the map since the last resetSamples(). */ + private var numUpdates: Long = _ + + /** The value of 'numUpdates' at which we will take our next sample. */ + private var nextSampleNum: Long = _ + + /** The average number of bytes per update between our last two samples. */ + private var bytesPerUpdate: Double = _ + + resetSamples() + + /** Called after the map grows in size, as this can be a dramatic change for small objects. */ + def resetSamples() { + numUpdates = 1 + nextSampleNum = 1 + samples.clear() + takeSample() + } + + override def update(key: K, value: V): Unit = { + super.update(key, value) + numUpdates += 1 + if (nextSampleNum == numUpdates) { takeSample() } + } + + override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { + val newValue = super.changeValue(key, updateFunc) + numUpdates += 1 + if (nextSampleNum == numUpdates) { takeSample() } + newValue + } + + /** Takes a new sample of the current map's size. */ + def takeSample() { + samples += Sample(SizeEstimator.estimate(this), numUpdates) + // Only use the last two samples to extrapolate. If fewer than 2 samples, assume no change. + bytesPerUpdate = math.max(0, samples.toSeq.reverse match { + case latest :: previous :: tail => + (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates) + case _ => + 0 + }) + nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong + } + + override protected def growTable() { + super.growTable() + resetSamples() + } + + /** Estimates the current size of the map in bytes. O(1) time. */ + def estimateSize(): Long = { + assert(samples.nonEmpty) + val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates) + (samples.last.size + extrapolatedDelta).toLong + } +} + +private object SizeTrackingAppendOnlyMap { + case class Sample(size: Long, numUpdates: Long) +} diff --git a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala new file mode 100644 index 0000000000..93f0c6a8e6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.util.Random + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.util.SizeTrackingAppendOnlyMapSuite.LargeDummyClass +import org.apache.spark.util.collection.{AppendOnlyMap, SizeTrackingAppendOnlyMap} + +class SizeTrackingAppendOnlyMapSuite extends FunSuite with BeforeAndAfterAll { + val NORMAL_ERROR = 0.20 + val HIGH_ERROR = 0.30 + + test("fixed size insertions") { + testWith[Int, Long](10000, i => (i, i.toLong)) + testWith[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong))) + testWith[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass())) + } + + test("variable size insertions") { + val rand = new Random(123456789) + def randString(minLen: Int, maxLen: Int): String = { + "a" * (rand.nextInt(maxLen - minLen) + minLen) + } + testWith[Int, String](10000, i => (i, randString(0, 10))) + testWith[Int, String](10000, i => (i, randString(0, 100))) + testWith[Int, String](10000, i => (i, randString(90, 100))) + } + + test("updates") { + val rand = new Random(123456789) + def randString(minLen: Int, maxLen: Int): String = { + "a" * (rand.nextInt(maxLen - minLen) + minLen) + } + testWith[String, Int](10000, i => (randString(0, 10000), i)) + } + + def testWith[K, V](numElements: Int, makeElement: (Int) => (K, V)) { + val map = new SizeTrackingAppendOnlyMap[K, V]() + for (i <- 0 until numElements) { + val (k, v) = makeElement(i) + map(k) = v + expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR) + } + } + + def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) { + val betterEstimatedSize = SizeEstimator.estimate(obj) + assert(betterEstimatedSize * (1 - error) < estimatedSize, + s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize") + assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize, + s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize") + } +} + +object SizeTrackingAppendOnlyMapSuite { + // Speed test, for reproducibility of results. + // These could be highly non-deterministic in general, however. + // Results: + // AppendOnlyMap: 31 ms + // SizeTracker: 54 ms + // SizeEstimator: 1500 ms + def main(args: Array[String]) { + val numElements = 100000 + + val baseTimes = for (i <- 0 until 10) yield time { + val map = new AppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + } + } + + val sampledTimes = for (i <- 0 until 10) yield time { + val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + map.estimateSize() + } + } + + val unsampledTimes = for (i <- 0 until 3) yield time { + val map = new AppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + SizeEstimator.estimate(map) + } + } + + println("Base: " + baseTimes) + println("SizeTracker (sampled): " + sampledTimes) + println("SizeEstimator (unsampled): " + unsampledTimes) + } + + def time(f: => Unit): Long = { + val start = System.currentTimeMillis() + f + System.currentTimeMillis() - start + } + + private class LargeDummyClass { + val arr = new Array[Int](100) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala new file mode 100644 index 0000000000..7006571ef0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.util.Random + +import org.scalatest.FunSuite + +/** + * Tests org.apache.spark.util.Vector functionality + */ +class VectorSuite extends FunSuite { + + def verifyVector(vector: Vector, expectedLength: Int) = { + assert(vector.length == expectedLength) + assert(vector.elements.min > 0.0) + assert(vector.elements.max < 1.0) + } + + test("random with default random number generator") { + val vector100 = Vector.random(100) + verifyVector(vector100, 100) + } + + test("random with given random number generator") { + val vector100 = Vector.random(100, new Random(100)) + verifyVector(vector100, 100) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala index 7177919a58..f44442f1a5 100644 --- a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.util +package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite +import java.util.Comparator class AppendOnlyMapSuite extends FunSuite { test("initialization") { @@ -151,4 +152,47 @@ class AppendOnlyMapSuite extends FunSuite { assert(map("" + i) === "" + i) } } + + test("destructive sort") { + val map = new AppendOnlyMap[String, String]() + for (i <- 1 to 100) { + map("" + i) = "" + i + } + map.update(null, "happy new year!") + + try { + map.apply("1") + map.update("1", "2013") + map.changeValue("1", (hadValue, oldValue) => "2014") + map.iterator + } catch { + case e: IllegalStateException => fail() + } + + val it = map.destructiveSortedIterator(new Comparator[(String, String)] { + def compare(kv1: (String, String), kv2: (String, String)): Int = { + val x = if (kv1 != null && kv1._1 != null) kv1._1.toInt else Int.MinValue + val y = if (kv2 != null && kv2._1 != null) kv2._1.toInt else Int.MinValue + x.compareTo(y) + } + }) + + // Should be sorted by key + assert(it.hasNext) + var previous = it.next() + assert(previous == (null, "happy new year!")) + previous = it.next() + assert(previous == ("1", "2014")) + while (it.hasNext) { + val kv = it.next() + assert(kv._1.toInt > previous._1.toInt) + previous = kv + } + + // All subsequent calls to apply, update, changeValue and iterator should throw exception + intercept[AssertionError] { map.apply("1") } + intercept[AssertionError] { map.update("1", "2013") } + intercept[AssertionError] { map.changeValue("1", (hadValue, oldValue) => "2014") } + intercept[AssertionError] { map.iterator } + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala new file mode 100644 index 0000000000..ef957bb0e5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -0,0 +1,230 @@ +package org.apache.spark.util.collection + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark._ +import org.apache.spark.SparkContext._ + +class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + override def beforeEach() { + val conf = new SparkConf(false) + conf.set("spark.shuffle.externalSorting", "true") + sc = new SparkContext("local", "test", conf) + } + + val createCombiner: (Int => ArrayBuffer[Int]) = i => ArrayBuffer[Int](i) + val mergeValue: (ArrayBuffer[Int], Int) => ArrayBuffer[Int] = (buffer, i) => { + buffer += i + } + val mergeCombiners: (ArrayBuffer[Int], ArrayBuffer[Int]) => ArrayBuffer[Int] = + (buf1, buf2) => { + buf1 ++= buf2 + } + + test("simple insert") { + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + + // Single insert + map.insert(1, 10) + var it = map.iterator + assert(it.hasNext) + val kv = it.next() + assert(kv._1 == 1 && kv._2 == ArrayBuffer[Int](10)) + assert(!it.hasNext) + + // Multiple insert + map.insert(2, 20) + map.insert(3, 30) + it = map.iterator + assert(it.hasNext) + assert(it.toSet == Set[(Int, ArrayBuffer[Int])]( + (1, ArrayBuffer[Int](10)), + (2, ArrayBuffer[Int](20)), + (3, ArrayBuffer[Int](30)))) + } + + test("insert with collision") { + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + + map.insert(1, 10) + map.insert(2, 20) + map.insert(3, 30) + map.insert(1, 100) + map.insert(2, 200) + map.insert(1, 1000) + val it = map.iterator + assert(it.hasNext) + val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) + assert(result == Set[(Int, Set[Int])]( + (1, Set[Int](10, 100, 1000)), + (2, Set[Int](20, 200)), + (3, Set[Int](30)))) + } + + test("ordering") { + val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map1.insert(1, 10) + map1.insert(2, 20) + map1.insert(3, 30) + + val map2 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map2.insert(2, 20) + map2.insert(3, 30) + map2.insert(1, 10) + + val map3 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map3.insert(3, 30) + map3.insert(1, 10) + map3.insert(2, 20) + + val it1 = map1.iterator + val it2 = map2.iterator + val it3 = map3.iterator + + var kv1 = it1.next() + var kv2 = it2.next() + var kv3 = it3.next() + assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) + assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + + kv1 = it1.next() + kv2 = it2.next() + kv3 = it3.next() + assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) + assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + + kv1 = it1.next() + kv2 = it2.next() + kv3 = it3.next() + assert(kv1._1 == kv2._1 && kv2._1 == kv3._1) + assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) + } + + test("null keys and values") { + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map.insert(1, 5) + map.insert(2, 6) + map.insert(3, 7) + assert(map.size === 3) + assert(map.iterator.toSet == Set[(Int, Seq[Int])]( + (1, Seq[Int](5)), + (2, Seq[Int](6)), + (3, Seq[Int](7)) + )) + + // Null keys + val nullInt = null.asInstanceOf[Int] + map.insert(nullInt, 8) + assert(map.size === 4) + assert(map.iterator.toSet == Set[(Int, Seq[Int])]( + (1, Seq[Int](5)), + (2, Seq[Int](6)), + (3, Seq[Int](7)), + (nullInt, Seq[Int](8)) + )) + + // Null values + map.insert(4, nullInt) + map.insert(nullInt, nullInt) + assert(map.size === 5) + val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) + assert(result == Set[(Int, Set[Int])]( + (1, Set[Int](5)), + (2, Set[Int](6)), + (3, Set[Int](7)), + (4, Set[Int](nullInt)), + (nullInt, Set[Int](nullInt, 8)) + )) + } + + test("simple aggregator") { + // reduceByKey + val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1)) + val result1 = rdd.reduceByKey(_+_).collect() + assert(result1.toSet == Set[(Int, Int)]((0, 5), (1, 5))) + + // groupByKey + val result2 = rdd.groupByKey().collect() + assert(result2.toSet == Set[(Int, Seq[Int])] + ((0, ArrayBuffer[Int](1, 1, 1, 1, 1)), (1, ArrayBuffer[Int](1, 1, 1, 1, 1)))) + } + + test("simple cogroup") { + val rdd1 = sc.parallelize(1 to 4).map(i => (i, i)) + val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i)) + val result = rdd1.cogroup(rdd2).collect() + + result.foreach { case (i, (seq1, seq2)) => + i match { + case 0 => assert(seq1.toSet == Set[Int]() && seq2.toSet == Set[Int](2, 4)) + case 1 => assert(seq1.toSet == Set[Int](1) && seq2.toSet == Set[Int](1, 3)) + case 2 => assert(seq1.toSet == Set[Int](2) && seq2.toSet == Set[Int]()) + case 3 => assert(seq1.toSet == Set[Int](3) && seq2.toSet == Set[Int]()) + case 4 => assert(seq1.toSet == Set[Int](4) && seq2.toSet == Set[Int]()) + } + } + } + + test("spilling") { + // TODO: Figure out correct memory parameters to actually induce spilling + // System.setProperty("spark.shuffle.buffer.mb", "1") + // System.setProperty("spark.shuffle.buffer.fraction", "0.05") + + // reduceByKey - should spill exactly 6 times + val rddA = sc.parallelize(0 until 10000).map(i => (i/2, i)) + val resultA = rddA.reduceByKey(math.max(_, _)).collect() + assert(resultA.length == 5000) + resultA.foreach { case(k, v) => + k match { + case 0 => assert(v == 1) + case 2500 => assert(v == 5001) + case 4999 => assert(v == 9999) + case _ => + } + } + + // groupByKey - should spill exactly 11 times + val rddB = sc.parallelize(0 until 10000).map(i => (i/4, i)) + val resultB = rddB.groupByKey().collect() + assert(resultB.length == 2500) + resultB.foreach { case(i, seq) => + i match { + case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3)) + case 1250 => assert(seq.toSet == Set[Int](5000, 5001, 5002, 5003)) + case 2499 => assert(seq.toSet == Set[Int](9996, 9997, 9998, 9999)) + case _ => + } + } + + // cogroup - should spill exactly 7 times + val rddC1 = sc.parallelize(0 until 1000).map(i => (i, i)) + val rddC2 = sc.parallelize(0 until 1000).map(i => (i%100, i)) + val resultC = rddC1.cogroup(rddC2).collect() + assert(resultC.length == 1000) + resultC.foreach { case(i, (seq1, seq2)) => + i match { + case 0 => + assert(seq1.toSet == Set[Int](0)) + assert(seq2.toSet == Set[Int](0, 100, 200, 300, 400, 500, 600, 700, 800, 900)) + case 500 => + assert(seq1.toSet == Set[Int](500)) + assert(seq2.toSet == Set[Int]()) + case 999 => + assert(seq1.toSet == Set[Int](999)) + assert(seq2.toSet == Set[Int]()) + case _ => + } + } + } + + // TODO: Test memory allocation for multiple concurrently running tasks +} diff --git a/docs/configuration.md b/docs/configuration.md index b1a0e19167..ad75e06fc7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -104,14 +104,25 @@ Apart from these, the following properties are also available, and may be useful </tr> <tr> <td>spark.storage.memoryFraction</td> - <td>0.66</td> + <td>0.6</td> <td> Fraction of Java heap to use for Spark's memory cache. This should not be larger than the "old" - generation of objects in the JVM, which by default is given 2/3 of the heap, but you can increase + generation of objects in the JVM, which by default is given 0.6 of the heap, but you can increase it if you configure your own old generation size. </td> </tr> <tr> + <td>spark.shuffle.memoryFraction</td> + <td>0.3</td> + <td> + Fraction of Java heap to use for aggregation and cogroups during shuffles, if + <code>spark.shuffle.externalSorting</code> is enabled. At any given time, the collective size of + all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will + begin to spill to disk. If spills are often, consider increasing this value at the expense of + <code>spark.storage.memoryFraction</code>. + </td> +</tr> +<tr> <td>spark.mesos.coarse</td> <td>false</td> <td> @@ -377,6 +388,14 @@ Apart from these, the following properties are also available, and may be useful </td> </tr> <tr> + <td>spark.shuffle.externalSorting</td> + <td>true</td> + <td> + If set to "true", limits the amount of memory used during reduces by spilling data out to disk. This spilling + threshold is specified by <code>spark.shuffle.memoryFraction</code>. + </td> +</tr> +<tr> <td>spark.speculation</td> <td>false</td> <td> diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index b206270107..3bd62646ba 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -101,7 +101,19 @@ With this mode, your application is actually run on the remote machine where the With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR -In order to tune worker core/number/memory etc. You need to export SPARK_WORKER_CORES, SPARK_WORKER_MEMORY, SPARK_WORKER_INSTANCES e.g. by ./conf/spark-env.sh +Configuration in yarn-client mode: + +In order to tune worker core/number/memory etc. You need to export environment variables or add them to the spark configuration file (./conf/spark_env.sh). The following are the list of options. + +* `SPARK_YARN_APP_JAR`, Path to your application's JAR file (required) +* `SPARK_WORKER_INSTANCES`, Number of workers to start (Default: 2) +* `SPARK_WORKER_CORES`, Number of cores for the workers (Default: 1). +* `SPARK_WORKER_MEMORY`, Memory per Worker (e.g. 1000M, 2G) (Default: 1G) +* `SPARK_MASTER_MEMORY`, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb) +* `SPARK_YARN_APP_NAME`, The name of your application (Default: Spark) +* `SPARK_YARN_QUEUE`, The hadoop queue to use for allocation requests (Default: 'default') +* `SPARK_YARN_DIST_FILES`, Comma separated list of files to be distributed with the job. +* `SPARK_YARN_DIST_ARCHIVES`, Comma separated list of archives to be distributed with the job. For example: @@ -114,7 +126,6 @@ For example: SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ MASTER=yarn-client ./bin/spark-shell -You can also send extra files to yarn cluster for worker to use by exporting SPARK_YARN_DIST_FILES=file1,file2... etc. # Building Spark for Hadoop/YARN 2.2.x diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index d82a1e1490..e7cb5ab3ff 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -185,7 +185,11 @@ def get_spark_ami(opts): "hi1.4xlarge": "hvm", "m3.xlarge": "hvm", "m3.2xlarge": "hvm", - "cr1.8xlarge": "hvm" + "cr1.8xlarge": "hvm", + "i2.xlarge": "hvm", + "i2.2xlarge": "hvm", + "i2.4xlarge": "hvm", + "i2.8xlarge": "hvm" } if opts.instance_type in instance_types: instance_type = instance_types[opts.instance_type] @@ -478,7 +482,11 @@ def get_num_disks(instance_type): "cr1.8xlarge": 2, "hi1.4xlarge": 2, "m3.xlarge": 0, - "m3.2xlarge": 0 + "m3.2xlarge": 0, + "i2.xlarge": 1, + "i2.2xlarge": 2, + "i2.4xlarge": 4, + "i2.8xlarge": 8 } if instance_type in disks_by_instance: return disks_by_instance[instance_type] diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java index 2e616b1ab2..349d826ab5 100644 --- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java @@ -48,7 +48,7 @@ public final class JavaNetworkWordCount { public static void main(String[] args) { if (args.length < 3) { - System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" + + System.err.println("Usage: JavaNetworkWordCount <master> <hostname> <port>\n" + "In local mode, <master> should be 'local[n]' with n > 1"); System.exit(1); } @@ -56,12 +56,12 @@ public final class JavaNetworkWordCount { StreamingExamples.setStreamingLogLevels(); // Create the context with a 1 second batch size - JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount", + JavaStreamingContext ssc = new JavaStreamingContext(args[0], "JavaNetworkWordCount", new Duration(1000), System.getenv("SPARK_HOME"), JavaStreamingContext.jarOfClass(JavaNetworkWordCount.class)); // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') + // words in input stream of \n delimited text (eg. generated by 'nc') JavaDStream<String> lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2])); JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() { @Override @@ -84,6 +84,5 @@ public final class JavaNetworkWordCount { wordCounts.print(); ssc.start(); - } } diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala index c12139b3ec..25f7013307 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala @@ -21,7 +21,8 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ /** - * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * Counts words in text encoded with UTF8 received from the network every second. + * * Usage: NetworkWordCount <master> <hostname> <port> * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1. * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data. @@ -46,7 +47,7 @@ object NetworkWordCount { System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass)) // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') + // words in input stream of \n delimited text (eg. generated by 'nc') val lines = ssc.socketTextStream(args(1), args(2).toInt) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala new file mode 100644 index 0000000000..d51e6e9418 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{Time, Seconds, StreamingContext} +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.util.IntParam +import java.io.File +import org.apache.spark.rdd.RDD +import com.google.common.io.Files +import java.nio.charset.Charset + +/** + * Counts words in text encoded with UTF8 received from the network every second. + * + * Usage: NetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-file> + * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1. + * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data. + * <checkpoint-directory> directory to HDFS-compatible file system which checkpoint data + * <output-file> file to which the word counts will be appended + * + * In local mode, <master> should be 'local[n]' with n > 1 + * <checkpoint-directory> and <output-file> must be absolute paths + * + * + * To run this on your local machine, you need to first run a Netcat server + * + * `$ nc -lk 9999` + * + * and run the example as + * + * `$ ./run-example org.apache.spark.streaming.examples.RecoverableNetworkWordCount \ + * local[2] localhost 9999 ~/checkpoint/ ~/out` + * + * If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create + * a new StreamingContext (will print "Creating new context" to the console). Otherwise, if + * checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from + * the checkpoint data. + * + * To run this example in a local standalone cluster with automatic driver recovery, + * + * `$ ./spark-class org.apache.spark.deploy.Client -s launch <cluster-url> <path-to-examples-jar> \ + * org.apache.spark.streaming.examples.RecoverableNetworkWordCount <cluster-url> \ + * localhost 9999 ~/checkpoint ~/out` + * + * <path-to-examples-jar> would typically be <spark-dir>/examples/target/scala-XX/spark-examples....jar + * + * Refer to the online documentation for more details. + */ + +object RecoverableNetworkWordCount { + + def createContext(master: String, ip: String, port: Int, outputPath: String) = { + + // If you do not see this printed, that means the StreamingContext has been loaded + // from the new checkpoint + println("Creating new context") + val outputFile = new File(outputPath) + if (outputFile.exists()) outputFile.delete() + + // Create the context with a 1 second batch size + val ssc = new StreamingContext(master, "RecoverableNetworkWordCount", Seconds(1), + System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass)) + + // Create a NetworkInputDStream on target ip:port and count the + // words in input stream of \n delimited text (eg. generated by 'nc') + val lines = ssc.socketTextStream(ip, port) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.foreach((rdd: RDD[(String, Int)], time: Time) => { + val counts = "Counts at time " + time + " " + rdd.collect().mkString("[", ", ", "]") + println(counts) + println("Appending to " + outputFile.getAbsolutePath) + Files.append(counts + "\n", outputFile, Charset.defaultCharset()) + }) + ssc + } + + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println("You arguments were " + args.mkString("[", ", ", "]")) + System.err.println( + """ + |Usage: RecoverableNetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-file> + | <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1. + | <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data. + | <checkpoint-directory> directory to HDFS-compatible file system which checkpoint data + | <output-file> file to which the word counts will be appended + | + |In local mode, <master> should be 'local[n]' with n > 1 + |Both <checkpoint-directory> and <output-file> must be absolute paths + """.stripMargin + ) + System.exit(1) + } + val Array(master, ip, IntParam(port), checkpointDirectory, outputPath) = args + val ssc = StreamingContext.getOrCreate(checkpointDirectory, + () => { + createContext(master, ip, port, outputPath) + }) + ssc.start() + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index ca0115f90e..1249ef4c3d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -24,10 +24,10 @@ import java.util.concurrent.RejectedExecutionException import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{SparkException, SparkConf, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.util.MetadataCleaner -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.streaming.scheduler.JobGenerator private[streaming] @@ -44,6 +44,10 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConf = ssc.conf + // These should be unset when a checkpoint is deserialized, + // otherwise the SparkContext won't initialize correctly. + sparkConf.remove("spark.hostPort").remove("spark.driver.host").remove("spark.driver.port") + def validate() { assert(master != null, "Checkpoint.master is null") assert(framework != null, "Checkpoint.framework is null") @@ -53,59 +57,119 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) } } +private[streaming] +object Checkpoint extends Logging { + val PREFIX = "checkpoint-" + val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r + + /** Get the checkpoint file for the given checkpoint time */ + def checkpointFile(checkpointDir: String, checkpointTime: Time) = { + new Path(checkpointDir, PREFIX + checkpointTime.milliseconds) + } + + /** Get the checkpoint backup file for the given checkpoint time */ + def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = { + new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk") + } + + /** Get checkpoint files present in the give directory, ordered by oldest-first */ + def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = { + def sortFunc(path1: Path, path2: Path): Boolean = { + val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } + val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } + (time1 < time2) || (time1 == time2 && bk1) + } + + val path = new Path(checkpointDir) + if (fs.exists(path)) { + val statuses = fs.listStatus(path) + if (statuses != null) { + val paths = statuses.map(_.getPath) + val filtered = paths.filter(p => REGEX.findFirstIn(p.toString).nonEmpty) + filtered.sortWith(sortFunc) + } else { + logWarning("Listing " + path + " returned null") + Seq.empty + } + } else { + logInfo("Checkpoint directory " + path + " does not exist") + Seq.empty + } + } +} + /** * Convenience class to handle the writing of graph checkpoint to file */ private[streaming] -class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Configuration) - extends Logging -{ - val file = new Path(checkpointDir, "graph") +class CheckpointWriter( + jobGenerator: JobGenerator, + conf: SparkConf, + checkpointDir: String, + hadoopConf: Configuration + ) extends Logging { val MAX_ATTEMPTS = 3 val executor = Executors.newFixedThreadPool(1) val compressionCodec = CompressionCodec.createCodec(conf) - // The file to which we actually write - and then "move" to file - val writeFile = new Path(file.getParent, file.getName + ".next") - // The file to which existing checkpoint is backed up (i.e. "moved") - val bakFile = new Path(file.getParent, file.getName + ".bk") - private var stopped = false private var fs_ : FileSystem = _ - // Removed code which validates whether there is only one CheckpointWriter per path 'file' since - // I did not notice any errors - reintroduce it ? class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { def run() { var attempts = 0 val startTime = System.currentTimeMillis() + val tempFile = new Path(checkpointDir, "temp") + val checkpointFile = Checkpoint.checkpointFile(checkpointDir, checkpointTime) + val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, checkpointTime) + while (attempts < MAX_ATTEMPTS && !stopped) { attempts += 1 try { - logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") - // This is inherently thread unsafe, so alleviating it by writing to '.new' and - // then moving it to the final file - val fos = fs.create(writeFile) + logInfo("Saving checkpoint for time " + checkpointTime + " to file '" + checkpointFile + "'") + + // Write checkpoint to temp file + fs.delete(tempFile, true) // just in case it exists + val fos = fs.create(tempFile) fos.write(bytes) fos.close() - if (fs.exists(file) && fs.rename(file, bakFile)) { - logDebug("Moved existing checkpoint file to " + bakFile) + + // If the checkpoint file exists, back it up + // If the backup exists as well, just delete it, otherwise rename will fail + if (fs.exists(checkpointFile)) { + fs.delete(backupFile, true) // just in case it exists + if (!fs.rename(checkpointFile, backupFile)) { + logWarning("Could not rename " + checkpointFile + " to " + backupFile) + } + } + + // Rename temp file to the final checkpoint file + if (!fs.rename(tempFile, checkpointFile)) { + logWarning("Could not rename " + tempFile + " to " + checkpointFile) + } + + // Delete old checkpoint files + val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs) + if (allCheckpointFiles.size > 4) { + allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { + logInfo("Deleting " + file) + fs.delete(file, true) + }) } - // paranoia - fs.delete(file, false) - fs.rename(writeFile, file) + // All done, print success val finishTime = System.currentTimeMillis() - logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file + - "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds") + logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile + + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms") + jobGenerator.onCheckpointCompletion(checkpointTime) return } catch { case ioe: IOException => - logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) + logWarning("Error in attempt " + attempts + " of writing checkpoint to " + checkpointFile, ioe) reset() } } - logError("Could not write checkpoint for time " + checkpointTime + " to file '" + file + "'") + logWarning("Could not write checkpoint for time " + checkpointTime + " to file " + checkpointFile + "'") } } @@ -118,6 +182,7 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi bos.close() try { executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray)) + logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue") } catch { case rej: RejectedExecutionException => logError("Could not submit checkpoint task to the thread pool executor", rej) @@ -140,7 +205,7 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi } private def fs = synchronized { - if (fs_ == null) fs_ = file.getFileSystem(hadoopConf) + if (fs_ == null) fs_ = new Path(checkpointDir).getFileSystem(hadoopConf) fs_ } @@ -153,43 +218,46 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi private[streaming] object CheckpointReader extends Logging { - def read(conf: SparkConf, path: String): Checkpoint = { - val fs = new Path(path).getFileSystem(new Configuration()) - val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), - new Path(path), new Path(path + ".bk")) + def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = { + val checkpointPath = new Path(checkpointDir) + def fs = checkpointPath.getFileSystem(hadoopConf) + + // Try to find the checkpoint files + val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse + if (checkpointFiles.isEmpty) { + return None + } + // Try to read the checkpoint files in the order + logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) val compressionCodec = CompressionCodec.createCodec(conf) - - attempts.foreach(file => { - if (fs.exists(file)) { - logInfo("Attempting to load checkpoint from file '" + file + "'") - try { - val fis = fs.open(file) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val zis = compressionCodec.compressedInputStream(fis) - val ois = new ObjectInputStreamWithLoader(zis, - Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp.validate() - logInfo("Checkpoint successfully loaded from file '" + file + "'") - logInfo("Checkpoint was generated at time " + cp.checkpointTime) - return cp - } catch { - case e: Exception => - logError("Error loading checkpoint from file '" + file + "'", e) - } - } else { - logWarning("Could not read checkpoint from file '" + file + "' as it does not exist") + checkpointFiles.foreach(file => { + logInfo("Attempting to load checkpoint from file " + file) + try { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val zis = compressionCodec.compressedInputStream(fis) + val ois = new ObjectInputStreamWithLoader(zis, + Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + logInfo("Checkpoint successfully loaded from file " + file) + logInfo("Checkpoint was generated at time " + cp.checkpointTime) + return Some(cp) + } catch { + case e: Exception => + logWarning("Error reading checkpoint from file " + file, e) } - }) - throw new Exception("Could not read checkpoint from path '" + path + "'") + + // If none of checkpoint files could be read, then throw exception + throw new SparkException("Failed to read checkpoint from directory " + checkpointPath) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala index 837f1ea1d8..b98f4a5101 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -329,13 +329,12 @@ abstract class DStream[T: ClassTag] ( * implementation clears the old generated RDDs. Subclasses of DStream may override * this to clear their own metadata along with the generated RDDs. */ - protected[streaming] def clearOldMetadata(time: Time) { - var numForgotten = 0 + protected[streaming] def clearMetadata(time: Time) { val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) generatedRDDs --= oldRDDs.keys logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " + (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", ")) - dependencies.foreach(_.clearOldMetadata(time)) + dependencies.foreach(_.clearMetadata(time)) } /* Adds metadata to the Stream while it is running. @@ -356,12 +355,18 @@ abstract class DStream[T: ClassTag] ( */ protected[streaming] def updateCheckpointData(currentTime: Time) { logInfo("Updating checkpoint data for time " + currentTime) - checkpointData.update() + checkpointData.update(currentTime) dependencies.foreach(_.updateCheckpointData(currentTime)) - checkpointData.cleanup() logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) } + protected[streaming] def clearCheckpointData(time: Time) { + logInfo("Clearing checkpoint data") + checkpointData.cleanup(time) + dependencies.foreach(_.clearCheckpointData(time)) + logInfo("Cleared checkpoint data") + } + /** * Restore the RDDs in generatedRDDs from the checkpointData. This is an internal method * that should not be called directly. This is a default implementation that recreates RDDs diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala index 3fd5d52403..671f7bbce7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala @@ -17,77 +17,86 @@ package org.apache.spark.streaming +import scala.collection.mutable.{HashMap, HashSet} +import scala.reflect.ClassTag + import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.conf.Configuration -import collection.mutable.HashMap import org.apache.spark.Logging -import scala.collection.mutable.HashMap -import scala.reflect.ClassTag - +import java.io.{ObjectInputStream, IOException} private[streaming] class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) extends Serializable with Logging { protected val data = new HashMap[Time, AnyRef]() - @transient private var fileSystem : FileSystem = null - @transient private var lastCheckpointFiles: HashMap[Time, String] = null + // Mapping of the batch time to the checkpointed RDD file of that time + @transient private var timeToCheckpointFile = new HashMap[Time, String] + // Mapping of the batch time to the time of the oldest checkpointed RDD + // in that batch's checkpoint data + @transient private var timeToOldestCheckpointFileTime = new HashMap[Time, Time] - protected[streaming] def checkpointFiles = data.asInstanceOf[HashMap[Time, String]] + @transient private var fileSystem : FileSystem = null + protected[streaming] def currentCheckpointFiles = data.asInstanceOf[HashMap[Time, String]] /** * Updates the checkpoint data of the DStream. This gets called every time * the graph checkpoint is initiated. Default implementation records the * checkpoint files to which the generate RDDs of the DStream has been saved. */ - def update() { + def update(time: Time) { // Get the checkpointed RDDs from the generated RDDs - val newCheckpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined) + val checkpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined) .map(x => (x._1, x._2.getCheckpointFile.get)) - - // Make a copy of the existing checkpoint data (checkpointed RDDs) - lastCheckpointFiles = checkpointFiles.clone() - - // If the new checkpoint data has checkpoints then replace existing with the new one - if (newCheckpointFiles.size > 0) { - checkpointFiles.clear() - checkpointFiles ++= newCheckpointFiles - } - - // TODO: remove this, this is just for debugging - newCheckpointFiles.foreach { - case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } + logDebug("Current checkpoint files:\n" + checkpointFiles.toSeq.mkString("\n")) + + // Add the checkpoint files to the data to be serialized + if (!checkpointFiles.isEmpty) { + currentCheckpointFiles.clear() + currentCheckpointFiles ++= checkpointFiles + // Add the current checkpoint files to the map of all checkpoint files + // This will be used to delete old checkpoint files + timeToCheckpointFile ++= currentCheckpointFiles + // Remember the time of the oldest checkpoint RDD in current state + timeToOldestCheckpointFileTime(time) = currentCheckpointFiles.keys.min(Time.ordering) } } /** - * Cleanup old checkpoint data. This gets called every time the graph - * checkpoint is initiated, but after `update` is called. Default - * implementation, cleans up old checkpoint files. + * Cleanup old checkpoint data. This gets called after a checkpoint of `time` has been + * written to the checkpoint directory. */ - def cleanup() { - // If there is at least on checkpoint file in the current checkpoint files, - // then delete the old checkpoint files. - if (checkpointFiles.size > 0 && lastCheckpointFiles != null) { - (lastCheckpointFiles -- checkpointFiles.keySet).foreach { - case (time, file) => { - try { - val path = new Path(file) - if (fileSystem == null) { - fileSystem = path.getFileSystem(new Configuration()) + def cleanup(time: Time) { + // Get the time of the oldest checkpointed RDD that was written as part of the + // checkpoint of `time` + timeToOldestCheckpointFileTime.remove(time) match { + case Some(lastCheckpointFileTime) => + // Find all the checkpointed RDDs (i.e. files) that are older than `lastCheckpointFileTime` + // This is because checkpointed RDDs older than this are not going to be needed + // even after master fails, as the checkpoint data of `time` does not refer to those files + val filesToDelete = timeToCheckpointFile.filter(_._1 < lastCheckpointFileTime) + logDebug("Files to delete:\n" + filesToDelete.mkString(",")) + filesToDelete.foreach { + case (time, file) => + try { + val path = new Path(file) + if (fileSystem == null) { + fileSystem = path.getFileSystem(dstream.ssc.sparkContext.hadoopConfiguration) + } + fileSystem.delete(path, true) + timeToCheckpointFile -= time + logInfo("Deleted checkpoint file '" + file + "' for time " + time) + } catch { + case e: Exception => + logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e) + fileSystem = null } - fileSystem.delete(path, true) - logInfo("Deleted checkpoint file '" + file + "' for time " + time) - } catch { - case e: Exception => - logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e) - } } - } + case None => + logInfo("Nothing to delete") } } @@ -98,7 +107,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) */ def restore() { // Create RDDs from the checkpoint data - checkpointFiles.foreach { + currentCheckpointFiles.foreach { case(time, file) => { logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file))) @@ -107,6 +116,13 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) } override def toString() = { - "[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]" + "[\n" + currentCheckpointFiles.size + " checkpoint files \n" + currentCheckpointFiles.mkString("\n") + "\n]" + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + timeToOldestCheckpointFileTime = new HashMap[Time, Time] + timeToCheckpointFile = new HashMap[Time, String] } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 62d07b22c6..eee9591ffc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -104,36 +104,44 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def getOutputStreams() = this.synchronized { outputStreams.toArray } def generateJobs(time: Time): Seq[Job] = { - this.synchronized { - logDebug("Generating jobs for time " + time) - val jobs = outputStreams.flatMap(outputStream => outputStream.generateJob(time)) - logDebug("Generated " + jobs.length + " jobs for time " + time) - jobs + logDebug("Generating jobs for time " + time) + val jobs = this.synchronized { + outputStreams.flatMap(outputStream => outputStream.generateJob(time)) } + logDebug("Generated " + jobs.length + " jobs for time " + time) + jobs } - def clearOldMetadata(time: Time) { + def clearMetadata(time: Time) { + logDebug("Clearing metadata for time " + time) this.synchronized { - logDebug("Clearing old metadata for time " + time) - outputStreams.foreach(_.clearOldMetadata(time)) - logDebug("Cleared old metadata for time " + time) + outputStreams.foreach(_.clearMetadata(time)) } + logDebug("Cleared old metadata for time " + time) } def updateCheckpointData(time: Time) { + logInfo("Updating checkpoint data for time " + time) this.synchronized { - logInfo("Updating checkpoint data for time " + time) outputStreams.foreach(_.updateCheckpointData(time)) - logInfo("Updated checkpoint data for time " + time) } + logInfo("Updated checkpoint data for time " + time) + } + + def clearCheckpointData(time: Time) { + logInfo("Clearing checkpoint data for time " + time) + this.synchronized { + outputStreams.foreach(_.clearCheckpointData(time)) + } + logInfo("Cleared checkpoint data for time " + time) } def restoreCheckpointData() { + logInfo("Restoring checkpoint data") this.synchronized { - logInfo("Restoring checkpoint data") outputStreams.foreach(_.restoreCheckpointData()) - logInfo("Restored checkpoint data") } + logInfo("Restored checkpoint data") } def validate() { @@ -146,8 +154,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { + logDebug("DStreamGraph.writeObject used") this.synchronized { - logDebug("DStreamGraph.writeObject used") checkpointInProgress = true oos.defaultWriteObject() checkpointInProgress = false @@ -156,8 +164,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream) { + logDebug("DStreamGraph.readObject used") this.synchronized { - logDebug("DStreamGraph.readObject used") checkpointInProgress = true ois.defaultReadObject() checkpointInProgress = false diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 693cb7fc30..dd34f6f4f2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -39,6 +39,7 @@ import org.apache.spark.util.MetadataCleaner import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receivers._ import org.apache.spark.streaming.scheduler._ +import org.apache.hadoop.conf.Configuration /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -88,10 +89,12 @@ class StreamingContext private ( /** * Re-create a StreamingContext from a checkpoint file. - * @param path Path either to the directory that was specified as the checkpoint directory, or - * to the checkpoint file 'graph' or 'graph.bk'. + * @param path Path to the directory that was specified as the checkpoint directory + * @param hadoopConf Optional, configuration object if necessary for reading from + * HDFS compatible filesystems */ - def this(path: String) = this(null, CheckpointReader.read(new SparkConf(), path), null) + def this(path: String, hadoopConf: Configuration = new Configuration) = + this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null) if (sc_ == null && cp_ == null) { throw new Exception("Spark Streaming cannot be initialized with " + @@ -171,8 +174,9 @@ class StreamingContext private ( /** * Set the context to periodically checkpoint the DStream operations for master - * fault-tolerance. The graph will be checkpointed every batch interval. - * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored + * fault-tolerance. + * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored. + * Note that this must be a fault-tolerant file system like HDFS for */ def checkpoint(directory: String) { if (directory != null) { @@ -461,26 +465,64 @@ class StreamingContext private ( } } +/** + * StreamingContext object contains a number of utility functions related to the + * StreamingContext class. + */ -object StreamingContext { +object StreamingContext extends Logging { implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = { new PairDStreamFunctions[K, V](stream) } /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the StreamingContext + * will be created by called the provided `creatingFunc`. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param creatingFunc Function to create a new StreamingContext + * @param hadoopConf Optional Hadoop configuration if necessary for reading from the + * file system + * @param createOnError Optional, whether to create a new StreamingContext if there is an + * error in reading checkpoint data. By default, an exception will be + * thrown on error. + */ + def getOrCreate( + checkpointPath: String, + creatingFunc: () => StreamingContext, + hadoopConf: Configuration = new Configuration(), + createOnError: Boolean = false + ): StreamingContext = { + val checkpointOption = try { + CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf) + } catch { + case e: Exception => + if (createOnError) { + None + } else { + throw e + } + } + checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc()) + } + + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass - * their JARs to SparkContext. + * their JARs to StreamingContext. */ def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls) + protected[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = { // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second batch intervals. - val sc = new SparkContext(conf) - if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) { - MetadataCleaner.setDelaySeconds(sc.conf, 3600) + if (MetadataCleaner.getDelaySeconds(conf) < 0) { + MetadataCleaner.setDelaySeconds(conf, 3600) } + val sc = new SparkContext(conf) sc } @@ -489,14 +531,17 @@ object StreamingContext { appName: String, sparkHome: String, jars: Seq[String], - environment: Map[String, String]): SparkContext = - { - val sc = new SparkContext(master, appName, sparkHome, jars, environment) + environment: Map[String, String] + ): SparkContext = { + + val conf = SparkContext.updatedConf( + new SparkConf(), master, appName, sparkHome, jars, environment) // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second batch intervals. - if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) { - MetadataCleaner.setDelaySeconds(sc.conf, 3600) + if (MetadataCleaner.getDelaySeconds(conf) < 0) { + MetadataCleaner.setDelaySeconds(conf, 3600) } + val sc = new SparkContext(master, appName, sparkHome, jars, environment) sc } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 7068f32517..523173d45a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -35,6 +35,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.scheduler.StreamingListener +import org.apache.hadoop.conf.Configuration /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -128,10 +129,16 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Re-creates a StreamingContext from a checkpoint file. - * @param path Path either to the directory that was specified as the checkpoint directory, or - * to the checkpoint file 'graph' or 'graph.bk'. + * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this (new StreamingContext(path)) + def this(path: String) = this(new StreamingContext(path)) + + /** + * Re-creates a StreamingContext from a checkpoint file. + * @param path Path to the directory that was specified as the checkpoint directory + * + */ + def this(path: String, hadoopConf: Configuration) = this(new StreamingContext(path, hadoopConf)) /** The underlying SparkContext */ val sc: JavaSparkContext = new JavaSparkContext(ssc.sc) @@ -471,20 +478,97 @@ class JavaStreamingContext(val ssc: StreamingContext) { } /** - * Starts the execution of the streams. + * Start the execution of the streams. */ def start() = ssc.start() /** - * Sstops the execution of the streams. + * Stop the execution of the streams. */ def stop() = ssc.stop() } +/** + * JavaStreamingContext object contains a number of utility functions. + */ object JavaStreamingContext { + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program + * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + */ + def getOrCreate( + checkpointPath: String, + factory: JavaStreamingContextFactory + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + factory.create.ssc + }) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + */ + def getOrCreate( + checkpointPath: String, + hadoopConf: Configuration, + factory: JavaStreamingContextFactory + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + factory.create.ssc + }, hadoopConf) + new JavaStreamingContext(ssc) + } + + /** + * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + * recreated from the checkpoint data. If the data does not exist, then the provided factory + * will be used to create a JavaStreamingContext. + * + * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program + * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext + * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible + * file system + * @param createOnError Whether to create a new JavaStreamingContext if there is an + * error in reading checkpoint data. + */ + def getOrCreate( + checkpointPath: String, + hadoopConf: Configuration, + factory: JavaStreamingContextFactory, + createOnError: Boolean + ): JavaStreamingContext = { + val ssc = StreamingContext.getOrCreate(checkpointPath, () => { + factory.create.ssc + }, hadoopConf, createOnError) + new JavaStreamingContext(ssc) + } + /** * Find the JAR from which a given class was loaded, to make it easy for users to pass - * their JARs to SparkContext. + * their JARs to StreamingContext. */ def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls).toArray } + +/** + * Factory interface for creating a new JavaStreamingContext + */ +trait JavaStreamingContextFactory { + def create(): JavaStreamingContext +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index fb9eda8996..1f0f31c4b1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -23,10 +23,10 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD import org.apache.spark.streaming.{DStreamCheckpointData, StreamingContext, Time} +import org.apache.spark.util.TimeStampedHashMap private[streaming] @@ -46,6 +46,8 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas @transient private var path_ : Path = null @transient private var fs_ : FileSystem = null @transient private[streaming] var files = new HashMap[Time, Array[String]] + @transient private var fileModTimes = new TimeStampedHashMap[String, Long](true) + @transient private var lastNewFileFindingTime = 0L override def start() { if (newFilesOnly) { @@ -88,14 +90,16 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas } /** Clear the old time-to-files mappings along with old RDDs */ - protected[streaming] override def clearOldMetadata(time: Time) { - super.clearOldMetadata(time) + protected[streaming] override def clearMetadata(time: Time) { + super.clearMetadata(time) val oldFiles = files.filter(_._1 <= (time - rememberDuration)) files --= oldFiles.keys logInfo("Cleared " + oldFiles.size + " old files that were older than " + (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) logDebug("Cleared files are:\n" + oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) + // Delete file mod times that weren't accessed in the last round of getting new files + fileModTimes.clearOldValues(lastNewFileFindingTime - 1) } /** @@ -104,8 +108,19 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas */ private def findNewFiles(currentTime: Long): (Seq[String], Long, Seq[String]) = { logDebug("Trying to get new files for time " + currentTime) + lastNewFileFindingTime = System.currentTimeMillis val filter = new CustomPathFilter(currentTime) - val newFiles = fs.listStatus(path, filter).map(_.getPath.toString) + val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString) + val timeTaken = System.currentTimeMillis - lastNewFileFindingTime + logInfo("Finding new files took " + timeTaken + " ms") + logDebug("# cached file times = " + fileModTimes.size) + if (timeTaken > slideDuration.milliseconds) { + logWarning( + "Time taken to find new files exceeds the batch size. " + + "Consider increasing the batch size or reduceing the number of " + + "files in the monitored directory." + ) + } (newFiles, filter.latestModTime, filter.latestModTimeFiles.toSeq) } @@ -122,16 +137,21 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas new UnionRDD(context.sparkContext, fileRDDs) } - private def path: Path = { + private def directoryPath: Path = { if (path_ == null) path_ = new Path(directory) path_ } private def fs: FileSystem = { - if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + if (fs_ == null) fs_ = directoryPath.getFileSystem(new Configuration()) fs_ } + private def getFileModTime(path: Path) = { + // Get file mod time from cache or fetch it from the file system + fileModTimes.getOrElseUpdate(path.toString, fs.getFileStatus(path).getModificationTime()) + } + private def reset() { fs_ = null } @@ -142,6 +162,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas ois.defaultReadObject() generatedRDDs = new HashMap[Time, RDD[(K,V)]] () files = new HashMap[Time, Array[String]] + fileModTimes = new TimeStampedHashMap[String, Long](true) } /** @@ -153,15 +174,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]] - override def update() { + override def update(time: Time) { hadoopFiles.clear() hadoopFiles ++= files } - override def cleanup() { } + override def cleanup(time: Time) { } override def restore() { - hadoopFiles.foreach { + hadoopFiles.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, f) => { // Restore the metadata in both files and generatedRDDs logInfo("Restoring files for time " + t + " - " + @@ -187,14 +208,13 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas // Latest file mod time seen in this round of fetching files and its corresponding files var latestModTime = 0L val latestModTimeFiles = new HashSet[String]() - def accept(path: Path): Boolean = { try { if (!filter(path)) { // Reject file if it does not satisfy filter logDebug("Rejected by filter " + path) return false } - val modTime = fs.getFileStatus(path).getModificationTime() + val modTime = getFileModTime(path) logDebug("Mod time for " + path + " is " + modTime) if (modTime < prevModTime) { logDebug("Mod time less than last mod time") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 3c624e8199..2fa6853ae0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -26,8 +26,9 @@ import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent -private[scheduler] case class ClearOldMetadata(time: Time) extends JobGeneratorEvent +private[scheduler] case class ClearMetadata(time: Time) extends JobGeneratorEvent private[scheduler] case class DoCheckpoint(time: Time) extends JobGeneratorEvent +private[scheduler] case class ClearCheckpointData(time: Time) extends JobGeneratorEvent /** * This class generates jobs from DStreams as well as drives checkpointing and cleaning @@ -53,7 +54,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, longTime => eventProcessorActor ! GenerateJobs(new Time(longTime))) lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { - new CheckpointWriter(ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) + new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) } else { null } @@ -77,15 +78,20 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { * On batch completion, clear old metadata and checkpoint computation. */ private[scheduler] def onBatchCompletion(time: Time) { - eventProcessorActor ! ClearOldMetadata(time) + eventProcessorActor ! ClearMetadata(time) + } + + private[streaming] def onCheckpointCompletion(time: Time) { + eventProcessorActor ! ClearCheckpointData(time) } /** Processes all events */ private def processEvent(event: JobGeneratorEvent) { event match { case GenerateJobs(time) => generateJobs(time) - case ClearOldMetadata(time) => clearOldMetadata(time) + case ClearMetadata(time) => clearMetadata(time) case DoCheckpoint(time) => doCheckpoint(time) + case ClearCheckpointData(time) => clearCheckpointData(time) } } @@ -115,14 +121,14 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val checkpointTime = ssc.initialCheckpoint.checkpointTime val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds)) val downTimes = checkpointTime.until(restartTime, batchDuration) - logInfo("Batches during down time: " + downTimes.mkString(", ")) + logInfo("Batches during down time (" + downTimes.size + " batches): " + downTimes.mkString(", ")) // Batches that were unprocessed before failure - val pendingTimes = ssc.initialCheckpoint.pendingTimes - logInfo("Batches pending processing: " + pendingTimes.mkString(", ")) + val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering) + logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", ")) // Reschedule jobs for these times val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) - logInfo("Batches to reschedule: " + timesToReschedule.mkString(", ")) + logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) timesToReschedule.foreach(time => jobScheduler.runJobs(time, graph.generateJobs(time)) ) @@ -141,11 +147,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } /** Clear DStream metadata for the given `time`. */ - private def clearOldMetadata(time: Time) { - ssc.graph.clearOldMetadata(time) + private def clearMetadata(time: Time) { + ssc.graph.clearMetadata(time) eventProcessorActor ! DoCheckpoint(time) } + /** Clear DStream checkpoint data for the given `time`. */ + private def clearCheckpointData(time: Time) { + ssc.graph.clearCheckpointData(time) + } + /** Perform checkpoint for the give `time`. */ private def doCheckpoint(time: Time) = synchronized { if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala index 1559f7a9f7..162b19d7f0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala @@ -42,6 +42,7 @@ object MasterFailureTest extends Logging { @volatile var killed = false @volatile var killCount = 0 + @volatile var setupCalled = false def main(args: Array[String]) { if (args.size < 2) { @@ -131,8 +132,26 @@ object MasterFailureTest extends Logging { // Just making sure that the expected output does not have duplicates assert(expectedOutput.distinct.toSet == expectedOutput.toSet) + // Reset all state + reset() + + // Create the directories for this test + val uuid = UUID.randomUUID().toString + val rootDir = new Path(directory, uuid) + val fs = rootDir.getFileSystem(new Configuration()) + val checkpointDir = new Path(rootDir, "checkpoint") + val testDir = new Path(rootDir, "test") + fs.mkdirs(checkpointDir) + fs.mkdirs(testDir) + // Setup the stream computation with the given operation - val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation) + val ssc = StreamingContext.getOrCreate(checkpointDir.toString, () => { + setupStreams(batchDuration, operation, checkpointDir, testDir) + }) + + // Check if setupStream was called to create StreamingContext + // (and not created from checkpoint file) + assert(setupCalled, "Setup was not called in the first call to StreamingContext.getOrCreate") // Start generating files in the a different thread val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds) @@ -144,9 +163,7 @@ object MasterFailureTest extends Logging { val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2 val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun) - // Delete directories fileGeneratingThread.join() - val fs = checkpointDir.getFileSystem(new Configuration()) fs.delete(checkpointDir, true) fs.delete(testDir, true) logInfo("Finished test after " + killCount + " failures") @@ -159,32 +176,24 @@ object MasterFailureTest extends Logging { * files should be written for testing. */ private def setupStreams[T: ClassTag]( - directory: String, batchDuration: Duration, - operation: DStream[String] => DStream[T] - ): (StreamingContext, Path, Path) = { - // Reset all state - reset() - - // Create the directories for this test - val uuid = UUID.randomUUID().toString - val rootDir = new Path(directory, uuid) - val fs = rootDir.getFileSystem(new Configuration()) - val checkpointDir = new Path(rootDir, "checkpoint") - val testDir = new Path(rootDir, "test") - fs.mkdirs(checkpointDir) - fs.mkdirs(testDir) + operation: DStream[String] => DStream[T], + checkpointDir: Path, + testDir: Path + ): StreamingContext = { + // Mark that setup was called + setupCalled = true // Setup the streaming computation with the given operation System.clearProperty("spark.driver.port") System.clearProperty("spark.hostPort") - var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) + val ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map()) ssc.checkpoint(checkpointDir.toString) val inputStream = ssc.textFileStream(testDir.toString) val operatedStream = operation(inputStream) val outputStream = new TestOutputStream(operatedStream) ssc.registerOutputStream(outputStream) - (ssc, checkpointDir, testDir) + ssc } @@ -204,7 +213,7 @@ object MasterFailureTest extends Logging { var isTimedOut = false val mergedOutput = new ArrayBuffer[T]() val checkpointDir = ssc.checkpointDir - var batchDuration = ssc.graph.batchDuration + val batchDuration = ssc.graph.batchDuration while(!isLastOutputGenerated && !isTimedOut) { // Get the output buffer @@ -261,7 +270,10 @@ object MasterFailureTest extends Logging { ) Thread.sleep(sleepTime) // Recreate the streaming context from checkpoint - ssc = new StreamingContext(checkpointDir) + ssc = StreamingContext.getOrCreate(checkpointDir, () => { + throw new Exception("Trying to create new context when it " + + "should be reading from checkpoint file") + }) } } mergedOutput @@ -297,6 +309,7 @@ object MasterFailureTest extends Logging { private def reset() { killed = false killCount = 0 + setupCalled = false } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 0d2145da9a..8b7d7709bf 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -28,6 +28,7 @@ import java.util.*; import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.io.Files; +import com.google.common.collect.Sets; import org.apache.spark.SparkConf; import org.apache.spark.HashPartitioner; @@ -441,13 +442,13 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa new Tuple2<String, String>("new york", "islanders"))); - List<List<Tuple2<String, Tuple2<String, String>>>> expected = Arrays.asList( - Arrays.asList( + List<HashSet<Tuple2<String, Tuple2<String, String>>>> expected = Arrays.asList( + Sets.newHashSet( new Tuple2<String, Tuple2<String, String>>("california", new Tuple2<String, String>("dodgers", "giants")), new Tuple2<String, Tuple2<String, String>>("new york", - new Tuple2<String, String>("yankees", "mets"))), - Arrays.asList( + new Tuple2<String, String>("yankees", "mets"))), + Sets.newHashSet( new Tuple2<String, Tuple2<String, String>>("california", new Tuple2<String, String>("sharks", "ducks")), new Tuple2<String, Tuple2<String, String>>("new york", @@ -482,8 +483,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa JavaTestUtils.attachTestOutputStream(joined); List<List<Tuple2<String, Tuple2<String, String>>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List<HashSet<Tuple2<String, Tuple2<String, String>>>> unorderedResult = Lists.newArrayList(); + for (List<Tuple2<String, Tuple2<String, String>>> res: result) { + unorderedResult.add(Sets.newHashSet(res)); + } - Assert.assertEquals(expected, result); + Assert.assertEquals(expected, unorderedResult); } @@ -1196,15 +1201,15 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa Arrays.asList("hello", "moon"), Arrays.asList("hello")); - List<List<Tuple2<String, Long>>> expected = Arrays.asList( - Arrays.asList( + List<HashSet<Tuple2<String, Long>>> expected = Arrays.asList( + Sets.newHashSet( new Tuple2<String, Long>("hello", 1L), new Tuple2<String, Long>("world", 1L)), - Arrays.asList( + Sets.newHashSet( new Tuple2<String, Long>("hello", 2L), new Tuple2<String, Long>("world", 1L), new Tuple2<String, Long>("moon", 1L)), - Arrays.asList( + Sets.newHashSet( new Tuple2<String, Long>("hello", 2L), new Tuple2<String, Long>("moon", 1L))); @@ -1214,8 +1219,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 3, 3); + List<HashSet<Tuple2<String, Long>>> unorderedResult = Lists.newArrayList(); + for (List<Tuple2<String, Long>> res: result) { + unorderedResult.add(Sets.newHashSet(res)); + } - Assert.assertEquals(expected, result); + Assert.assertEquals(expected, unorderedResult); } @Test diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 8dc80ac2ed..6499de98c9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -84,9 +84,9 @@ class CheckpointSuite extends TestSuiteBase { ssc.start() advanceTimeWithRealDelay(ssc, firstNumBatches) logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData) - assert(!stateStream.checkpointData.checkpointFiles.isEmpty, + assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure") - stateStream.checkpointData.checkpointFiles.foreach { + stateStream.checkpointData.currentCheckpointFiles.foreach { case (time, file) => { assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") @@ -95,7 +95,7 @@ class CheckpointSuite extends TestSuiteBase { // Run till a further time such that previous checkpoint files in the stream would be deleted // and check whether the earlier checkpoint files are deleted - val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2)) + val checkpointFiles = stateStream.checkpointData.currentCheckpointFiles.map(x => new File(x._2)) advanceTimeWithRealDelay(ssc, secondNumBatches) checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) @@ -114,9 +114,9 @@ class CheckpointSuite extends TestSuiteBase { // is present in the checkpoint data or not ssc.start() advanceTimeWithRealDelay(ssc, 1) - assert(!stateStream.checkpointData.checkpointFiles.isEmpty, + assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure") - stateStream.checkpointData.checkpointFiles.foreach { + stateStream.checkpointData.currentCheckpointFiles.foreach { case (time, file) => { assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 2bb11e54c5..2e46d750c4 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -127,14 +127,13 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // local dirs, so lets check both. We assume one of the 2 is set. // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(System.getenv("LOCAL_DIRS")) - .getOrElse("")) - - if (localDirs.isEmpty()) { - throw new Exception("Yarn Local dirs can't be empty") + .orElse(Option(System.getenv("LOCAL_DIRS"))) + + localDirs match { + case None => throw new Exception("Yarn Local dirs can't be empty") + case Some(l) => l } - localDirs - } + } private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index ddfec1a4ac..62b20b8fba 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -76,6 +76,10 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar def run() { + // Setup the directories so things go to yarn approved directories rather + // then user specified and /tmp. + System.setProperty("spark.local.dir", getLocalDirs()) + appAttemptId = getApplicationAttemptId() resourceManager = registerWithResourceManager() val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() @@ -103,10 +107,12 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - // must be <= timeoutInterval/ 2. - // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. - // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. - val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L)) + // we want to be reasonably responsive without causing too many requests to RM. + val schedulerInterval = + System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong + // must be <= timeoutInterval / 2. + val interval = math.min(timeoutInterval / 2, schedulerInterval) + reporterThread = launchReporterThread(interval) // Wait for the reporter thread to Finish. @@ -119,6 +125,20 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar System.exit(0) } + /** Get the Yarn approved local directories. */ + private def getLocalDirs(): String = { + // Hadoop 0.23 and 2.x have different Environment variable names for the + // local dirs, so lets check both. We assume one of the 2 is set. + // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X + val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) + .orElse(Option(System.getenv("LOCAL_DIRS"))) + + localDirs match { + case None => throw new Exception("Yarn Local dirs can't be empty") + case Some(l) => l + } + } + private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 4b1b5da048..22e55e0c60 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -22,6 +22,8 @@ import org.apache.spark.{SparkException, Logging, SparkContext} import org.apache.spark.deploy.yarn.{Client, ClientArguments} import org.apache.spark.scheduler.TaskSchedulerImpl +import scala.collection.mutable.ArrayBuffer + private[spark] class YarnClientSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) @@ -31,45 +33,47 @@ private[spark] class YarnClientSchedulerBackend( var client: Client = null var appId: ApplicationId = null + private[spark] def addArg(optionName: String, optionalParam: String, arrayBuf: ArrayBuffer[String]) { + Option(System.getenv(optionalParam)) foreach { + optParam => { + arrayBuf += (optionName, optParam) + } + } + } + override def start() { super.start() - val defalutWorkerCores = "2" - val defalutWorkerMemory = "512m" - val defaultWorkerNumber = "1" - val userJar = System.getenv("SPARK_YARN_APP_JAR") - val distFiles = System.getenv("SPARK_YARN_DIST_FILES") - var workerCores = System.getenv("SPARK_WORKER_CORES") - var workerMemory = System.getenv("SPARK_WORKER_MEMORY") - var workerNumber = System.getenv("SPARK_WORKER_INSTANCES") - if (userJar == null) throw new SparkException("env SPARK_YARN_APP_JAR is not set") - if (workerCores == null) - workerCores = defalutWorkerCores - if (workerMemory == null) - workerMemory = defalutWorkerMemory - if (workerNumber == null) - workerNumber = defaultWorkerNumber - val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort - val argsArray = Array[String]( + val argsArrayBuf = new ArrayBuffer[String]() + argsArrayBuf += ( "--class", "notused", "--jar", userJar, "--args", hostport, - "--worker-memory", workerMemory, - "--worker-cores", workerCores, - "--num-workers", workerNumber, - "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher", - "--files", distFiles + "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher" ) - val args = new ClientArguments(argsArray, conf) + // process any optional arguments, use the defaults already defined in ClientArguments + // if things aren't specified + Map("--master-memory" -> "SPARK_MASTER_MEMORY", + "--num-workers" -> "SPARK_WORKER_INSTANCES", + "--worker-memory" -> "SPARK_WORKER_MEMORY", + "--worker-cores" -> "SPARK_WORKER_CORES", + "--queue" -> "SPARK_YARN_QUEUE", + "--name" -> "SPARK_YARN_APP_NAME", + "--files" -> "SPARK_YARN_DIST_FILES", + "--archives" -> "SPARK_YARN_DIST_ARCHIVES") + .foreach { case (optName, optParam) => addArg(optName, optParam, argsArrayBuf) } + + logDebug("ClientArguments called with: " + argsArrayBuf) + val args = new ClientArguments(argsArrayBuf.toArray, conf) client = new Client(args, conf) appId = client.runApp() waitForApp() diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 69ae14ce83..4b777d5fa7 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -116,14 +116,13 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, // local dirs, so lets check both. We assume one of the 2 is set. // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(System.getenv("LOCAL_DIRS")) - .getOrElse("")) - - if (localDirs.isEmpty()) { - throw new Exception("Yarn Local dirs can't be empty") + .orElse(Option(System.getenv("LOCAL_DIRS"))) + + localDirs match { + case None => throw new Exception("Yarn Local dirs can't be empty") + case Some(l) => l } - localDirs - } + } private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index be323d7783..952e963389 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -99,6 +99,7 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf) appContext.setApplicationName(args.appName) appContext.setQueue(args.amQueue) appContext.setAMContainerSpec(amContainer) + appContext.setApplicationType("SPARK") // Memory for the ApplicationMaster. val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index 49248a8516..78353224fa 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -78,6 +78,10 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar def run() { + // Setup the directories so things go to yarn approved directories rather + // then user specified and /tmp. + System.setProperty("spark.local.dir", getLocalDirs()) + amClient = AMRMClient.createAMRMClient() amClient.init(yarnConf) amClient.start() @@ -94,10 +98,12 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - // must be <= timeoutInterval/ 2. - // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. - // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. - val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval / 10, 60000L)) + // we want to be reasonably responsive without causing too many requests to RM. + val schedulerInterval = + System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong + // must be <= timeoutInterval / 2. + val interval = math.min(timeoutInterval / 2, schedulerInterval) + reporterThread = launchReporterThread(interval) // Wait for the reporter thread to Finish. @@ -110,6 +116,20 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar System.exit(0) } + /** Get the Yarn approved local directories. */ + private def getLocalDirs(): String = { + // Hadoop 0.23 and 2.x have different Environment variable names for the + // local dirs, so lets check both. We assume one of the 2 is set. + // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X + val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) + .orElse(Option(System.getenv("LOCAL_DIRS"))) + + localDirs match { + case None => throw new Exception("Yarn Local dirs can't be empty") + case Some(l) => l + } + } + private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name()) |