aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--conf/log4j.properties.template4
-rw-r--r--core/src/main/resources/org/apache/spark/log4j-defaults.properties4
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala61
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala88
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/Vector.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala (renamed from core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala)116
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala350
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala101
-rw-r--r--core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala120
-rw-r--r--core/src/test/scala/org/apache/spark/util/VectorSuite.scala44
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala (renamed from core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala)46
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala230
-rw-r--r--docs/configuration.md23
-rw-r--r--docs/running-on-yarn.md15
-rwxr-xr-xec2/spark_ec2.py12
-rw-r--r--examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java7
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala5
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala118
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala188
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStream.scala15
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala106
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala38
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala75
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala96
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala42
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala31
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala55
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java29
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala10
-rw-r--r--yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala13
-rw-r--r--yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala28
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala50
-rw-r--r--yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala13
-rw-r--r--yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala1
-rw-r--r--yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala28
45 files changed, 1894 insertions, 356 deletions
diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template
index 17d1978dde..f7f8535594 100644
--- a/conf/log4j.properties.template
+++ b/conf/log4j.properties.template
@@ -5,5 +5,7 @@ log4j.appender.console.target=System.err
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
-# Ignore messages below warning level from Jetty, because it's a bit verbose
+# Settings to quiet third party logs that are too verbose
log4j.logger.org.eclipse.jetty=WARN
+log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
+log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
index 17d1978dde..f7f8535594 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
@@ -5,5 +5,7 @@ log4j.appender.console.target=System.err
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
-# Ignore messages below warning level from Jetty, because it's a bit verbose
+# Settings to quiet third party logs that are too verbose
log4j.logger.org.eclipse.jetty=WARN
+log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
+log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 1a2ec55876..8b30cd4bfe 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -17,7 +17,7 @@
package org.apache.spark
-import org.apache.spark.util.AppendOnlyMap
+import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap}
/**
* A set of functions used to aggregate data.
@@ -31,30 +31,51 @@ case class Aggregator[K, V, C] (
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
+ private val sparkConf = SparkEnv.get.conf
+ private val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true)
+
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
- val combiners = new AppendOnlyMap[K, C]
- var kv: Product2[K, V] = null
- val update = (hadValue: Boolean, oldValue: C) => {
- if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
- }
- while (iter.hasNext) {
- kv = iter.next()
- combiners.changeValue(kv._1, update)
+ if (!externalSorting) {
+ val combiners = new AppendOnlyMap[K,C]
+ var kv: Product2[K, V] = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
+ }
+ while (iter.hasNext) {
+ kv = iter.next()
+ combiners.changeValue(kv._1, update)
+ }
+ combiners.iterator
+ } else {
+ val combiners =
+ new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ while (iter.hasNext) {
+ val (k, v) = iter.next()
+ combiners.insert(k, v)
+ }
+ combiners.iterator
}
- combiners.iterator
}
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
- val combiners = new AppendOnlyMap[K, C]
- var kc: (K, C) = null
- val update = (hadValue: Boolean, oldValue: C) => {
- if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
+ if (!externalSorting) {
+ val combiners = new AppendOnlyMap[K,C]
+ var kc: Product2[K, C] = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
+ }
+ while (iter.hasNext) {
+ kc = iter.next()
+ combiners.changeValue(kc._1, update)
+ }
+ combiners.iterator
+ } else {
+ val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
+ while (iter.hasNext) {
+ val (k, c) = iter.next()
+ combiners.insert(k, c)
+ }
+ combiners.iterator
}
- while (iter.hasNext) {
- kc = iter.next()
- combiners.changeValue(kc._1, update)
- }
- combiners.iterator
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 66c226e491..139048d5c7 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -677,10 +677,10 @@ class SparkContext(
key = uri.getScheme match {
// A JAR file which exists only on the driver node
case null | "file" =>
- if (SparkHadoopUtil.get.isYarnMode()) {
- // In order for this to work on yarn the user must specify the --addjars option to
- // the client to upload the file into the distributed cache to make it show up in the
- // current working directory.
+ if (SparkHadoopUtil.get.isYarnMode() && master == "yarn-standalone") {
+ // In order for this to work in yarn standalone mode the user must specify the
+ // --addjars option to the client to upload the file into the distributed cache
+ // of the AM to make it show up in the current working directory.
val fileName = new Path(uri.getPath).getName()
try {
env.httpFileServer.addJar(new File(fileName))
@@ -1086,7 +1086,7 @@ object SparkContext {
* parameters that are passed as the default value of null, instead of throwing an exception
* like SparkConf would.
*/
- private def updatedConf(
+ private[spark] def updatedConf(
conf: SparkConf,
master: String,
appName: String,
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index e093e2f162..08b592df71 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -54,7 +54,11 @@ class SparkEnv private[spark] (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
- val conf: SparkConf) {
+ val conf: SparkConf) extends Logging {
+
+ // A mapping of thread ID to amount of memory used for shuffle in bytes
+ // All accesses should be manually synchronized
+ val shuffleMemoryMap = mutable.HashMap[Long, Long]()
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index e51d274d33..a7b2328a02 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -279,6 +279,11 @@ private[spark] class Executor(
//System.exit(1)
}
} finally {
+ // TODO: Unregister shuffle memory only for ShuffleMapTask
+ val shuffleMemoryMap = env.shuffleMemoryMap
+ shuffleMemoryMap.synchronized {
+ shuffleMemoryMap.remove(Thread.currentThread().getId)
+ }
runningTasks.remove(taskId)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 4ba4696fef..a73714abca 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -23,8 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
-import org.apache.spark.util.AppendOnlyMap
-
+import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -44,14 +43,12 @@ private[spark] case class NarrowCoGroupSplitDep(
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
-private[spark]
-class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
+private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
}
-
/**
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
* tuple with the list of values for that key.
@@ -62,6 +59,14 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
+ // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs).
+ // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner.
+ // CoGroupValue is the intermediate state of each value before being merged in compute.
+ private type CoGroup = ArrayBuffer[Any]
+ private type CoGroupValue = (Any, Int) // Int is dependency number
+ private type CoGroupCombiner = Seq[CoGroup]
+
+ private val sparkConf = SparkEnv.get.conf
private var serializerClass: String = null
def setSerializer(cls: String): CoGroupedRDD[K] = {
@@ -100,37 +105,74 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
override val partitioner = Some(part)
- override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
+ override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = {
+ val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true)
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
- // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
- val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
- val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
- if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any])
- }
-
- val getSeq = (k: K) => {
- map.changeValue(k, update)
- }
-
- val ser = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)
+ // A list of (rdd iterator, dependency number) pairs
+ val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
- rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv =>
- getSeq(kv._1)(depNum) += kv._2
- }
+ val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
+ rddIterators += ((it, depNum))
}
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
- kv => getSeq(kv._1)(depNum) += kv._2
+ val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf)
+ val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
+ rddIterators += ((it, depNum))
+ }
+ }
+
+ if (!externalSorting) {
+ val map = new AppendOnlyMap[K, CoGroupCombiner]
+ val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => {
+ if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup)
+ }
+ val getCombiner: K => CoGroupCombiner = key => {
+ map.changeValue(key, update)
+ }
+ rddIterators.foreach { case (it, depNum) =>
+ while (it.hasNext) {
+ val kv = it.next()
+ getCombiner(kv._1)(depNum) += kv._2
}
}
+ new InterruptibleIterator(context, map.iterator)
+ } else {
+ val map = createExternalMap(numRdds)
+ rddIterators.foreach { case (it, depNum) =>
+ while (it.hasNext) {
+ val kv = it.next()
+ map.insert(kv._1, new CoGroupValue(kv._2, depNum))
+ }
+ }
+ new InterruptibleIterator(context, map.iterator)
+ }
+ }
+
+ private def createExternalMap(numRdds: Int)
+ : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = {
+
+ val createCombiner: (CoGroupValue => CoGroupCombiner) = value => {
+ val newCombiner = Array.fill(numRdds)(new CoGroup)
+ value match { case (v, depNum) => newCombiner(depNum) += v }
+ newCombiner
}
- new InterruptibleIterator(context, map.iterator)
+ val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner =
+ (combiner, value) => {
+ value match { case (v, depNum) => combiner(depNum) += v }
+ combiner
+ }
+ val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner =
+ (combiner1, combiner2) => {
+ combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 }
+ }
+ new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner](
+ createCombiner, mergeValue, mergeCombiners)
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index c118ddfc01..1248409e35 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -99,8 +99,6 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
}, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
- // A sanity check to make sure mergeCombiners is not defined.
- assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
values.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
@@ -267,8 +265,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
// into a hash table, leading to more objects in the old gen.
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
+ def mergeCombiners(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = c1 ++ c2
val bufs = combineByKey[ArrayBuffer[V]](
- createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false)
+ createCombiner _, mergeValue _, mergeCombiners _, partitioner, mapSideCombine=false)
bufs.asInstanceOf[RDD[(K, Seq[V])]]
}
@@ -339,7 +338,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* existing partitioner/parallelism level.
*/
def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C)
- : RDD[(K, C)] = {
+ : RDD[(K, C)] = {
combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 043e01dbfb..38b536023b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -106,7 +106,7 @@ class DAGScheduler(
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
- val RESUBMIT_TIMEOUT = 50.milliseconds
+ val RESUBMIT_TIMEOUT = 200.milliseconds
// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
@@ -196,7 +196,7 @@ class DAGScheduler(
*/
def receive = {
case event: DAGSchedulerEvent =>
- logDebug("Got event of type " + event.getClass.getName)
+ logTrace("Got event of type " + event.getClass.getName)
/**
* All events are forwarded to `processEvent()`, so that the event processing logic can
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 7156d855d8..301d784b35 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -17,12 +17,14 @@
package org.apache.spark.storage
+import java.util.UUID
+
/**
* Identifies a particular Block of data, usually associated with a single file.
* A Block can be uniquely identified by its filename, but each type of Block has a different
* set of keys which produce its unique name.
*
- * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method.
+ * If your BlockId should be serializable, be sure to add it to the BlockId.apply() method.
*/
private[spark] sealed abstract class BlockId {
/** A globally unique identifier for this Block. Can be used for ser/de. */
@@ -55,7 +57,8 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def name = "broadcast_" + broadcastId
}
-private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
+private[spark]
+case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
def name = broadcastId.name + "_" + hType
}
@@ -67,6 +70,11 @@ private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends B
def name = "input-" + streamId + "-" + uniqueId
}
+/** Id associated with temporary data managed as blocks. Not serializable. */
+private[spark] case class TempBlockId(id: UUID) extends BlockId {
+ def name = "temp_" + id
+}
+
// Intended only for testing purposes
private[spark] case class TestBlockId(id: String) extends BlockId {
def name = "test_" + id
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index c56e2ca2df..ff9f241fc1 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -159,7 +159,7 @@ private[spark] class BlockManager(
/**
* Reregister with the master and report all blocks to it. This will be called by the heart beat
- * thread if our heartbeat to the block amnager indicates that we were not registered.
+ * thread if our heartbeat to the block manager indicates that we were not registered.
*
* Note that this method must be called without any BlockInfo locks held.
*/
@@ -864,7 +864,7 @@ private[spark] object BlockManager extends Logging {
val ID_GENERATOR = new IdGenerator
def getMaxMemory(conf: SparkConf): Long = {
- val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.66)
+ val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6)
(Runtime.getRuntime.maxMemory * memoryFraction).toLong
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 61e63c60d5..369a277232 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -181,4 +181,8 @@ class DiskBlockObjectWriter(
// Only valid if called after close()
override def timeWriting() = _timeWriting
+
+ def bytesWritten: Long = {
+ lastValidPosition - initialPosition
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index edc1133172..a8ef7fa8b6 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -19,7 +19,7 @@ package org.apache.spark.storage
import java.io.File
import java.text.SimpleDateFormat
-import java.util.{Date, Random}
+import java.util.{Date, Random, UUID}
import org.apache.spark.Logging
import org.apache.spark.executor.ExecutorExitCode
@@ -90,6 +90,15 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
def getFile(blockId: BlockId): File = getFile(blockId.name)
+ /** Produces a unique block id and File suitable for intermediate results. */
+ def createTempBlock(): (TempBlockId, File) = {
+ var blockId = new TempBlockId(UUID.randomUUID())
+ while (getFile(blockId).exists()) {
+ blockId = new TempBlockId(UUID.randomUUID())
+ }
+ (blockId, getFile(blockId))
+ }
+
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index 181ae2fd45..8e07a0f29a 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -26,16 +26,23 @@ import org.apache.spark.Logging
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
- * time stamp along with each key-value pair. Key-value pairs that are older than a particular
- * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in
- * replacement of scala.collection.mutable.HashMap.
+ * timestamp along with each key-value pair. If specified, the timestamp of each pair can be
+ * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular
+ * threshold time can then be removed using the clearOldValues method. This is intended to
+ * be a drop-in replacement of scala.collection.mutable.HashMap.
+ * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be
+ * updated when it is accessed
*/
-class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging {
+class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
+ extends Map[A, B]() with Logging {
val internalMap = new ConcurrentHashMap[A, (B, Long)]()
def get(key: A): Option[B] = {
val value = internalMap.get(key)
- if (value != null) Some(value._1) else None
+ if (value != null && updateTimeStampOnGet) {
+ internalMap.replace(key, value, (value._1, currentTime))
+ }
+ Option(value).map(_._1)
}
def iterator: Iterator[(A, B)] = {
diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala
index fe710c58ac..62fd6d8da5 100644
--- a/core/src/main/scala/org/apache/spark/util/Vector.scala
+++ b/core/src/main/scala/org/apache/spark/util/Vector.scala
@@ -17,6 +17,8 @@
package org.apache.spark.util
+import scala.util.Random
+
class Vector(val elements: Array[Double]) extends Serializable {
def length = elements.length
@@ -124,6 +126,12 @@ object Vector {
def ones(length: Int) = Vector(length, _ => 1)
+ /**
+ * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers
+ * between 0.0 and 1.0. Optional [[scala.util.Random]] number generator can be provided.
+ */
+ def random(length: Int, random: Random = new XORShiftRandom()) = Vector(length, _ => random.nextDouble())
+
class Multiplier(num: Double) {
def * (vec: Vector) = vec * num
}
diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
index 8bb4ee3bfa..d98c7aa3d7 100644
--- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
@@ -15,7 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.util
+package org.apache.spark.util.collection
+
+import java.util.{Arrays, Comparator}
/**
* A simple open hash table optimized for the append-only use case, where keys
@@ -28,14 +30,15 @@ package org.apache.spark.util
* TODO: Cache the hash values of each key? java.util.HashMap does that.
*/
private[spark]
-class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable {
+class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K,
+ V)] with Serializable {
require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
require(initialCapacity >= 1, "Invalid initial capacity")
private var capacity = nextPowerOf2(initialCapacity)
private var mask = capacity - 1
private var curSize = 0
- private var growThreshold = LOAD_FACTOR * capacity
+ private var growThreshold = (LOAD_FACTOR * capacity).toInt
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
@@ -45,10 +48,15 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
private var haveNullValue = false
private var nullValue: V = null.asInstanceOf[V]
+ // Triggered by destructiveSortedIterator; the underlying data array may no longer be used
+ private var destroyed = false
+ private val destructionMessage = "Map state is invalid from destructive sorting!"
+
private val LOAD_FACTOR = 0.7
/** Get the value for a given key */
def apply(key: K): V = {
+ assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
return nullValue
@@ -72,6 +80,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
/** Set the value for a key */
def update(key: K, value: V): Unit = {
+ assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
@@ -106,6 +115,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
* for key, if any, or null otherwise. Returns the newly updated value.
*/
def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
+ assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
@@ -139,35 +149,38 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
}
/** Iterator method from Iterable */
- override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
- var pos = -1
-
- /** Get the next value we should return from next(), or null if we're finished iterating */
- def nextValue(): (K, V) = {
- if (pos == -1) { // Treat position -1 as looking at the null value
- if (haveNullValue) {
- return (null.asInstanceOf[K], nullValue)
+ override def iterator: Iterator[(K, V)] = {
+ assert(!destroyed, destructionMessage)
+ new Iterator[(K, V)] {
+ var pos = -1
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def nextValue(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
}
- pos += 1
- }
- while (pos < capacity) {
- if (!data(2 * pos).eq(null)) {
- return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ while (pos < capacity) {
+ if (!data(2 * pos).eq(null)) {
+ return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ }
+ pos += 1
}
- pos += 1
+ null
}
- null
- }
- override def hasNext: Boolean = nextValue() != null
+ override def hasNext: Boolean = nextValue() != null
- override def next(): (K, V) = {
- val value = nextValue()
- if (value == null) {
- throw new NoSuchElementException("End of iterator")
+ override def next(): (K, V) = {
+ val value = nextValue()
+ if (value == null) {
+ throw new NoSuchElementException("End of iterator")
+ }
+ pos += 1
+ value
}
- pos += 1
- value
}
}
@@ -190,7 +203,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
}
/** Double the table's size and re-hash everything */
- private def growTable() {
+ protected def growTable() {
val newCapacity = capacity * 2
if (newCapacity >= (1 << 30)) {
// We can't make the table this big because we want an array of 2x
@@ -227,11 +240,58 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
data = newData
capacity = newCapacity
mask = newMask
- growThreshold = LOAD_FACTOR * newCapacity
+ growThreshold = (LOAD_FACTOR * newCapacity).toInt
}
private def nextPowerOf2(n: Int): Int = {
val highBit = Integer.highestOneBit(n)
if (highBit == n) n else highBit << 1
}
+
+ /**
+ * Return an iterator of the map in sorted order. This provides a way to sort the map without
+ * using additional memory, at the expense of destroying the validity of the map.
+ */
+ def destructiveSortedIterator(cmp: Comparator[(K, V)]): Iterator[(K, V)] = {
+ destroyed = true
+ // Pack KV pairs into the front of the underlying array
+ var keyIndex, newIndex = 0
+ while (keyIndex < capacity) {
+ if (data(2 * keyIndex) != null) {
+ data(newIndex) = (data(2 * keyIndex), data(2 * keyIndex + 1))
+ newIndex += 1
+ }
+ keyIndex += 1
+ }
+ assert(curSize == newIndex + (if (haveNullValue) 1 else 0))
+
+ // Sort by the given ordering
+ val rawOrdering = new Comparator[AnyRef] {
+ def compare(x: AnyRef, y: AnyRef): Int = {
+ cmp.compare(x.asInstanceOf[(K, V)], y.asInstanceOf[(K, V)])
+ }
+ }
+ Arrays.sort(data, 0, newIndex, rawOrdering)
+
+ new Iterator[(K, V)] {
+ var i = 0
+ var nullValueReady = haveNullValue
+ def hasNext: Boolean = (i < newIndex || nullValueReady)
+ def next(): (K, V) = {
+ if (nullValueReady) {
+ nullValueReady = false
+ (null.asInstanceOf[K], nullValue)
+ } else {
+ val item = data(i).asInstanceOf[(K, V)]
+ i += 1
+ item
+ }
+ }
+ }
+ }
+
+ /**
+ * Return whether the next insert will cause the map to grow
+ */
+ def atGrowThreshold: Boolean = curSize == growThreshold
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
new file mode 100644
index 0000000000..e3bcd895aa
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -0,0 +1,350 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io._
+import java.util.Comparator
+
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter}
+
+/**
+ * An append-only map that spills sorted content to disk when there is insufficient space for it
+ * to grow.
+ *
+ * This map takes two passes over the data:
+ *
+ * (1) Values are merged into combiners, which are sorted and spilled to disk as necessary
+ * (2) Combiners are read from disk and merged together
+ *
+ * The setting of the spill threshold faces the following trade-off: If the spill threshold is
+ * too high, the in-memory map may occupy more memory than is available, resulting in OOM.
+ * However, if the spill threshold is too low, we spill frequently and incur unnecessary disk
+ * writes. This may lead to a performance regression compared to the normal case of using the
+ * non-spilling AppendOnlyMap.
+ *
+ * Two parameters control the memory threshold:
+ *
+ * `spark.shuffle.memoryFraction` specifies the collective amount of memory used for storing
+ * these maps as a fraction of the executor's total memory. Since each concurrently running
+ * task maintains one map, the actual threshold for each map is this quantity divided by the
+ * number of running tasks.
+ *
+ * `spark.shuffle.safetyFraction` specifies an additional margin of safety as a fraction of
+ * this threshold, in case map size estimation is not sufficiently accurate.
+ */
+
+private[spark] class ExternalAppendOnlyMap[K, V, C](
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C,
+ serializer: Serializer = SparkEnv.get.serializerManager.default,
+ diskBlockManager: DiskBlockManager = SparkEnv.get.blockManager.diskBlockManager)
+ extends Iterable[(K, C)] with Serializable with Logging {
+
+ import ExternalAppendOnlyMap._
+
+ private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
+ private val spilledMaps = new ArrayBuffer[DiskMapIterator]
+ private val sparkConf = SparkEnv.get.conf
+
+ // Collective memory threshold shared across all running tasks
+ private val maxMemoryThreshold = {
+ val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.3)
+ val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8)
+ (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
+ }
+
+ // Number of pairs in the in-memory map
+ private var numPairsInMemory = 0
+
+ // Number of in-memory pairs inserted before tracking the map's shuffle memory usage
+ private val trackMemoryThreshold = 1000
+
+ // How many times we have spilled so far
+ private var spillCount = 0
+
+ private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
+ private val syncWrites = sparkConf.getBoolean("spark.shuffle.sync", false)
+ private val comparator = new KCComparator[K, C]
+ private val ser = serializer.newInstance()
+
+ /**
+ * Insert the given key and value into the map.
+ *
+ * If the underlying map is about to grow, check if the global pool of shuffle memory has
+ * enough room for this to happen. If so, allocate the memory required to grow the map;
+ * otherwise, spill the in-memory map to disk.
+ *
+ * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked.
+ */
+ def insert(key: K, value: V) {
+ val update: (Boolean, C) => C = (hadVal, oldVal) => {
+ if (hadVal) mergeValue(oldVal, value) else createCombiner(value)
+ }
+ if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) {
+ val mapSize = currentMap.estimateSize()
+ var shouldSpill = false
+ val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+
+ // Atomically check whether there is sufficient memory in the global pool for
+ // this map to grow and, if possible, allocate the required amount
+ shuffleMemoryMap.synchronized {
+ val threadId = Thread.currentThread().getId
+ val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
+ val availableMemory = maxMemoryThreshold -
+ (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
+
+ // Assume map growth factor is 2x
+ shouldSpill = availableMemory < mapSize * 2
+ if (!shouldSpill) {
+ shuffleMemoryMap(threadId) = mapSize * 2
+ }
+ }
+ // Do not synchronize spills
+ if (shouldSpill) {
+ spill(mapSize)
+ }
+ }
+ currentMap.changeValue(key, update)
+ numPairsInMemory += 1
+ }
+
+ /**
+ * Sort the existing contents of the in-memory map and spill them to a temporary file on disk
+ */
+ private def spill(mapSize: Long) {
+ spillCount += 1
+ logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)"
+ .format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
+ val (blockId, file) = diskBlockManager.createTempBlock()
+ val writer =
+ new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize, identity, syncWrites)
+ try {
+ val it = currentMap.destructiveSortedIterator(comparator)
+ while (it.hasNext) {
+ val kv = it.next()
+ writer.write(kv)
+ }
+ writer.commit()
+ } finally {
+ // Partial failures cannot be tolerated; do not revert partial writes
+ writer.close()
+ }
+ currentMap = new SizeTrackingAppendOnlyMap[K, C]
+ spilledMaps.append(new DiskMapIterator(file))
+
+ // Reset the amount of shuffle memory used by this map in the global pool
+ val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+ shuffleMemoryMap.synchronized {
+ shuffleMemoryMap(Thread.currentThread().getId) = 0
+ }
+ numPairsInMemory = 0
+ }
+
+ /**
+ * Return an iterator that merges the in-memory map with the spilled maps.
+ * If no spill has occurred, simply return the in-memory map's iterator.
+ */
+ override def iterator: Iterator[(K, C)] = {
+ if (spilledMaps.isEmpty) {
+ currentMap.iterator
+ } else {
+ new ExternalIterator()
+ }
+ }
+
+ /**
+ * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps
+ */
+ private class ExternalIterator extends Iterator[(K, C)] {
+
+ // A fixed-size queue that maintains a buffer for each stream we are currently merging
+ val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
+
+ // Input streams are derived both from the in-memory map and spilled maps on disk
+ // The in-memory map is sorted in place, while the spilled maps are already in sorted order
+ val sortedMap = currentMap.destructiveSortedIterator(comparator)
+ val inputStreams = Seq(sortedMap) ++ spilledMaps
+
+ inputStreams.foreach { it =>
+ val kcPairs = getMorePairs(it)
+ mergeHeap.enqueue(StreamBuffer(it, kcPairs))
+ }
+
+ /**
+ * Fetch from the given iterator until a key of different hash is retrieved. In the
+ * event of key hash collisions, this ensures no pairs are hidden from being merged.
+ * Assume the given iterator is in sorted order.
+ */
+ def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = {
+ val kcPairs = new ArrayBuffer[(K, C)]
+ if (it.hasNext) {
+ var kc = it.next()
+ kcPairs += kc
+ val minHash = kc._1.hashCode()
+ while (it.hasNext && kc._1.hashCode() == minHash) {
+ kc = it.next()
+ kcPairs += kc
+ }
+ }
+ kcPairs
+ }
+
+ /**
+ * If the given buffer contains a value for the given key, merge that value into
+ * baseCombiner and remove the corresponding (K, C) pair from the buffer
+ */
+ def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = {
+ var i = 0
+ while (i < buffer.pairs.size) {
+ val (k, c) = buffer.pairs(i)
+ if (k == key) {
+ buffer.pairs.remove(i)
+ return mergeCombiners(baseCombiner, c)
+ }
+ i += 1
+ }
+ baseCombiner
+ }
+
+ /**
+ * Return true if there exists an input stream that still has unvisited pairs
+ */
+ override def hasNext: Boolean = mergeHeap.exists(!_.pairs.isEmpty)
+
+ /**
+ * Select a key with the minimum hash, then combine all values with the same key from all input streams.
+ */
+ override def next(): (K, C) = {
+ // Select a key from the StreamBuffer that holds the lowest key hash
+ val minBuffer = mergeHeap.dequeue()
+ val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash)
+ if (minPairs.length == 0) {
+ // Should only happen when no other stream buffers have any pairs left
+ throw new NoSuchElementException
+ }
+ var (minKey, minCombiner) = minPairs.remove(0)
+ assert(minKey.hashCode() == minHash)
+
+ // For all other streams that may have this key (i.e. have the same minimum key hash),
+ // merge in the corresponding value (if any) from that stream
+ val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer)
+ while (!mergeHeap.isEmpty && mergeHeap.head.minKeyHash == minHash) {
+ val newBuffer = mergeHeap.dequeue()
+ minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer)
+ mergedBuffers += newBuffer
+ }
+
+ // Repopulate each visited stream buffer and add it back to the merge heap
+ mergedBuffers.foreach { buffer =>
+ if (buffer.pairs.length == 0) {
+ buffer.pairs ++= getMorePairs(buffer.iterator)
+ }
+ mergeHeap.enqueue(buffer)
+ }
+
+ (minKey, minCombiner)
+ }
+
+ /**
+ * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash.
+ * Each buffer maintains the lowest-ordered keys in the corresponding iterator. Due to
+ * hash collisions, it is possible for multiple keys to be "tied" for being the lowest.
+ *
+ * StreamBuffers are ordered by the minimum key hash found across all of their own pairs.
+ */
+ case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)])
+ extends Comparable[StreamBuffer] {
+
+ def minKeyHash: Int = {
+ if (pairs.length > 0){
+ // pairs are already sorted by key hash
+ pairs(0)._1.hashCode()
+ } else {
+ Int.MaxValue
+ }
+ }
+
+ override def compareTo(other: StreamBuffer): Int = {
+ // minus sign because mutable.PriorityQueue dequeues the max, not the min
+ -minKeyHash.compareTo(other.minKeyHash)
+ }
+ }
+ }
+
+ /**
+ * An iterator that returns (K, C) pairs in sorted order from an on-disk map
+ */
+ private class DiskMapIterator(file: File) extends Iterator[(K, C)] {
+ val fileStream = new FileInputStream(file)
+ val bufferedStream = new FastBufferedInputStream(fileStream)
+ val deserializeStream = ser.deserializeStream(bufferedStream)
+ var nextItem: (K, C) = null
+ var eof = false
+
+ def readNextItem(): (K, C) = {
+ if (!eof) {
+ try {
+ return deserializeStream.readObject().asInstanceOf[(K, C)]
+ } catch {
+ case e: EOFException =>
+ eof = true
+ cleanup()
+ }
+ }
+ null
+ }
+
+ override def hasNext: Boolean = {
+ if (nextItem == null) {
+ nextItem = readNextItem()
+ }
+ nextItem != null
+ }
+
+ override def next(): (K, C) = {
+ val item = if (nextItem == null) readNextItem() else nextItem
+ if (item == null) {
+ throw new NoSuchElementException
+ }
+ nextItem = null
+ item
+ }
+
+ // TODO: Ensure this gets called even if the iterator isn't drained.
+ def cleanup() {
+ deserializeStream.close()
+ file.delete()
+ }
+ }
+}
+
+private[spark] object ExternalAppendOnlyMap {
+ private class KCComparator[K, C] extends Comparator[(K, C)] {
+ def compare(kc1: (K, C), kc2: (K, C)): Int = {
+ kc1._1.hashCode().compareTo(kc2._1.hashCode())
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
new file mode 100644
index 0000000000..204330dad4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.util.SizeEstimator
+import org.apache.spark.util.collection.SizeTrackingAppendOnlyMap.Sample
+
+/**
+ * Append-only map that keeps track of its estimated size in bytes.
+ * We sample with a slow exponential back-off using the SizeEstimator to amortize the time,
+ * as each call to SizeEstimator can take a sizable amount of time (order of a few milliseconds).
+ */
+private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] {
+
+ /**
+ * Controls the base of the exponential which governs the rate of sampling.
+ * E.g., a value of 2 would mean we sample at 1, 2, 4, 8, ... elements.
+ */
+ private val SAMPLE_GROWTH_RATE = 1.1
+
+ /** All samples taken since last resetSamples(). Only the last two are used for extrapolation. */
+ private val samples = new ArrayBuffer[Sample]()
+
+ /** Total number of insertions and updates into the map since the last resetSamples(). */
+ private var numUpdates: Long = _
+
+ /** The value of 'numUpdates' at which we will take our next sample. */
+ private var nextSampleNum: Long = _
+
+ /** The average number of bytes per update between our last two samples. */
+ private var bytesPerUpdate: Double = _
+
+ resetSamples()
+
+ /** Called after the map grows in size, as this can be a dramatic change for small objects. */
+ def resetSamples() {
+ numUpdates = 1
+ nextSampleNum = 1
+ samples.clear()
+ takeSample()
+ }
+
+ override def update(key: K, value: V): Unit = {
+ super.update(key, value)
+ numUpdates += 1
+ if (nextSampleNum == numUpdates) { takeSample() }
+ }
+
+ override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
+ val newValue = super.changeValue(key, updateFunc)
+ numUpdates += 1
+ if (nextSampleNum == numUpdates) { takeSample() }
+ newValue
+ }
+
+ /** Takes a new sample of the current map's size. */
+ def takeSample() {
+ samples += Sample(SizeEstimator.estimate(this), numUpdates)
+ // Only use the last two samples to extrapolate. If fewer than 2 samples, assume no change.
+ bytesPerUpdate = math.max(0, samples.toSeq.reverse match {
+ case latest :: previous :: tail =>
+ (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates)
+ case _ =>
+ 0
+ })
+ nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong
+ }
+
+ override protected def growTable() {
+ super.growTable()
+ resetSamples()
+ }
+
+ /** Estimates the current size of the map in bytes. O(1) time. */
+ def estimateSize(): Long = {
+ assert(samples.nonEmpty)
+ val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates)
+ (samples.last.size + extrapolatedDelta).toLong
+ }
+}
+
+private object SizeTrackingAppendOnlyMap {
+ case class Sample(size: Long, numUpdates: Long)
+}
diff --git a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala
new file mode 100644
index 0000000000..93f0c6a8e6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.util.Random
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.util.SizeTrackingAppendOnlyMapSuite.LargeDummyClass
+import org.apache.spark.util.collection.{AppendOnlyMap, SizeTrackingAppendOnlyMap}
+
+class SizeTrackingAppendOnlyMapSuite extends FunSuite with BeforeAndAfterAll {
+ val NORMAL_ERROR = 0.20
+ val HIGH_ERROR = 0.30
+
+ test("fixed size insertions") {
+ testWith[Int, Long](10000, i => (i, i.toLong))
+ testWith[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong)))
+ testWith[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass()))
+ }
+
+ test("variable size insertions") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testWith[Int, String](10000, i => (i, randString(0, 10)))
+ testWith[Int, String](10000, i => (i, randString(0, 100)))
+ testWith[Int, String](10000, i => (i, randString(90, 100)))
+ }
+
+ test("updates") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testWith[String, Int](10000, i => (randString(0, 10000), i))
+ }
+
+ def testWith[K, V](numElements: Int, makeElement: (Int) => (K, V)) {
+ val map = new SizeTrackingAppendOnlyMap[K, V]()
+ for (i <- 0 until numElements) {
+ val (k, v) = makeElement(i)
+ map(k) = v
+ expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR)
+ }
+ }
+
+ def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) {
+ val betterEstimatedSize = SizeEstimator.estimate(obj)
+ assert(betterEstimatedSize * (1 - error) < estimatedSize,
+ s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize")
+ assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize,
+ s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize")
+ }
+}
+
+object SizeTrackingAppendOnlyMapSuite {
+ // Speed test, for reproducibility of results.
+ // These could be highly non-deterministic in general, however.
+ // Results:
+ // AppendOnlyMap: 31 ms
+ // SizeTracker: 54 ms
+ // SizeEstimator: 1500 ms
+ def main(args: Array[String]) {
+ val numElements = 100000
+
+ val baseTimes = for (i <- 0 until 10) yield time {
+ val map = new AppendOnlyMap[Int, LargeDummyClass]()
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass()
+ }
+ }
+
+ val sampledTimes = for (i <- 0 until 10) yield time {
+ val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]()
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass()
+ map.estimateSize()
+ }
+ }
+
+ val unsampledTimes = for (i <- 0 until 3) yield time {
+ val map = new AppendOnlyMap[Int, LargeDummyClass]()
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass()
+ SizeEstimator.estimate(map)
+ }
+ }
+
+ println("Base: " + baseTimes)
+ println("SizeTracker (sampled): " + sampledTimes)
+ println("SizeEstimator (unsampled): " + unsampledTimes)
+ }
+
+ def time(f: => Unit): Long = {
+ val start = System.currentTimeMillis()
+ f
+ System.currentTimeMillis() - start
+ }
+
+ private class LargeDummyClass {
+ val arr = new Array[Int](100)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala
new file mode 100644
index 0000000000..7006571ef0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+/**
+ * Tests org.apache.spark.util.Vector functionality
+ */
+class VectorSuite extends FunSuite {
+
+ def verifyVector(vector: Vector, expectedLength: Int) = {
+ assert(vector.length == expectedLength)
+ assert(vector.elements.min > 0.0)
+ assert(vector.elements.max < 1.0)
+ }
+
+ test("random with default random number generator") {
+ val vector100 = Vector.random(100)
+ verifyVector(vector100, 100)
+ }
+
+ test("random with given random number generator") {
+ val vector100 = Vector.random(100, new Random(100))
+ verifyVector(vector100, 100)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala
index 7177919a58..f44442f1a5 100644
--- a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.util
+package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
+import java.util.Comparator
class AppendOnlyMapSuite extends FunSuite {
test("initialization") {
@@ -151,4 +152,47 @@ class AppendOnlyMapSuite extends FunSuite {
assert(map("" + i) === "" + i)
}
}
+
+ test("destructive sort") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ map.update(null, "happy new year!")
+
+ try {
+ map.apply("1")
+ map.update("1", "2013")
+ map.changeValue("1", (hadValue, oldValue) => "2014")
+ map.iterator
+ } catch {
+ case e: IllegalStateException => fail()
+ }
+
+ val it = map.destructiveSortedIterator(new Comparator[(String, String)] {
+ def compare(kv1: (String, String), kv2: (String, String)): Int = {
+ val x = if (kv1 != null && kv1._1 != null) kv1._1.toInt else Int.MinValue
+ val y = if (kv2 != null && kv2._1 != null) kv2._1.toInt else Int.MinValue
+ x.compareTo(y)
+ }
+ })
+
+ // Should be sorted by key
+ assert(it.hasNext)
+ var previous = it.next()
+ assert(previous == (null, "happy new year!"))
+ previous = it.next()
+ assert(previous == ("1", "2014"))
+ while (it.hasNext) {
+ val kv = it.next()
+ assert(kv._1.toInt > previous._1.toInt)
+ previous = kv
+ }
+
+ // All subsequent calls to apply, update, changeValue and iterator should throw exception
+ intercept[AssertionError] { map.apply("1") }
+ intercept[AssertionError] { map.update("1", "2013") }
+ intercept[AssertionError] { map.changeValue("1", (hadValue, oldValue) => "2014") }
+ intercept[AssertionError] { map.iterator }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
new file mode 100644
index 0000000000..ef957bb0e5
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -0,0 +1,230 @@
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+
+class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ override def beforeEach() {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.externalSorting", "true")
+ sc = new SparkContext("local", "test", conf)
+ }
+
+ val createCombiner: (Int => ArrayBuffer[Int]) = i => ArrayBuffer[Int](i)
+ val mergeValue: (ArrayBuffer[Int], Int) => ArrayBuffer[Int] = (buffer, i) => {
+ buffer += i
+ }
+ val mergeCombiners: (ArrayBuffer[Int], ArrayBuffer[Int]) => ArrayBuffer[Int] =
+ (buf1, buf2) => {
+ buf1 ++= buf2
+ }
+
+ test("simple insert") {
+ val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+
+ // Single insert
+ map.insert(1, 10)
+ var it = map.iterator
+ assert(it.hasNext)
+ val kv = it.next()
+ assert(kv._1 == 1 && kv._2 == ArrayBuffer[Int](10))
+ assert(!it.hasNext)
+
+ // Multiple insert
+ map.insert(2, 20)
+ map.insert(3, 30)
+ it = map.iterator
+ assert(it.hasNext)
+ assert(it.toSet == Set[(Int, ArrayBuffer[Int])](
+ (1, ArrayBuffer[Int](10)),
+ (2, ArrayBuffer[Int](20)),
+ (3, ArrayBuffer[Int](30))))
+ }
+
+ test("insert with collision") {
+ val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+
+ map.insert(1, 10)
+ map.insert(2, 20)
+ map.insert(3, 30)
+ map.insert(1, 100)
+ map.insert(2, 200)
+ map.insert(1, 1000)
+ val it = map.iterator
+ assert(it.hasNext)
+ val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
+ assert(result == Set[(Int, Set[Int])](
+ (1, Set[Int](10, 100, 1000)),
+ (2, Set[Int](20, 200)),
+ (3, Set[Int](30))))
+ }
+
+ test("ordering") {
+ val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map1.insert(1, 10)
+ map1.insert(2, 20)
+ map1.insert(3, 30)
+
+ val map2 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map2.insert(2, 20)
+ map2.insert(3, 30)
+ map2.insert(1, 10)
+
+ val map3 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map3.insert(3, 30)
+ map3.insert(1, 10)
+ map3.insert(2, 20)
+
+ val it1 = map1.iterator
+ val it2 = map2.iterator
+ val it3 = map3.iterator
+
+ var kv1 = it1.next()
+ var kv2 = it2.next()
+ var kv3 = it3.next()
+ assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
+ assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+
+ kv1 = it1.next()
+ kv2 = it2.next()
+ kv3 = it3.next()
+ assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
+ assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+
+ kv1 = it1.next()
+ kv2 = it2.next()
+ kv3 = it3.next()
+ assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
+ assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+ }
+
+ test("null keys and values") {
+ val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map.insert(1, 5)
+ map.insert(2, 6)
+ map.insert(3, 7)
+ assert(map.size === 3)
+ assert(map.iterator.toSet == Set[(Int, Seq[Int])](
+ (1, Seq[Int](5)),
+ (2, Seq[Int](6)),
+ (3, Seq[Int](7))
+ ))
+
+ // Null keys
+ val nullInt = null.asInstanceOf[Int]
+ map.insert(nullInt, 8)
+ assert(map.size === 4)
+ assert(map.iterator.toSet == Set[(Int, Seq[Int])](
+ (1, Seq[Int](5)),
+ (2, Seq[Int](6)),
+ (3, Seq[Int](7)),
+ (nullInt, Seq[Int](8))
+ ))
+
+ // Null values
+ map.insert(4, nullInt)
+ map.insert(nullInt, nullInt)
+ assert(map.size === 5)
+ val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
+ assert(result == Set[(Int, Set[Int])](
+ (1, Set[Int](5)),
+ (2, Set[Int](6)),
+ (3, Set[Int](7)),
+ (4, Set[Int](nullInt)),
+ (nullInt, Set[Int](nullInt, 8))
+ ))
+ }
+
+ test("simple aggregator") {
+ // reduceByKey
+ val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1))
+ val result1 = rdd.reduceByKey(_+_).collect()
+ assert(result1.toSet == Set[(Int, Int)]((0, 5), (1, 5)))
+
+ // groupByKey
+ val result2 = rdd.groupByKey().collect()
+ assert(result2.toSet == Set[(Int, Seq[Int])]
+ ((0, ArrayBuffer[Int](1, 1, 1, 1, 1)), (1, ArrayBuffer[Int](1, 1, 1, 1, 1))))
+ }
+
+ test("simple cogroup") {
+ val rdd1 = sc.parallelize(1 to 4).map(i => (i, i))
+ val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i))
+ val result = rdd1.cogroup(rdd2).collect()
+
+ result.foreach { case (i, (seq1, seq2)) =>
+ i match {
+ case 0 => assert(seq1.toSet == Set[Int]() && seq2.toSet == Set[Int](2, 4))
+ case 1 => assert(seq1.toSet == Set[Int](1) && seq2.toSet == Set[Int](1, 3))
+ case 2 => assert(seq1.toSet == Set[Int](2) && seq2.toSet == Set[Int]())
+ case 3 => assert(seq1.toSet == Set[Int](3) && seq2.toSet == Set[Int]())
+ case 4 => assert(seq1.toSet == Set[Int](4) && seq2.toSet == Set[Int]())
+ }
+ }
+ }
+
+ test("spilling") {
+ // TODO: Figure out correct memory parameters to actually induce spilling
+ // System.setProperty("spark.shuffle.buffer.mb", "1")
+ // System.setProperty("spark.shuffle.buffer.fraction", "0.05")
+
+ // reduceByKey - should spill exactly 6 times
+ val rddA = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ val resultA = rddA.reduceByKey(math.max(_, _)).collect()
+ assert(resultA.length == 5000)
+ resultA.foreach { case(k, v) =>
+ k match {
+ case 0 => assert(v == 1)
+ case 2500 => assert(v == 5001)
+ case 4999 => assert(v == 9999)
+ case _ =>
+ }
+ }
+
+ // groupByKey - should spill exactly 11 times
+ val rddB = sc.parallelize(0 until 10000).map(i => (i/4, i))
+ val resultB = rddB.groupByKey().collect()
+ assert(resultB.length == 2500)
+ resultB.foreach { case(i, seq) =>
+ i match {
+ case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3))
+ case 1250 => assert(seq.toSet == Set[Int](5000, 5001, 5002, 5003))
+ case 2499 => assert(seq.toSet == Set[Int](9996, 9997, 9998, 9999))
+ case _ =>
+ }
+ }
+
+ // cogroup - should spill exactly 7 times
+ val rddC1 = sc.parallelize(0 until 1000).map(i => (i, i))
+ val rddC2 = sc.parallelize(0 until 1000).map(i => (i%100, i))
+ val resultC = rddC1.cogroup(rddC2).collect()
+ assert(resultC.length == 1000)
+ resultC.foreach { case(i, (seq1, seq2)) =>
+ i match {
+ case 0 =>
+ assert(seq1.toSet == Set[Int](0))
+ assert(seq2.toSet == Set[Int](0, 100, 200, 300, 400, 500, 600, 700, 800, 900))
+ case 500 =>
+ assert(seq1.toSet == Set[Int](500))
+ assert(seq2.toSet == Set[Int]())
+ case 999 =>
+ assert(seq1.toSet == Set[Int](999))
+ assert(seq2.toSet == Set[Int]())
+ case _ =>
+ }
+ }
+ }
+
+ // TODO: Test memory allocation for multiple concurrently running tasks
+}
diff --git a/docs/configuration.md b/docs/configuration.md
index b1a0e19167..ad75e06fc7 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -104,14 +104,25 @@ Apart from these, the following properties are also available, and may be useful
</tr>
<tr>
<td>spark.storage.memoryFraction</td>
- <td>0.66</td>
+ <td>0.6</td>
<td>
Fraction of Java heap to use for Spark's memory cache. This should not be larger than the "old"
- generation of objects in the JVM, which by default is given 2/3 of the heap, but you can increase
+ generation of objects in the JVM, which by default is given 0.6 of the heap, but you can increase
it if you configure your own old generation size.
</td>
</tr>
<tr>
+ <td>spark.shuffle.memoryFraction</td>
+ <td>0.3</td>
+ <td>
+ Fraction of Java heap to use for aggregation and cogroups during shuffles, if
+ <code>spark.shuffle.externalSorting</code> is enabled. At any given time, the collective size of
+ all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will
+ begin to spill to disk. If spills are often, consider increasing this value at the expense of
+ <code>spark.storage.memoryFraction</code>.
+ </td>
+</tr>
+<tr>
<td>spark.mesos.coarse</td>
<td>false</td>
<td>
@@ -377,6 +388,14 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
+ <td>spark.shuffle.externalSorting</td>
+ <td>true</td>
+ <td>
+ If set to "true", limits the amount of memory used during reduces by spilling data out to disk. This spilling
+ threshold is specified by <code>spark.shuffle.memoryFraction</code>.
+ </td>
+</tr>
+<tr>
<td>spark.speculation</td>
<td>false</td>
<td>
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index b206270107..3bd62646ba 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -101,7 +101,19 @@ With this mode, your application is actually run on the remote machine where the
With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR
-In order to tune worker core/number/memory etc. You need to export SPARK_WORKER_CORES, SPARK_WORKER_MEMORY, SPARK_WORKER_INSTANCES e.g. by ./conf/spark-env.sh
+Configuration in yarn-client mode:
+
+In order to tune worker core/number/memory etc. You need to export environment variables or add them to the spark configuration file (./conf/spark_env.sh). The following are the list of options.
+
+* `SPARK_YARN_APP_JAR`, Path to your application's JAR file (required)
+* `SPARK_WORKER_INSTANCES`, Number of workers to start (Default: 2)
+* `SPARK_WORKER_CORES`, Number of cores for the workers (Default: 1).
+* `SPARK_WORKER_MEMORY`, Memory per Worker (e.g. 1000M, 2G) (Default: 1G)
+* `SPARK_MASTER_MEMORY`, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)
+* `SPARK_YARN_APP_NAME`, The name of your application (Default: Spark)
+* `SPARK_YARN_QUEUE`, The hadoop queue to use for allocation requests (Default: 'default')
+* `SPARK_YARN_DIST_FILES`, Comma separated list of files to be distributed with the job.
+* `SPARK_YARN_DIST_ARCHIVES`, Comma separated list of archives to be distributed with the job.
For example:
@@ -114,7 +126,6 @@ For example:
SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
MASTER=yarn-client ./bin/spark-shell
-You can also send extra files to yarn cluster for worker to use by exporting SPARK_YARN_DIST_FILES=file1,file2... etc.
# Building Spark for Hadoop/YARN 2.2.x
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index d82a1e1490..e7cb5ab3ff 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -185,7 +185,11 @@ def get_spark_ami(opts):
"hi1.4xlarge": "hvm",
"m3.xlarge": "hvm",
"m3.2xlarge": "hvm",
- "cr1.8xlarge": "hvm"
+ "cr1.8xlarge": "hvm",
+ "i2.xlarge": "hvm",
+ "i2.2xlarge": "hvm",
+ "i2.4xlarge": "hvm",
+ "i2.8xlarge": "hvm"
}
if opts.instance_type in instance_types:
instance_type = instance_types[opts.instance_type]
@@ -478,7 +482,11 @@ def get_num_disks(instance_type):
"cr1.8xlarge": 2,
"hi1.4xlarge": 2,
"m3.xlarge": 0,
- "m3.2xlarge": 0
+ "m3.2xlarge": 0,
+ "i2.xlarge": 1,
+ "i2.2xlarge": 2,
+ "i2.4xlarge": 4,
+ "i2.8xlarge": 8
}
if instance_type in disks_by_instance:
return disks_by_instance[instance_type]
diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java
index 2e616b1ab2..349d826ab5 100644
--- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java
@@ -48,7 +48,7 @@ public final class JavaNetworkWordCount {
public static void main(String[] args) {
if (args.length < 3) {
- System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" +
+ System.err.println("Usage: JavaNetworkWordCount <master> <hostname> <port>\n" +
"In local mode, <master> should be 'local[n]' with n > 1");
System.exit(1);
}
@@ -56,12 +56,12 @@ public final class JavaNetworkWordCount {
StreamingExamples.setStreamingLogLevels();
// Create the context with a 1 second batch size
- JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount",
+ JavaStreamingContext ssc = new JavaStreamingContext(args[0], "JavaNetworkWordCount",
new Duration(1000), System.getenv("SPARK_HOME"),
JavaStreamingContext.jarOfClass(JavaNetworkWordCount.class));
// Create a NetworkInputDStream on target ip:port and count the
- // words in input stream of \n delimited test (eg. generated by 'nc')
+ // words in input stream of \n delimited text (eg. generated by 'nc')
JavaDStream<String> lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2]));
JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
@Override
@@ -84,6 +84,5 @@ public final class JavaNetworkWordCount {
wordCounts.print();
ssc.start();
-
}
}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala
index c12139b3ec..25f7013307 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala
@@ -21,7 +21,8 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext._
/**
- * Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ * Counts words in text encoded with UTF8 received from the network every second.
+ *
* Usage: NetworkWordCount <master> <hostname> <port>
* <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
* <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
@@ -46,7 +47,7 @@ object NetworkWordCount {
System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass))
// Create a NetworkInputDStream on target ip:port and count the
- // words in input stream of \n delimited test (eg. generated by 'nc')
+ // words in input stream of \n delimited text (eg. generated by 'nc')
val lines = ssc.socketTextStream(args(1), args(2).toInt)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala
new file mode 100644
index 0000000000..d51e6e9418
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.examples
+
+import org.apache.spark.streaming.{Time, Seconds, StreamingContext}
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.util.IntParam
+import java.io.File
+import org.apache.spark.rdd.RDD
+import com.google.common.io.Files
+import java.nio.charset.Charset
+
+/**
+ * Counts words in text encoded with UTF8 received from the network every second.
+ *
+ * Usage: NetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-file>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
+ * <checkpoint-directory> directory to HDFS-compatible file system which checkpoint data
+ * <output-file> file to which the word counts will be appended
+ *
+ * In local mode, <master> should be 'local[n]' with n > 1
+ * <checkpoint-directory> and <output-file> must be absolute paths
+ *
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ *
+ * `$ nc -lk 9999`
+ *
+ * and run the example as
+ *
+ * `$ ./run-example org.apache.spark.streaming.examples.RecoverableNetworkWordCount \
+ * local[2] localhost 9999 ~/checkpoint/ ~/out`
+ *
+ * If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create
+ * a new StreamingContext (will print "Creating new context" to the console). Otherwise, if
+ * checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from
+ * the checkpoint data.
+ *
+ * To run this example in a local standalone cluster with automatic driver recovery,
+ *
+ * `$ ./spark-class org.apache.spark.deploy.Client -s launch <cluster-url> <path-to-examples-jar> \
+ * org.apache.spark.streaming.examples.RecoverableNetworkWordCount <cluster-url> \
+ * localhost 9999 ~/checkpoint ~/out`
+ *
+ * <path-to-examples-jar> would typically be <spark-dir>/examples/target/scala-XX/spark-examples....jar
+ *
+ * Refer to the online documentation for more details.
+ */
+
+object RecoverableNetworkWordCount {
+
+ def createContext(master: String, ip: String, port: Int, outputPath: String) = {
+
+ // If you do not see this printed, that means the StreamingContext has been loaded
+ // from the new checkpoint
+ println("Creating new context")
+ val outputFile = new File(outputPath)
+ if (outputFile.exists()) outputFile.delete()
+
+ // Create the context with a 1 second batch size
+ val ssc = new StreamingContext(master, "RecoverableNetworkWordCount", Seconds(1),
+ System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass))
+
+ // Create a NetworkInputDStream on target ip:port and count the
+ // words in input stream of \n delimited text (eg. generated by 'nc')
+ val lines = ssc.socketTextStream(ip, port)
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.foreach((rdd: RDD[(String, Int)], time: Time) => {
+ val counts = "Counts at time " + time + " " + rdd.collect().mkString("[", ", ", "]")
+ println(counts)
+ println("Appending to " + outputFile.getAbsolutePath)
+ Files.append(counts + "\n", outputFile, Charset.defaultCharset())
+ })
+ ssc
+ }
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ System.err.println("You arguments were " + args.mkString("[", ", ", "]"))
+ System.err.println(
+ """
+ |Usage: RecoverableNetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-file>
+ | <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ | <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
+ | <checkpoint-directory> directory to HDFS-compatible file system which checkpoint data
+ | <output-file> file to which the word counts will be appended
+ |
+ |In local mode, <master> should be 'local[n]' with n > 1
+ |Both <checkpoint-directory> and <output-file> must be absolute paths
+ """.stripMargin
+ )
+ System.exit(1)
+ }
+ val Array(master, ip, IntParam(port), checkpointDirectory, outputPath) = args
+ val ssc = StreamingContext.getOrCreate(checkpointDirectory,
+ () => {
+ createContext(master, ip, port, outputPath)
+ })
+ ssc.start()
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index ca0115f90e..1249ef4c3d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -24,10 +24,10 @@ import java.util.concurrent.RejectedExecutionException
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.conf.Configuration
-import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.{SparkException, SparkConf, Logging}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.MetadataCleaner
-import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.streaming.scheduler.JobGenerator
private[streaming]
@@ -44,6 +44,10 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf)
val sparkConf = ssc.conf
+ // These should be unset when a checkpoint is deserialized,
+ // otherwise the SparkContext won't initialize correctly.
+ sparkConf.remove("spark.hostPort").remove("spark.driver.host").remove("spark.driver.port")
+
def validate() {
assert(master != null, "Checkpoint.master is null")
assert(framework != null, "Checkpoint.framework is null")
@@ -53,59 +57,119 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
}
}
+private[streaming]
+object Checkpoint extends Logging {
+ val PREFIX = "checkpoint-"
+ val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r
+
+ /** Get the checkpoint file for the given checkpoint time */
+ def checkpointFile(checkpointDir: String, checkpointTime: Time) = {
+ new Path(checkpointDir, PREFIX + checkpointTime.milliseconds)
+ }
+
+ /** Get the checkpoint backup file for the given checkpoint time */
+ def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = {
+ new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk")
+ }
+
+ /** Get checkpoint files present in the give directory, ordered by oldest-first */
+ def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = {
+ def sortFunc(path1: Path, path2: Path): Boolean = {
+ val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) }
+ val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) }
+ (time1 < time2) || (time1 == time2 && bk1)
+ }
+
+ val path = new Path(checkpointDir)
+ if (fs.exists(path)) {
+ val statuses = fs.listStatus(path)
+ if (statuses != null) {
+ val paths = statuses.map(_.getPath)
+ val filtered = paths.filter(p => REGEX.findFirstIn(p.toString).nonEmpty)
+ filtered.sortWith(sortFunc)
+ } else {
+ logWarning("Listing " + path + " returned null")
+ Seq.empty
+ }
+ } else {
+ logInfo("Checkpoint directory " + path + " does not exist")
+ Seq.empty
+ }
+ }
+}
+
/**
* Convenience class to handle the writing of graph checkpoint to file
*/
private[streaming]
-class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Configuration)
- extends Logging
-{
- val file = new Path(checkpointDir, "graph")
+class CheckpointWriter(
+ jobGenerator: JobGenerator,
+ conf: SparkConf,
+ checkpointDir: String,
+ hadoopConf: Configuration
+ ) extends Logging {
val MAX_ATTEMPTS = 3
val executor = Executors.newFixedThreadPool(1)
val compressionCodec = CompressionCodec.createCodec(conf)
- // The file to which we actually write - and then "move" to file
- val writeFile = new Path(file.getParent, file.getName + ".next")
- // The file to which existing checkpoint is backed up (i.e. "moved")
- val bakFile = new Path(file.getParent, file.getName + ".bk")
-
private var stopped = false
private var fs_ : FileSystem = _
- // Removed code which validates whether there is only one CheckpointWriter per path 'file' since
- // I did not notice any errors - reintroduce it ?
class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable {
def run() {
var attempts = 0
val startTime = System.currentTimeMillis()
+ val tempFile = new Path(checkpointDir, "temp")
+ val checkpointFile = Checkpoint.checkpointFile(checkpointDir, checkpointTime)
+ val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, checkpointTime)
+
while (attempts < MAX_ATTEMPTS && !stopped) {
attempts += 1
try {
- logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'")
- // This is inherently thread unsafe, so alleviating it by writing to '.new' and
- // then moving it to the final file
- val fos = fs.create(writeFile)
+ logInfo("Saving checkpoint for time " + checkpointTime + " to file '" + checkpointFile + "'")
+
+ // Write checkpoint to temp file
+ fs.delete(tempFile, true) // just in case it exists
+ val fos = fs.create(tempFile)
fos.write(bytes)
fos.close()
- if (fs.exists(file) && fs.rename(file, bakFile)) {
- logDebug("Moved existing checkpoint file to " + bakFile)
+
+ // If the checkpoint file exists, back it up
+ // If the backup exists as well, just delete it, otherwise rename will fail
+ if (fs.exists(checkpointFile)) {
+ fs.delete(backupFile, true) // just in case it exists
+ if (!fs.rename(checkpointFile, backupFile)) {
+ logWarning("Could not rename " + checkpointFile + " to " + backupFile)
+ }
+ }
+
+ // Rename temp file to the final checkpoint file
+ if (!fs.rename(tempFile, checkpointFile)) {
+ logWarning("Could not rename " + tempFile + " to " + checkpointFile)
+ }
+
+ // Delete old checkpoint files
+ val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs)
+ if (allCheckpointFiles.size > 4) {
+ allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => {
+ logInfo("Deleting " + file)
+ fs.delete(file, true)
+ })
}
- // paranoia
- fs.delete(file, false)
- fs.rename(writeFile, file)
+ // All done, print success
val finishTime = System.currentTimeMillis()
- logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file +
- "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds")
+ logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile +
+ "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms")
+ jobGenerator.onCheckpointCompletion(checkpointTime)
return
} catch {
case ioe: IOException =>
- logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe)
+ logWarning("Error in attempt " + attempts + " of writing checkpoint to " + checkpointFile, ioe)
reset()
}
}
- logError("Could not write checkpoint for time " + checkpointTime + " to file '" + file + "'")
+ logWarning("Could not write checkpoint for time " + checkpointTime + " to file " + checkpointFile + "'")
}
}
@@ -118,6 +182,7 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi
bos.close()
try {
executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray))
+ logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
} catch {
case rej: RejectedExecutionException =>
logError("Could not submit checkpoint task to the thread pool executor", rej)
@@ -140,7 +205,7 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi
}
private def fs = synchronized {
- if (fs_ == null) fs_ = file.getFileSystem(hadoopConf)
+ if (fs_ == null) fs_ = new Path(checkpointDir).getFileSystem(hadoopConf)
fs_
}
@@ -153,43 +218,46 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi
private[streaming]
object CheckpointReader extends Logging {
- def read(conf: SparkConf, path: String): Checkpoint = {
- val fs = new Path(path).getFileSystem(new Configuration())
- val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"),
- new Path(path), new Path(path + ".bk"))
+ def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = {
+ val checkpointPath = new Path(checkpointDir)
+ def fs = checkpointPath.getFileSystem(hadoopConf)
+
+ // Try to find the checkpoint files
+ val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse
+ if (checkpointFiles.isEmpty) {
+ return None
+ }
+ // Try to read the checkpoint files in the order
+ logInfo("Checkpoint files found: " + checkpointFiles.mkString(","))
val compressionCodec = CompressionCodec.createCodec(conf)
-
- attempts.foreach(file => {
- if (fs.exists(file)) {
- logInfo("Attempting to load checkpoint from file '" + file + "'")
- try {
- val fis = fs.open(file)
- // ObjectInputStream uses the last defined user-defined class loader in the stack
- // to find classes, which maybe the wrong class loader. Hence, a inherited version
- // of ObjectInputStream is used to explicitly use the current thread's default class
- // loader to find and load classes. This is a well know Java issue and has popped up
- // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
- val zis = compressionCodec.compressedInputStream(fis)
- val ois = new ObjectInputStreamWithLoader(zis,
- Thread.currentThread().getContextClassLoader)
- val cp = ois.readObject.asInstanceOf[Checkpoint]
- ois.close()
- fs.close()
- cp.validate()
- logInfo("Checkpoint successfully loaded from file '" + file + "'")
- logInfo("Checkpoint was generated at time " + cp.checkpointTime)
- return cp
- } catch {
- case e: Exception =>
- logError("Error loading checkpoint from file '" + file + "'", e)
- }
- } else {
- logWarning("Could not read checkpoint from file '" + file + "' as it does not exist")
+ checkpointFiles.foreach(file => {
+ logInfo("Attempting to load checkpoint from file " + file)
+ try {
+ val fis = fs.open(file)
+ // ObjectInputStream uses the last defined user-defined class loader in the stack
+ // to find classes, which maybe the wrong class loader. Hence, a inherited version
+ // of ObjectInputStream is used to explicitly use the current thread's default class
+ // loader to find and load classes. This is a well know Java issue and has popped up
+ // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
+ val zis = compressionCodec.compressedInputStream(fis)
+ val ois = new ObjectInputStreamWithLoader(zis,
+ Thread.currentThread().getContextClassLoader)
+ val cp = ois.readObject.asInstanceOf[Checkpoint]
+ ois.close()
+ fs.close()
+ cp.validate()
+ logInfo("Checkpoint successfully loaded from file " + file)
+ logInfo("Checkpoint was generated at time " + cp.checkpointTime)
+ return Some(cp)
+ } catch {
+ case e: Exception =>
+ logWarning("Error reading checkpoint from file " + file, e)
}
-
})
- throw new Exception("Could not read checkpoint from path '" + path + "'")
+
+ // If none of checkpoint files could be read, then throw exception
+ throw new SparkException("Failed to read checkpoint from directory " + checkpointPath)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index 837f1ea1d8..b98f4a5101 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -329,13 +329,12 @@ abstract class DStream[T: ClassTag] (
* implementation clears the old generated RDDs. Subclasses of DStream may override
* this to clear their own metadata along with the generated RDDs.
*/
- protected[streaming] def clearOldMetadata(time: Time) {
- var numForgotten = 0
+ protected[streaming] def clearMetadata(time: Time) {
val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration))
generatedRDDs --= oldRDDs.keys
logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " +
(time - rememberDuration) + ": " + oldRDDs.keys.mkString(", "))
- dependencies.foreach(_.clearOldMetadata(time))
+ dependencies.foreach(_.clearMetadata(time))
}
/* Adds metadata to the Stream while it is running.
@@ -356,12 +355,18 @@ abstract class DStream[T: ClassTag] (
*/
protected[streaming] def updateCheckpointData(currentTime: Time) {
logInfo("Updating checkpoint data for time " + currentTime)
- checkpointData.update()
+ checkpointData.update(currentTime)
dependencies.foreach(_.updateCheckpointData(currentTime))
- checkpointData.cleanup()
logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData)
}
+ protected[streaming] def clearCheckpointData(time: Time) {
+ logInfo("Clearing checkpoint data")
+ checkpointData.cleanup(time)
+ dependencies.foreach(_.clearCheckpointData(time))
+ logInfo("Cleared checkpoint data")
+ }
+
/**
* Restore the RDDs in generatedRDDs from the checkpointData. This is an internal method
* that should not be called directly. This is a default implementation that recreates RDDs
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
index 3fd5d52403..671f7bbce7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
@@ -17,77 +17,86 @@
package org.apache.spark.streaming
+import scala.collection.mutable.{HashMap, HashSet}
+import scala.reflect.ClassTag
+
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.conf.Configuration
-import collection.mutable.HashMap
import org.apache.spark.Logging
-import scala.collection.mutable.HashMap
-import scala.reflect.ClassTag
-
+import java.io.{ObjectInputStream, IOException}
private[streaming]
class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
extends Serializable with Logging {
protected val data = new HashMap[Time, AnyRef]()
- @transient private var fileSystem : FileSystem = null
- @transient private var lastCheckpointFiles: HashMap[Time, String] = null
+ // Mapping of the batch time to the checkpointed RDD file of that time
+ @transient private var timeToCheckpointFile = new HashMap[Time, String]
+ // Mapping of the batch time to the time of the oldest checkpointed RDD
+ // in that batch's checkpoint data
+ @transient private var timeToOldestCheckpointFileTime = new HashMap[Time, Time]
- protected[streaming] def checkpointFiles = data.asInstanceOf[HashMap[Time, String]]
+ @transient private var fileSystem : FileSystem = null
+ protected[streaming] def currentCheckpointFiles = data.asInstanceOf[HashMap[Time, String]]
/**
* Updates the checkpoint data of the DStream. This gets called every time
* the graph checkpoint is initiated. Default implementation records the
* checkpoint files to which the generate RDDs of the DStream has been saved.
*/
- def update() {
+ def update(time: Time) {
// Get the checkpointed RDDs from the generated RDDs
- val newCheckpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
+ val checkpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
.map(x => (x._1, x._2.getCheckpointFile.get))
-
- // Make a copy of the existing checkpoint data (checkpointed RDDs)
- lastCheckpointFiles = checkpointFiles.clone()
-
- // If the new checkpoint data has checkpoints then replace existing with the new one
- if (newCheckpointFiles.size > 0) {
- checkpointFiles.clear()
- checkpointFiles ++= newCheckpointFiles
- }
-
- // TODO: remove this, this is just for debugging
- newCheckpointFiles.foreach {
- case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") }
+ logDebug("Current checkpoint files:\n" + checkpointFiles.toSeq.mkString("\n"))
+
+ // Add the checkpoint files to the data to be serialized
+ if (!checkpointFiles.isEmpty) {
+ currentCheckpointFiles.clear()
+ currentCheckpointFiles ++= checkpointFiles
+ // Add the current checkpoint files to the map of all checkpoint files
+ // This will be used to delete old checkpoint files
+ timeToCheckpointFile ++= currentCheckpointFiles
+ // Remember the time of the oldest checkpoint RDD in current state
+ timeToOldestCheckpointFileTime(time) = currentCheckpointFiles.keys.min(Time.ordering)
}
}
/**
- * Cleanup old checkpoint data. This gets called every time the graph
- * checkpoint is initiated, but after `update` is called. Default
- * implementation, cleans up old checkpoint files.
+ * Cleanup old checkpoint data. This gets called after a checkpoint of `time` has been
+ * written to the checkpoint directory.
*/
- def cleanup() {
- // If there is at least on checkpoint file in the current checkpoint files,
- // then delete the old checkpoint files.
- if (checkpointFiles.size > 0 && lastCheckpointFiles != null) {
- (lastCheckpointFiles -- checkpointFiles.keySet).foreach {
- case (time, file) => {
- try {
- val path = new Path(file)
- if (fileSystem == null) {
- fileSystem = path.getFileSystem(new Configuration())
+ def cleanup(time: Time) {
+ // Get the time of the oldest checkpointed RDD that was written as part of the
+ // checkpoint of `time`
+ timeToOldestCheckpointFileTime.remove(time) match {
+ case Some(lastCheckpointFileTime) =>
+ // Find all the checkpointed RDDs (i.e. files) that are older than `lastCheckpointFileTime`
+ // This is because checkpointed RDDs older than this are not going to be needed
+ // even after master fails, as the checkpoint data of `time` does not refer to those files
+ val filesToDelete = timeToCheckpointFile.filter(_._1 < lastCheckpointFileTime)
+ logDebug("Files to delete:\n" + filesToDelete.mkString(","))
+ filesToDelete.foreach {
+ case (time, file) =>
+ try {
+ val path = new Path(file)
+ if (fileSystem == null) {
+ fileSystem = path.getFileSystem(dstream.ssc.sparkContext.hadoopConfiguration)
+ }
+ fileSystem.delete(path, true)
+ timeToCheckpointFile -= time
+ logInfo("Deleted checkpoint file '" + file + "' for time " + time)
+ } catch {
+ case e: Exception =>
+ logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e)
+ fileSystem = null
}
- fileSystem.delete(path, true)
- logInfo("Deleted checkpoint file '" + file + "' for time " + time)
- } catch {
- case e: Exception =>
- logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e)
- }
}
- }
+ case None =>
+ logInfo("Nothing to delete")
}
}
@@ -98,7 +107,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
*/
def restore() {
// Create RDDs from the checkpoint data
- checkpointFiles.foreach {
+ currentCheckpointFiles.foreach {
case(time, file) => {
logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'")
dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file)))
@@ -107,6 +116,13 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
}
override def toString() = {
- "[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]"
+ "[\n" + currentCheckpointFiles.size + " checkpoint files \n" + currentCheckpointFiles.mkString("\n") + "\n]"
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ ois.defaultReadObject()
+ timeToOldestCheckpointFileTime = new HashMap[Time, Time]
+ timeToCheckpointFile = new HashMap[Time, String]
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index 62d07b22c6..eee9591ffc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -104,36 +104,44 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
def getOutputStreams() = this.synchronized { outputStreams.toArray }
def generateJobs(time: Time): Seq[Job] = {
- this.synchronized {
- logDebug("Generating jobs for time " + time)
- val jobs = outputStreams.flatMap(outputStream => outputStream.generateJob(time))
- logDebug("Generated " + jobs.length + " jobs for time " + time)
- jobs
+ logDebug("Generating jobs for time " + time)
+ val jobs = this.synchronized {
+ outputStreams.flatMap(outputStream => outputStream.generateJob(time))
}
+ logDebug("Generated " + jobs.length + " jobs for time " + time)
+ jobs
}
- def clearOldMetadata(time: Time) {
+ def clearMetadata(time: Time) {
+ logDebug("Clearing metadata for time " + time)
this.synchronized {
- logDebug("Clearing old metadata for time " + time)
- outputStreams.foreach(_.clearOldMetadata(time))
- logDebug("Cleared old metadata for time " + time)
+ outputStreams.foreach(_.clearMetadata(time))
}
+ logDebug("Cleared old metadata for time " + time)
}
def updateCheckpointData(time: Time) {
+ logInfo("Updating checkpoint data for time " + time)
this.synchronized {
- logInfo("Updating checkpoint data for time " + time)
outputStreams.foreach(_.updateCheckpointData(time))
- logInfo("Updated checkpoint data for time " + time)
}
+ logInfo("Updated checkpoint data for time " + time)
+ }
+
+ def clearCheckpointData(time: Time) {
+ logInfo("Clearing checkpoint data for time " + time)
+ this.synchronized {
+ outputStreams.foreach(_.clearCheckpointData(time))
+ }
+ logInfo("Cleared checkpoint data for time " + time)
}
def restoreCheckpointData() {
+ logInfo("Restoring checkpoint data")
this.synchronized {
- logInfo("Restoring checkpoint data")
outputStreams.foreach(_.restoreCheckpointData())
- logInfo("Restored checkpoint data")
}
+ logInfo("Restored checkpoint data")
}
def validate() {
@@ -146,8 +154,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
+ logDebug("DStreamGraph.writeObject used")
this.synchronized {
- logDebug("DStreamGraph.writeObject used")
checkpointInProgress = true
oos.defaultWriteObject()
checkpointInProgress = false
@@ -156,8 +164,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
@throws(classOf[IOException])
private def readObject(ois: ObjectInputStream) {
+ logDebug("DStreamGraph.readObject used")
this.synchronized {
- logDebug("DStreamGraph.readObject used")
checkpointInProgress = true
ois.defaultReadObject()
checkpointInProgress = false
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 693cb7fc30..dd34f6f4f2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -39,6 +39,7 @@ import org.apache.spark.util.MetadataCleaner
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.receivers._
import org.apache.spark.streaming.scheduler._
+import org.apache.hadoop.conf.Configuration
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -88,10 +89,12 @@ class StreamingContext private (
/**
* Re-create a StreamingContext from a checkpoint file.
- * @param path Path either to the directory that was specified as the checkpoint directory, or
- * to the checkpoint file 'graph' or 'graph.bk'.
+ * @param path Path to the directory that was specified as the checkpoint directory
+ * @param hadoopConf Optional, configuration object if necessary for reading from
+ * HDFS compatible filesystems
*/
- def this(path: String) = this(null, CheckpointReader.read(new SparkConf(), path), null)
+ def this(path: String, hadoopConf: Configuration = new Configuration) =
+ this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null)
if (sc_ == null && cp_ == null) {
throw new Exception("Spark Streaming cannot be initialized with " +
@@ -171,8 +174,9 @@ class StreamingContext private (
/**
* Set the context to periodically checkpoint the DStream operations for master
- * fault-tolerance. The graph will be checkpointed every batch interval.
- * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored
+ * fault-tolerance.
+ * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored.
+ * Note that this must be a fault-tolerant file system like HDFS for
*/
def checkpoint(directory: String) {
if (directory != null) {
@@ -461,26 +465,64 @@ class StreamingContext private (
}
}
+/**
+ * StreamingContext object contains a number of utility functions related to the
+ * StreamingContext class.
+ */
-object StreamingContext {
+object StreamingContext extends Logging {
implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = {
new PairDStreamFunctions[K, V](stream)
}
/**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
+ * will be created by called the provided `creatingFunc`.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param creatingFunc Function to create a new StreamingContext
+ * @param hadoopConf Optional Hadoop configuration if necessary for reading from the
+ * file system
+ * @param createOnError Optional, whether to create a new StreamingContext if there is an
+ * error in reading checkpoint data. By default, an exception will be
+ * thrown on error.
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ creatingFunc: () => StreamingContext,
+ hadoopConf: Configuration = new Configuration(),
+ createOnError: Boolean = false
+ ): StreamingContext = {
+ val checkpointOption = try {
+ CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf)
+ } catch {
+ case e: Exception =>
+ if (createOnError) {
+ None
+ } else {
+ throw e
+ }
+ }
+ checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc())
+ }
+
+ /**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
- * their JARs to SparkContext.
+ * their JARs to StreamingContext.
*/
def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls)
+
protected[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = {
// Set the default cleaner delay to an hour if not already set.
// This should be sufficient for even 1 second batch intervals.
- val sc = new SparkContext(conf)
- if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) {
- MetadataCleaner.setDelaySeconds(sc.conf, 3600)
+ if (MetadataCleaner.getDelaySeconds(conf) < 0) {
+ MetadataCleaner.setDelaySeconds(conf, 3600)
}
+ val sc = new SparkContext(conf)
sc
}
@@ -489,14 +531,17 @@ object StreamingContext {
appName: String,
sparkHome: String,
jars: Seq[String],
- environment: Map[String, String]): SparkContext =
- {
- val sc = new SparkContext(master, appName, sparkHome, jars, environment)
+ environment: Map[String, String]
+ ): SparkContext = {
+
+ val conf = SparkContext.updatedConf(
+ new SparkConf(), master, appName, sparkHome, jars, environment)
// Set the default cleaner delay to an hour if not already set.
// This should be sufficient for even 1 second batch intervals.
- if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) {
- MetadataCleaner.setDelaySeconds(sc.conf, 3600)
+ if (MetadataCleaner.getDelaySeconds(conf) < 0) {
+ MetadataCleaner.setDelaySeconds(conf, 3600)
}
+ val sc = new SparkContext(master, appName, sparkHome, jars, environment)
sc
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 7068f32517..523173d45a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -35,6 +35,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
import org.apache.spark.streaming.scheduler.StreamingListener
+import org.apache.hadoop.conf.Configuration
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -128,10 +129,16 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Re-creates a StreamingContext from a checkpoint file.
- * @param path Path either to the directory that was specified as the checkpoint directory, or
- * to the checkpoint file 'graph' or 'graph.bk'.
+ * @param path Path to the directory that was specified as the checkpoint directory
*/
- def this(path: String) = this (new StreamingContext(path))
+ def this(path: String) = this(new StreamingContext(path))
+
+ /**
+ * Re-creates a StreamingContext from a checkpoint file.
+ * @param path Path to the directory that was specified as the checkpoint directory
+ *
+ */
+ def this(path: String, hadoopConf: Configuration) = this(new StreamingContext(path, hadoopConf))
/** The underlying SparkContext */
val sc: JavaSparkContext = new JavaSparkContext(ssc.sc)
@@ -471,20 +478,97 @@ class JavaStreamingContext(val ssc: StreamingContext) {
}
/**
- * Starts the execution of the streams.
+ * Start the execution of the streams.
*/
def start() = ssc.start()
/**
- * Sstops the execution of the streams.
+ * Stop the execution of the streams.
*/
def stop() = ssc.stop()
}
+/**
+ * JavaStreamingContext object contains a number of utility functions.
+ */
object JavaStreamingContext {
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the provided factory
+ * will be used to create a JavaStreamingContext.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
+ * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ factory: JavaStreamingContextFactory
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ factory.create.ssc
+ })
+ new JavaStreamingContext(ssc)
+ }
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the provided factory
+ * will be used to create a JavaStreamingContext.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
+ * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible
+ * file system
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ hadoopConf: Configuration,
+ factory: JavaStreamingContextFactory
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ factory.create.ssc
+ }, hadoopConf)
+ new JavaStreamingContext(ssc)
+ }
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the provided factory
+ * will be used to create a JavaStreamingContext.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
+ * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible
+ * file system
+ * @param createOnError Whether to create a new JavaStreamingContext if there is an
+ * error in reading checkpoint data.
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ hadoopConf: Configuration,
+ factory: JavaStreamingContextFactory,
+ createOnError: Boolean
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ factory.create.ssc
+ }, hadoopConf, createOnError)
+ new JavaStreamingContext(ssc)
+ }
+
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
- * their JARs to SparkContext.
+ * their JARs to StreamingContext.
*/
def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls).toArray
}
+
+/**
+ * Factory interface for creating a new JavaStreamingContext
+ */
+trait JavaStreamingContextFactory {
+ def create(): JavaStreamingContext
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index fb9eda8996..1f0f31c4b1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -23,10 +23,10 @@ import scala.reflect.ClassTag
import org.apache.hadoop.fs.{FileSystem, Path, PathFilter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.UnionRDD
import org.apache.spark.streaming.{DStreamCheckpointData, StreamingContext, Time}
+import org.apache.spark.util.TimeStampedHashMap
private[streaming]
@@ -46,6 +46,8 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
@transient private var path_ : Path = null
@transient private var fs_ : FileSystem = null
@transient private[streaming] var files = new HashMap[Time, Array[String]]
+ @transient private var fileModTimes = new TimeStampedHashMap[String, Long](true)
+ @transient private var lastNewFileFindingTime = 0L
override def start() {
if (newFilesOnly) {
@@ -88,14 +90,16 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
}
/** Clear the old time-to-files mappings along with old RDDs */
- protected[streaming] override def clearOldMetadata(time: Time) {
- super.clearOldMetadata(time)
+ protected[streaming] override def clearMetadata(time: Time) {
+ super.clearMetadata(time)
val oldFiles = files.filter(_._1 <= (time - rememberDuration))
files --= oldFiles.keys
logInfo("Cleared " + oldFiles.size + " old files that were older than " +
(time - rememberDuration) + ": " + oldFiles.keys.mkString(", "))
logDebug("Cleared files are:\n" +
oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n"))
+ // Delete file mod times that weren't accessed in the last round of getting new files
+ fileModTimes.clearOldValues(lastNewFileFindingTime - 1)
}
/**
@@ -104,8 +108,19 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
*/
private def findNewFiles(currentTime: Long): (Seq[String], Long, Seq[String]) = {
logDebug("Trying to get new files for time " + currentTime)
+ lastNewFileFindingTime = System.currentTimeMillis
val filter = new CustomPathFilter(currentTime)
- val newFiles = fs.listStatus(path, filter).map(_.getPath.toString)
+ val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString)
+ val timeTaken = System.currentTimeMillis - lastNewFileFindingTime
+ logInfo("Finding new files took " + timeTaken + " ms")
+ logDebug("# cached file times = " + fileModTimes.size)
+ if (timeTaken > slideDuration.milliseconds) {
+ logWarning(
+ "Time taken to find new files exceeds the batch size. " +
+ "Consider increasing the batch size or reduceing the number of " +
+ "files in the monitored directory."
+ )
+ }
(newFiles, filter.latestModTime, filter.latestModTimeFiles.toSeq)
}
@@ -122,16 +137,21 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
new UnionRDD(context.sparkContext, fileRDDs)
}
- private def path: Path = {
+ private def directoryPath: Path = {
if (path_ == null) path_ = new Path(directory)
path_
}
private def fs: FileSystem = {
- if (fs_ == null) fs_ = path.getFileSystem(new Configuration())
+ if (fs_ == null) fs_ = directoryPath.getFileSystem(new Configuration())
fs_
}
+ private def getFileModTime(path: Path) = {
+ // Get file mod time from cache or fetch it from the file system
+ fileModTimes.getOrElseUpdate(path.toString, fs.getFileStatus(path).getModificationTime())
+ }
+
private def reset() {
fs_ = null
}
@@ -142,6 +162,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
ois.defaultReadObject()
generatedRDDs = new HashMap[Time, RDD[(K,V)]] ()
files = new HashMap[Time, Array[String]]
+ fileModTimes = new TimeStampedHashMap[String, Long](true)
}
/**
@@ -153,15 +174,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]]
- override def update() {
+ override def update(time: Time) {
hadoopFiles.clear()
hadoopFiles ++= files
}
- override def cleanup() { }
+ override def cleanup(time: Time) { }
override def restore() {
- hadoopFiles.foreach {
+ hadoopFiles.toSeq.sortBy(_._1)(Time.ordering).foreach {
case (t, f) => {
// Restore the metadata in both files and generatedRDDs
logInfo("Restoring files for time " + t + " - " +
@@ -187,14 +208,13 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
// Latest file mod time seen in this round of fetching files and its corresponding files
var latestModTime = 0L
val latestModTimeFiles = new HashSet[String]()
-
def accept(path: Path): Boolean = {
try {
if (!filter(path)) { // Reject file if it does not satisfy filter
logDebug("Rejected by filter " + path)
return false
}
- val modTime = fs.getFileStatus(path).getModificationTime()
+ val modTime = getFileModTime(path)
logDebug("Mod time for " + path + " is " + modTime)
if (modTime < prevModTime) {
logDebug("Mod time less than last mod time")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 3c624e8199..2fa6853ae0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -26,8 +26,9 @@ import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
/** Event classes for JobGenerator */
private[scheduler] sealed trait JobGeneratorEvent
private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent
-private[scheduler] case class ClearOldMetadata(time: Time) extends JobGeneratorEvent
+private[scheduler] case class ClearMetadata(time: Time) extends JobGeneratorEvent
private[scheduler] case class DoCheckpoint(time: Time) extends JobGeneratorEvent
+private[scheduler] case class ClearCheckpointData(time: Time) extends JobGeneratorEvent
/**
* This class generates jobs from DStreams as well as drives checkpointing and cleaning
@@ -53,7 +54,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
longTime => eventProcessorActor ! GenerateJobs(new Time(longTime)))
lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
- new CheckpointWriter(ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
+ new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
} else {
null
}
@@ -77,15 +78,20 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
* On batch completion, clear old metadata and checkpoint computation.
*/
private[scheduler] def onBatchCompletion(time: Time) {
- eventProcessorActor ! ClearOldMetadata(time)
+ eventProcessorActor ! ClearMetadata(time)
+ }
+
+ private[streaming] def onCheckpointCompletion(time: Time) {
+ eventProcessorActor ! ClearCheckpointData(time)
}
/** Processes all events */
private def processEvent(event: JobGeneratorEvent) {
event match {
case GenerateJobs(time) => generateJobs(time)
- case ClearOldMetadata(time) => clearOldMetadata(time)
+ case ClearMetadata(time) => clearMetadata(time)
case DoCheckpoint(time) => doCheckpoint(time)
+ case ClearCheckpointData(time) => clearCheckpointData(time)
}
}
@@ -115,14 +121,14 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val checkpointTime = ssc.initialCheckpoint.checkpointTime
val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds))
val downTimes = checkpointTime.until(restartTime, batchDuration)
- logInfo("Batches during down time: " + downTimes.mkString(", "))
+ logInfo("Batches during down time (" + downTimes.size + " batches): " + downTimes.mkString(", "))
// Batches that were unprocessed before failure
- val pendingTimes = ssc.initialCheckpoint.pendingTimes
- logInfo("Batches pending processing: " + pendingTimes.mkString(", "))
+ val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering)
+ logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", "))
// Reschedule jobs for these times
val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
- logInfo("Batches to reschedule: " + timesToReschedule.mkString(", "))
+ logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", "))
timesToReschedule.foreach(time =>
jobScheduler.runJobs(time, graph.generateJobs(time))
)
@@ -141,11 +147,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
}
/** Clear DStream metadata for the given `time`. */
- private def clearOldMetadata(time: Time) {
- ssc.graph.clearOldMetadata(time)
+ private def clearMetadata(time: Time) {
+ ssc.graph.clearMetadata(time)
eventProcessorActor ! DoCheckpoint(time)
}
+ /** Clear DStream checkpoint data for the given `time`. */
+ private def clearCheckpointData(time: Time) {
+ ssc.graph.clearCheckpointData(time)
+ }
+
/** Perform checkpoint for the give `time`. */
private def doCheckpoint(time: Time) = synchronized {
if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
index 1559f7a9f7..162b19d7f0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
@@ -42,6 +42,7 @@ object MasterFailureTest extends Logging {
@volatile var killed = false
@volatile var killCount = 0
+ @volatile var setupCalled = false
def main(args: Array[String]) {
if (args.size < 2) {
@@ -131,8 +132,26 @@ object MasterFailureTest extends Logging {
// Just making sure that the expected output does not have duplicates
assert(expectedOutput.distinct.toSet == expectedOutput.toSet)
+ // Reset all state
+ reset()
+
+ // Create the directories for this test
+ val uuid = UUID.randomUUID().toString
+ val rootDir = new Path(directory, uuid)
+ val fs = rootDir.getFileSystem(new Configuration())
+ val checkpointDir = new Path(rootDir, "checkpoint")
+ val testDir = new Path(rootDir, "test")
+ fs.mkdirs(checkpointDir)
+ fs.mkdirs(testDir)
+
// Setup the stream computation with the given operation
- val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation)
+ val ssc = StreamingContext.getOrCreate(checkpointDir.toString, () => {
+ setupStreams(batchDuration, operation, checkpointDir, testDir)
+ })
+
+ // Check if setupStream was called to create StreamingContext
+ // (and not created from checkpoint file)
+ assert(setupCalled, "Setup was not called in the first call to StreamingContext.getOrCreate")
// Start generating files in the a different thread
val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds)
@@ -144,9 +163,7 @@ object MasterFailureTest extends Logging {
val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2
val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun)
- // Delete directories
fileGeneratingThread.join()
- val fs = checkpointDir.getFileSystem(new Configuration())
fs.delete(checkpointDir, true)
fs.delete(testDir, true)
logInfo("Finished test after " + killCount + " failures")
@@ -159,32 +176,24 @@ object MasterFailureTest extends Logging {
* files should be written for testing.
*/
private def setupStreams[T: ClassTag](
- directory: String,
batchDuration: Duration,
- operation: DStream[String] => DStream[T]
- ): (StreamingContext, Path, Path) = {
- // Reset all state
- reset()
-
- // Create the directories for this test
- val uuid = UUID.randomUUID().toString
- val rootDir = new Path(directory, uuid)
- val fs = rootDir.getFileSystem(new Configuration())
- val checkpointDir = new Path(rootDir, "checkpoint")
- val testDir = new Path(rootDir, "test")
- fs.mkdirs(checkpointDir)
- fs.mkdirs(testDir)
+ operation: DStream[String] => DStream[T],
+ checkpointDir: Path,
+ testDir: Path
+ ): StreamingContext = {
+ // Mark that setup was called
+ setupCalled = true
// Setup the streaming computation with the given operation
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
- var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map())
+ val ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map())
ssc.checkpoint(checkpointDir.toString)
val inputStream = ssc.textFileStream(testDir.toString)
val operatedStream = operation(inputStream)
val outputStream = new TestOutputStream(operatedStream)
ssc.registerOutputStream(outputStream)
- (ssc, checkpointDir, testDir)
+ ssc
}
@@ -204,7 +213,7 @@ object MasterFailureTest extends Logging {
var isTimedOut = false
val mergedOutput = new ArrayBuffer[T]()
val checkpointDir = ssc.checkpointDir
- var batchDuration = ssc.graph.batchDuration
+ val batchDuration = ssc.graph.batchDuration
while(!isLastOutputGenerated && !isTimedOut) {
// Get the output buffer
@@ -261,7 +270,10 @@ object MasterFailureTest extends Logging {
)
Thread.sleep(sleepTime)
// Recreate the streaming context from checkpoint
- ssc = new StreamingContext(checkpointDir)
+ ssc = StreamingContext.getOrCreate(checkpointDir, () => {
+ throw new Exception("Trying to create new context when it " +
+ "should be reading from checkpoint file")
+ })
}
}
mergedOutput
@@ -297,6 +309,7 @@ object MasterFailureTest extends Logging {
private def reset() {
killed = false
killCount = 0
+ setupCalled = false
}
}
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 0d2145da9a..8b7d7709bf 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -28,6 +28,7 @@ import java.util.*;
import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
+import com.google.common.collect.Sets;
import org.apache.spark.SparkConf;
import org.apache.spark.HashPartitioner;
@@ -441,13 +442,13 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
new Tuple2<String, String>("new york", "islanders")));
- List<List<Tuple2<String, Tuple2<String, String>>>> expected = Arrays.asList(
- Arrays.asList(
+ List<HashSet<Tuple2<String, Tuple2<String, String>>>> expected = Arrays.asList(
+ Sets.newHashSet(
new Tuple2<String, Tuple2<String, String>>("california",
new Tuple2<String, String>("dodgers", "giants")),
new Tuple2<String, Tuple2<String, String>>("new york",
- new Tuple2<String, String>("yankees", "mets"))),
- Arrays.asList(
+ new Tuple2<String, String>("yankees", "mets"))),
+ Sets.newHashSet(
new Tuple2<String, Tuple2<String, String>>("california",
new Tuple2<String, String>("sharks", "ducks")),
new Tuple2<String, Tuple2<String, String>>("new york",
@@ -482,8 +483,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
JavaTestUtils.attachTestOutputStream(joined);
List<List<Tuple2<String, Tuple2<String, String>>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+ List<HashSet<Tuple2<String, Tuple2<String, String>>>> unorderedResult = Lists.newArrayList();
+ for (List<Tuple2<String, Tuple2<String, String>>> res: result) {
+ unorderedResult.add(Sets.newHashSet(res));
+ }
- Assert.assertEquals(expected, result);
+ Assert.assertEquals(expected, unorderedResult);
}
@@ -1196,15 +1201,15 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
Arrays.asList("hello", "moon"),
Arrays.asList("hello"));
- List<List<Tuple2<String, Long>>> expected = Arrays.asList(
- Arrays.asList(
+ List<HashSet<Tuple2<String, Long>>> expected = Arrays.asList(
+ Sets.newHashSet(
new Tuple2<String, Long>("hello", 1L),
new Tuple2<String, Long>("world", 1L)),
- Arrays.asList(
+ Sets.newHashSet(
new Tuple2<String, Long>("hello", 2L),
new Tuple2<String, Long>("world", 1L),
new Tuple2<String, Long>("moon", 1L)),
- Arrays.asList(
+ Sets.newHashSet(
new Tuple2<String, Long>("hello", 2L),
new Tuple2<String, Long>("moon", 1L)));
@@ -1214,8 +1219,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
stream.countByValueAndWindow(new Duration(2000), new Duration(1000));
JavaTestUtils.attachTestOutputStream(counted);
List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+ List<HashSet<Tuple2<String, Long>>> unorderedResult = Lists.newArrayList();
+ for (List<Tuple2<String, Long>> res: result) {
+ unorderedResult.add(Sets.newHashSet(res));
+ }
- Assert.assertEquals(expected, result);
+ Assert.assertEquals(expected, unorderedResult);
}
@Test
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 8dc80ac2ed..6499de98c9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -84,9 +84,9 @@ class CheckpointSuite extends TestSuiteBase {
ssc.start()
advanceTimeWithRealDelay(ssc, firstNumBatches)
logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData)
- assert(!stateStream.checkpointData.checkpointFiles.isEmpty,
+ assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty,
"No checkpointed RDDs in state stream before first failure")
- stateStream.checkpointData.checkpointFiles.foreach {
+ stateStream.checkpointData.currentCheckpointFiles.foreach {
case (time, file) => {
assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time +
" for state stream before first failure does not exist")
@@ -95,7 +95,7 @@ class CheckpointSuite extends TestSuiteBase {
// Run till a further time such that previous checkpoint files in the stream would be deleted
// and check whether the earlier checkpoint files are deleted
- val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2))
+ val checkpointFiles = stateStream.checkpointData.currentCheckpointFiles.map(x => new File(x._2))
advanceTimeWithRealDelay(ssc, secondNumBatches)
checkpointFiles.foreach(file =>
assert(!file.exists, "Checkpoint file '" + file + "' was not deleted"))
@@ -114,9 +114,9 @@ class CheckpointSuite extends TestSuiteBase {
// is present in the checkpoint data or not
ssc.start()
advanceTimeWithRealDelay(ssc, 1)
- assert(!stateStream.checkpointData.checkpointFiles.isEmpty,
+ assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty,
"No checkpointed RDDs in state stream before second failure")
- stateStream.checkpointData.checkpointFiles.foreach {
+ stateStream.checkpointData.currentCheckpointFiles.foreach {
case (time, file) => {
assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time +
" for state stream before seconds failure does not exist")
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 2bb11e54c5..2e46d750c4 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -127,14 +127,13 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
// local dirs, so lets check both. We assume one of the 2 is set.
// LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
- .getOrElse(Option(System.getenv("LOCAL_DIRS"))
- .getOrElse(""))
-
- if (localDirs.isEmpty()) {
- throw new Exception("Yarn Local dirs can't be empty")
+ .orElse(Option(System.getenv("LOCAL_DIRS")))
+
+ localDirs match {
+ case None => throw new Exception("Yarn Local dirs can't be empty")
+ case Some(l) => l
}
- localDirs
- }
+ }
private def getApplicationAttemptId(): ApplicationAttemptId = {
val envs = System.getenv()
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
index ddfec1a4ac..62b20b8fba 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -76,6 +76,10 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
def run() {
+ // Setup the directories so things go to yarn approved directories rather
+ // then user specified and /tmp.
+ System.setProperty("spark.local.dir", getLocalDirs())
+
appAttemptId = getApplicationAttemptId()
resourceManager = registerWithResourceManager()
val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
@@ -103,10 +107,12 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
// ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
- // must be <= timeoutInterval/ 2.
- // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
- // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
- val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+ // we want to be reasonably responsive without causing too many requests to RM.
+ val schedulerInterval =
+ System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong
+ // must be <= timeoutInterval / 2.
+ val interval = math.min(timeoutInterval / 2, schedulerInterval)
+
reporterThread = launchReporterThread(interval)
// Wait for the reporter thread to Finish.
@@ -119,6 +125,20 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
System.exit(0)
}
+ /** Get the Yarn approved local directories. */
+ private def getLocalDirs(): String = {
+ // Hadoop 0.23 and 2.x have different Environment variable names for the
+ // local dirs, so lets check both. We assume one of the 2 is set.
+ // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
+ val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
+ .orElse(Option(System.getenv("LOCAL_DIRS")))
+
+ localDirs match {
+ case None => throw new Exception("Yarn Local dirs can't be empty")
+ case Some(l) => l
+ }
+ }
+
private def getApplicationAttemptId(): ApplicationAttemptId = {
val envs = System.getenv()
val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 4b1b5da048..22e55e0c60 100644
--- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -22,6 +22,8 @@ import org.apache.spark.{SparkException, Logging, SparkContext}
import org.apache.spark.deploy.yarn.{Client, ClientArguments}
import org.apache.spark.scheduler.TaskSchedulerImpl
+import scala.collection.mutable.ArrayBuffer
+
private[spark] class YarnClientSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext)
@@ -31,45 +33,47 @@ private[spark] class YarnClientSchedulerBackend(
var client: Client = null
var appId: ApplicationId = null
+ private[spark] def addArg(optionName: String, optionalParam: String, arrayBuf: ArrayBuffer[String]) {
+ Option(System.getenv(optionalParam)) foreach {
+ optParam => {
+ arrayBuf += (optionName, optParam)
+ }
+ }
+ }
+
override def start() {
super.start()
- val defalutWorkerCores = "2"
- val defalutWorkerMemory = "512m"
- val defaultWorkerNumber = "1"
-
val userJar = System.getenv("SPARK_YARN_APP_JAR")
- val distFiles = System.getenv("SPARK_YARN_DIST_FILES")
- var workerCores = System.getenv("SPARK_WORKER_CORES")
- var workerMemory = System.getenv("SPARK_WORKER_MEMORY")
- var workerNumber = System.getenv("SPARK_WORKER_INSTANCES")
-
if (userJar == null)
throw new SparkException("env SPARK_YARN_APP_JAR is not set")
- if (workerCores == null)
- workerCores = defalutWorkerCores
- if (workerMemory == null)
- workerMemory = defalutWorkerMemory
- if (workerNumber == null)
- workerNumber = defaultWorkerNumber
-
val driverHost = conf.get("spark.driver.host")
val driverPort = conf.get("spark.driver.port")
val hostport = driverHost + ":" + driverPort
- val argsArray = Array[String](
+ val argsArrayBuf = new ArrayBuffer[String]()
+ argsArrayBuf += (
"--class", "notused",
"--jar", userJar,
"--args", hostport,
- "--worker-memory", workerMemory,
- "--worker-cores", workerCores,
- "--num-workers", workerNumber,
- "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher",
- "--files", distFiles
+ "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher"
)
- val args = new ClientArguments(argsArray, conf)
+ // process any optional arguments, use the defaults already defined in ClientArguments
+ // if things aren't specified
+ Map("--master-memory" -> "SPARK_MASTER_MEMORY",
+ "--num-workers" -> "SPARK_WORKER_INSTANCES",
+ "--worker-memory" -> "SPARK_WORKER_MEMORY",
+ "--worker-cores" -> "SPARK_WORKER_CORES",
+ "--queue" -> "SPARK_YARN_QUEUE",
+ "--name" -> "SPARK_YARN_APP_NAME",
+ "--files" -> "SPARK_YARN_DIST_FILES",
+ "--archives" -> "SPARK_YARN_DIST_ARCHIVES")
+ .foreach { case (optName, optParam) => addArg(optName, optParam, argsArrayBuf) }
+
+ logDebug("ClientArguments called with: " + argsArrayBuf)
+ val args = new ClientArguments(argsArrayBuf.toArray, conf)
client = new Client(args, conf)
appId = client.runApp()
waitForApp()
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 69ae14ce83..4b777d5fa7 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -116,14 +116,13 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
// local dirs, so lets check both. We assume one of the 2 is set.
// LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
- .getOrElse(Option(System.getenv("LOCAL_DIRS"))
- .getOrElse(""))
-
- if (localDirs.isEmpty()) {
- throw new Exception("Yarn Local dirs can't be empty")
+ .orElse(Option(System.getenv("LOCAL_DIRS")))
+
+ localDirs match {
+ case None => throw new Exception("Yarn Local dirs can't be empty")
+ case Some(l) => l
}
- localDirs
- }
+ }
private def getApplicationAttemptId(): ApplicationAttemptId = {
val envs = System.getenv()
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index be323d7783..952e963389 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -99,6 +99,7 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf)
appContext.setApplicationName(args.appName)
appContext.setQueue(args.amQueue)
appContext.setAMContainerSpec(amContainer)
+ appContext.setApplicationType("SPARK")
// Memory for the ApplicationMaster.
val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
index 49248a8516..78353224fa 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -78,6 +78,10 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
def run() {
+ // Setup the directories so things go to yarn approved directories rather
+ // then user specified and /tmp.
+ System.setProperty("spark.local.dir", getLocalDirs())
+
amClient = AMRMClient.createAMRMClient()
amClient.init(yarnConf)
amClient.start()
@@ -94,10 +98,12 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
// ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
- // must be <= timeoutInterval/ 2.
- // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
- // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
- val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval / 10, 60000L))
+ // we want to be reasonably responsive without causing too many requests to RM.
+ val schedulerInterval =
+ System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong
+ // must be <= timeoutInterval / 2.
+ val interval = math.min(timeoutInterval / 2, schedulerInterval)
+
reporterThread = launchReporterThread(interval)
// Wait for the reporter thread to Finish.
@@ -110,6 +116,20 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
System.exit(0)
}
+ /** Get the Yarn approved local directories. */
+ private def getLocalDirs(): String = {
+ // Hadoop 0.23 and 2.x have different Environment variable names for the
+ // local dirs, so lets check both. We assume one of the 2 is set.
+ // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
+ val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
+ .orElse(Option(System.getenv("LOCAL_DIRS")))
+
+ localDirs match {
+ case None => throw new Exception("Yarn Local dirs can't be empty")
+ case Some(l) => l
+ }
+ }
+
private def getApplicationAttemptId(): ApplicationAttemptId = {
val envs = System.getenv()
val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())