aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/pyspark7
-rw-r--r--conf/log4j.properties.template4
-rw-r--r--core/src/main/resources/org/apache/spark/log4j-defaults.properties4
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala61
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala88
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/Vector.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala (renamed from core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala)116
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala350
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala101
-rw-r--r--core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala120
-rw-r--r--core/src/test/scala/org/apache/spark/util/VectorSuite.scala44
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala (renamed from core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala)46
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala230
-rw-r--r--docs/configuration.md25
-rw-r--r--docs/python-programming-guide.md5
-rw-r--r--docs/running-on-yarn.md15
-rwxr-xr-xec2/spark_ec2.py12
-rw-r--r--examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java2
-rw-r--r--examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java2
-rw-r--r--examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java11
-rw-r--r--examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java2
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala12
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala3
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala5
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala8
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala7
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala8
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala5
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala118
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/StreamingExamples.scala21
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala10
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala9
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala1
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala188
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStream.scala17
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala106
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala38
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala75
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala96
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala42
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala31
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala55
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java29
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala2
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala10
-rw-r--r--yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala13
-rw-r--r--yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala28
-rw-r--r--yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala50
-rw-r--r--yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala13
-rw-r--r--yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala1
-rw-r--r--yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala28
69 files changed, 1982 insertions, 402 deletions
diff --git a/bin/pyspark b/bin/pyspark
index d6810f4686..ed6f8da730 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -59,12 +59,7 @@ if [ -n "$IPYTHON_OPTS" ]; then
fi
if [[ "$IPYTHON" = "1" ]] ; then
- # IPython <1.0.0 doesn't honor PYTHONSTARTUP, while 1.0.0+ does.
- # Hence we clear PYTHONSTARTUP and use the -c "%run $IPYTHONSTARTUP" command which works on all versions
- # We also force interactive mode with "-i"
- IPYTHONSTARTUP=$PYTHONSTARTUP
- PYTHONSTARTUP=
- exec ipython "$IPYTHON_OPTS" -i -c "%run $IPYTHONSTARTUP"
+ exec ipython $IPYTHON_OPTS
else
exec "$PYSPARK_PYTHON" "$@"
fi
diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template
index 17d1978dde..f7f8535594 100644
--- a/conf/log4j.properties.template
+++ b/conf/log4j.properties.template
@@ -5,5 +5,7 @@ log4j.appender.console.target=System.err
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
-# Ignore messages below warning level from Jetty, because it's a bit verbose
+# Settings to quiet third party logs that are too verbose
log4j.logger.org.eclipse.jetty=WARN
+log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
+log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
index 17d1978dde..f7f8535594 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
@@ -5,5 +5,7 @@ log4j.appender.console.target=System.err
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
-# Ignore messages below warning level from Jetty, because it's a bit verbose
+# Settings to quiet third party logs that are too verbose
log4j.logger.org.eclipse.jetty=WARN
+log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
+log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 1a2ec55876..8b30cd4bfe 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -17,7 +17,7 @@
package org.apache.spark
-import org.apache.spark.util.AppendOnlyMap
+import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap}
/**
* A set of functions used to aggregate data.
@@ -31,30 +31,51 @@ case class Aggregator[K, V, C] (
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
+ private val sparkConf = SparkEnv.get.conf
+ private val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true)
+
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
- val combiners = new AppendOnlyMap[K, C]
- var kv: Product2[K, V] = null
- val update = (hadValue: Boolean, oldValue: C) => {
- if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
- }
- while (iter.hasNext) {
- kv = iter.next()
- combiners.changeValue(kv._1, update)
+ if (!externalSorting) {
+ val combiners = new AppendOnlyMap[K,C]
+ var kv: Product2[K, V] = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
+ }
+ while (iter.hasNext) {
+ kv = iter.next()
+ combiners.changeValue(kv._1, update)
+ }
+ combiners.iterator
+ } else {
+ val combiners =
+ new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ while (iter.hasNext) {
+ val (k, v) = iter.next()
+ combiners.insert(k, v)
+ }
+ combiners.iterator
}
- combiners.iterator
}
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
- val combiners = new AppendOnlyMap[K, C]
- var kc: (K, C) = null
- val update = (hadValue: Boolean, oldValue: C) => {
- if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
+ if (!externalSorting) {
+ val combiners = new AppendOnlyMap[K,C]
+ var kc: Product2[K, C] = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
+ }
+ while (iter.hasNext) {
+ kc = iter.next()
+ combiners.changeValue(kc._1, update)
+ }
+ combiners.iterator
+ } else {
+ val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
+ while (iter.hasNext) {
+ val (k, c) = iter.next()
+ combiners.insert(k, c)
+ }
+ combiners.iterator
}
- while (iter.hasNext) {
- kc = iter.next()
- combiners.changeValue(kc._1, update)
- }
- combiners.iterator
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 66c226e491..139048d5c7 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -677,10 +677,10 @@ class SparkContext(
key = uri.getScheme match {
// A JAR file which exists only on the driver node
case null | "file" =>
- if (SparkHadoopUtil.get.isYarnMode()) {
- // In order for this to work on yarn the user must specify the --addjars option to
- // the client to upload the file into the distributed cache to make it show up in the
- // current working directory.
+ if (SparkHadoopUtil.get.isYarnMode() && master == "yarn-standalone") {
+ // In order for this to work in yarn standalone mode the user must specify the
+ // --addjars option to the client to upload the file into the distributed cache
+ // of the AM to make it show up in the current working directory.
val fileName = new Path(uri.getPath).getName()
try {
env.httpFileServer.addJar(new File(fileName))
@@ -1086,7 +1086,7 @@ object SparkContext {
* parameters that are passed as the default value of null, instead of throwing an exception
* like SparkConf would.
*/
- private def updatedConf(
+ private[spark] def updatedConf(
conf: SparkConf,
master: String,
appName: String,
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index e093e2f162..08b592df71 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -54,7 +54,11 @@ class SparkEnv private[spark] (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
- val conf: SparkConf) {
+ val conf: SparkConf) extends Logging {
+
+ // A mapping of thread ID to amount of memory used for shuffle in bytes
+ // All accesses should be manually synchronized
+ val shuffleMemoryMap = mutable.HashMap[Long, Long]()
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index dbb0cb90f5..9485bfd89e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -67,11 +67,11 @@ private[spark] class ApplicationPage(parent: MasterWebUI) {
<li><strong>User:</strong> {app.desc.user}</li>
<li><strong>Cores:</strong>
{
- if (app.desc.maxCores == Integer.MAX_VALUE) {
+ if (app.desc.maxCores == None) {
"Unlimited (%s granted)".format(app.coresGranted)
} else {
"%s (%s granted, %s left)".format(
- app.desc.maxCores, app.coresGranted, app.coresLeft)
+ app.desc.maxCores.get, app.coresGranted, app.coresLeft)
}
}
</li>
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index e51d274d33..a7b2328a02 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -279,6 +279,11 @@ private[spark] class Executor(
//System.exit(1)
}
} finally {
+ // TODO: Unregister shuffle memory only for ShuffleMapTask
+ val shuffleMemoryMap = env.shuffleMemoryMap
+ shuffleMemoryMap.synchronized {
+ shuffleMemoryMap.remove(Thread.currentThread().getId)
+ }
runningTasks.remove(taskId)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 4ba4696fef..a73714abca 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -23,8 +23,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
-import org.apache.spark.util.AppendOnlyMap
-
+import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -44,14 +43,12 @@ private[spark] case class NarrowCoGroupSplitDep(
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
-private[spark]
-class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
+private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
}
-
/**
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
* tuple with the list of values for that key.
@@ -62,6 +59,14 @@ class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
+ // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs).
+ // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner.
+ // CoGroupValue is the intermediate state of each value before being merged in compute.
+ private type CoGroup = ArrayBuffer[Any]
+ private type CoGroupValue = (Any, Int) // Int is dependency number
+ private type CoGroupCombiner = Seq[CoGroup]
+
+ private val sparkConf = SparkEnv.get.conf
private var serializerClass: String = null
def setSerializer(cls: String): CoGroupedRDD[K] = {
@@ -100,37 +105,74 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
override val partitioner = Some(part)
- override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
+ override def compute(s: Partition, context: TaskContext): Iterator[(K, CoGroupCombiner)] = {
+ val externalSorting = sparkConf.getBoolean("spark.shuffle.externalSorting", true)
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
- // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
- val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
- val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
- if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any])
- }
-
- val getSeq = (k: K) => {
- map.changeValue(k, update)
- }
-
- val ser = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)
+ // A list of (rdd iterator, dependency number) pairs
+ val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
- rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv =>
- getSeq(kv._1)(depNum) += kv._2
- }
+ val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
+ rddIterators += ((it, depNum))
}
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
- kv => getSeq(kv._1)(depNum) += kv._2
+ val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf)
+ val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
+ rddIterators += ((it, depNum))
+ }
+ }
+
+ if (!externalSorting) {
+ val map = new AppendOnlyMap[K, CoGroupCombiner]
+ val update: (Boolean, CoGroupCombiner) => CoGroupCombiner = (hadVal, oldVal) => {
+ if (hadVal) oldVal else Array.fill(numRdds)(new CoGroup)
+ }
+ val getCombiner: K => CoGroupCombiner = key => {
+ map.changeValue(key, update)
+ }
+ rddIterators.foreach { case (it, depNum) =>
+ while (it.hasNext) {
+ val kv = it.next()
+ getCombiner(kv._1)(depNum) += kv._2
}
}
+ new InterruptibleIterator(context, map.iterator)
+ } else {
+ val map = createExternalMap(numRdds)
+ rddIterators.foreach { case (it, depNum) =>
+ while (it.hasNext) {
+ val kv = it.next()
+ map.insert(kv._1, new CoGroupValue(kv._2, depNum))
+ }
+ }
+ new InterruptibleIterator(context, map.iterator)
+ }
+ }
+
+ private def createExternalMap(numRdds: Int)
+ : ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner] = {
+
+ val createCombiner: (CoGroupValue => CoGroupCombiner) = value => {
+ val newCombiner = Array.fill(numRdds)(new CoGroup)
+ value match { case (v, depNum) => newCombiner(depNum) += v }
+ newCombiner
}
- new InterruptibleIterator(context, map.iterator)
+ val mergeValue: (CoGroupCombiner, CoGroupValue) => CoGroupCombiner =
+ (combiner, value) => {
+ value match { case (v, depNum) => combiner(depNum) += v }
+ combiner
+ }
+ val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner =
+ (combiner1, combiner2) => {
+ combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 }
+ }
+ new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner](
+ createCombiner, mergeValue, mergeCombiners)
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index c118ddfc01..1248409e35 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -99,8 +99,6 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
}, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
- // A sanity check to make sure mergeCombiners is not defined.
- assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
values.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
@@ -267,8 +265,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
// into a hash table, leading to more objects in the old gen.
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
+ def mergeCombiners(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = c1 ++ c2
val bufs = combineByKey[ArrayBuffer[V]](
- createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false)
+ createCombiner _, mergeValue _, mergeCombiners _, partitioner, mapSideCombine=false)
bufs.asInstanceOf[RDD[(K, Seq[V])]]
}
@@ -339,7 +338,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
* existing partitioner/parallelism level.
*/
def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C)
- : RDD[(K, C)] = {
+ : RDD[(K, C)] = {
combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 043e01dbfb..38b536023b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -106,7 +106,7 @@ class DAGScheduler(
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
- val RESUBMIT_TIMEOUT = 50.milliseconds
+ val RESUBMIT_TIMEOUT = 200.milliseconds
// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
@@ -196,7 +196,7 @@ class DAGScheduler(
*/
def receive = {
case event: DAGSchedulerEvent =>
- logDebug("Got event of type " + event.getClass.getName)
+ logTrace("Got event of type " + event.getClass.getName)
/**
* All events are forwarded to `processEvent()`, so that the event processing logic can
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 7156d855d8..301d784b35 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -17,12 +17,14 @@
package org.apache.spark.storage
+import java.util.UUID
+
/**
* Identifies a particular Block of data, usually associated with a single file.
* A Block can be uniquely identified by its filename, but each type of Block has a different
* set of keys which produce its unique name.
*
- * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method.
+ * If your BlockId should be serializable, be sure to add it to the BlockId.apply() method.
*/
private[spark] sealed abstract class BlockId {
/** A globally unique identifier for this Block. Can be used for ser/de. */
@@ -55,7 +57,8 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def name = "broadcast_" + broadcastId
}
-private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
+private[spark]
+case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
def name = broadcastId.name + "_" + hType
}
@@ -67,6 +70,11 @@ private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends B
def name = "input-" + streamId + "-" + uniqueId
}
+/** Id associated with temporary data managed as blocks. Not serializable. */
+private[spark] case class TempBlockId(id: UUID) extends BlockId {
+ def name = "temp_" + id
+}
+
// Intended only for testing purposes
private[spark] case class TestBlockId(id: String) extends BlockId {
def name = "test_" + id
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index c56e2ca2df..ff9f241fc1 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -159,7 +159,7 @@ private[spark] class BlockManager(
/**
* Reregister with the master and report all blocks to it. This will be called by the heart beat
- * thread if our heartbeat to the block amnager indicates that we were not registered.
+ * thread if our heartbeat to the block manager indicates that we were not registered.
*
* Note that this method must be called without any BlockInfo locks held.
*/
@@ -864,7 +864,7 @@ private[spark] object BlockManager extends Logging {
val ID_GENERATOR = new IdGenerator
def getMaxMemory(conf: SparkConf): Long = {
- val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.66)
+ val memoryFraction = conf.getDouble("spark.storage.memoryFraction", 0.6)
(Runtime.getRuntime.maxMemory * memoryFraction).toLong
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 61e63c60d5..369a277232 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -181,4 +181,8 @@ class DiskBlockObjectWriter(
// Only valid if called after close()
override def timeWriting() = _timeWriting
+
+ def bytesWritten: Long = {
+ lastValidPosition - initialPosition
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index edc1133172..a8ef7fa8b6 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -19,7 +19,7 @@ package org.apache.spark.storage
import java.io.File
import java.text.SimpleDateFormat
-import java.util.{Date, Random}
+import java.util.{Date, Random, UUID}
import org.apache.spark.Logging
import org.apache.spark.executor.ExecutorExitCode
@@ -90,6 +90,15 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
def getFile(blockId: BlockId): File = getFile(blockId.name)
+ /** Produces a unique block id and File suitable for intermediate results. */
+ def createTempBlock(): (TempBlockId, File) = {
+ var blockId = new TempBlockId(UUID.randomUUID())
+ while (getFile(blockId).exists()) {
+ blockId = new TempBlockId(UUID.randomUUID())
+ }
+ (blockId, getFile(blockId))
+ }
+
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index e2b24298a5..6e0ff143b7 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -64,7 +64,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
// TODO: Remove this once the shuffle file consolidation feature is stable.
val consolidateShuffleFiles =
- conf.getBoolean("spark.shuffle.consolidateFiles", false)
+ conf.getBoolean("spark.shuffle.consolidateFiles", true)
private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 8dcfeacb60..d1e58016be 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -171,7 +171,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
summary ++
<h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++
<div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++
- <h4>Aggregated Metrics by Executors</h4> ++ executorTable.toNodeSeq() ++
+ <h4>Aggregated Metrics by Executor</h4> ++ executorTable.toNodeSeq() ++
<h4>Tasks</h4> ++ taskTable
headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages)
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 3d1e90a352..ac07a55cb9 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -74,7 +74,7 @@ object MetadataCleanerType extends Enumeration {
// initialization of StreamingContext. It's okay for users trying to configure stuff themselves.
object MetadataCleaner {
def getDelaySeconds(conf: SparkConf) = {
- conf.getInt("spark.cleaner.ttl", 3500)
+ conf.getInt("spark.cleaner.ttl", -1)
}
def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int =
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index 181ae2fd45..8e07a0f29a 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -26,16 +26,23 @@ import org.apache.spark.Logging
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
- * time stamp along with each key-value pair. Key-value pairs that are older than a particular
- * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in
- * replacement of scala.collection.mutable.HashMap.
+ * timestamp along with each key-value pair. If specified, the timestamp of each pair can be
+ * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular
+ * threshold time can then be removed using the clearOldValues method. This is intended to
+ * be a drop-in replacement of scala.collection.mutable.HashMap.
+ * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be
+ * updated when it is accessed
*/
-class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging {
+class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
+ extends Map[A, B]() with Logging {
val internalMap = new ConcurrentHashMap[A, (B, Long)]()
def get(key: A): Option[B] = {
val value = internalMap.get(key)
- if (value != null) Some(value._1) else None
+ if (value != null && updateTimeStampOnGet) {
+ internalMap.replace(key, value, (value._1, currentTime))
+ }
+ Option(value).map(_._1)
}
def iterator: Iterator[(A, B)] = {
diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala
index fe710c58ac..62fd6d8da5 100644
--- a/core/src/main/scala/org/apache/spark/util/Vector.scala
+++ b/core/src/main/scala/org/apache/spark/util/Vector.scala
@@ -17,6 +17,8 @@
package org.apache.spark.util
+import scala.util.Random
+
class Vector(val elements: Array[Double]) extends Serializable {
def length = elements.length
@@ -124,6 +126,12 @@ object Vector {
def ones(length: Int) = Vector(length, _ => 1)
+ /**
+ * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers
+ * between 0.0 and 1.0. Optional [[scala.util.Random]] number generator can be provided.
+ */
+ def random(length: Int, random: Random = new XORShiftRandom()) = Vector(length, _ => random.nextDouble())
+
class Multiplier(num: Double) {
def * (vec: Vector) = vec * num
}
diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
index 8bb4ee3bfa..d98c7aa3d7 100644
--- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
@@ -15,7 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.util
+package org.apache.spark.util.collection
+
+import java.util.{Arrays, Comparator}
/**
* A simple open hash table optimized for the append-only use case, where keys
@@ -28,14 +30,15 @@ package org.apache.spark.util
* TODO: Cache the hash values of each key? java.util.HashMap does that.
*/
private[spark]
-class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable {
+class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K,
+ V)] with Serializable {
require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
require(initialCapacity >= 1, "Invalid initial capacity")
private var capacity = nextPowerOf2(initialCapacity)
private var mask = capacity - 1
private var curSize = 0
- private var growThreshold = LOAD_FACTOR * capacity
+ private var growThreshold = (LOAD_FACTOR * capacity).toInt
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
@@ -45,10 +48,15 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
private var haveNullValue = false
private var nullValue: V = null.asInstanceOf[V]
+ // Triggered by destructiveSortedIterator; the underlying data array may no longer be used
+ private var destroyed = false
+ private val destructionMessage = "Map state is invalid from destructive sorting!"
+
private val LOAD_FACTOR = 0.7
/** Get the value for a given key */
def apply(key: K): V = {
+ assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
return nullValue
@@ -72,6 +80,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
/** Set the value for a key */
def update(key: K, value: V): Unit = {
+ assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
@@ -106,6 +115,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
* for key, if any, or null otherwise. Returns the newly updated value.
*/
def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
+ assert(!destroyed, destructionMessage)
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
@@ -139,35 +149,38 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
}
/** Iterator method from Iterable */
- override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
- var pos = -1
-
- /** Get the next value we should return from next(), or null if we're finished iterating */
- def nextValue(): (K, V) = {
- if (pos == -1) { // Treat position -1 as looking at the null value
- if (haveNullValue) {
- return (null.asInstanceOf[K], nullValue)
+ override def iterator: Iterator[(K, V)] = {
+ assert(!destroyed, destructionMessage)
+ new Iterator[(K, V)] {
+ var pos = -1
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def nextValue(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
}
- pos += 1
- }
- while (pos < capacity) {
- if (!data(2 * pos).eq(null)) {
- return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ while (pos < capacity) {
+ if (!data(2 * pos).eq(null)) {
+ return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ }
+ pos += 1
}
- pos += 1
+ null
}
- null
- }
- override def hasNext: Boolean = nextValue() != null
+ override def hasNext: Boolean = nextValue() != null
- override def next(): (K, V) = {
- val value = nextValue()
- if (value == null) {
- throw new NoSuchElementException("End of iterator")
+ override def next(): (K, V) = {
+ val value = nextValue()
+ if (value == null) {
+ throw new NoSuchElementException("End of iterator")
+ }
+ pos += 1
+ value
}
- pos += 1
- value
}
}
@@ -190,7 +203,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
}
/** Double the table's size and re-hash everything */
- private def growTable() {
+ protected def growTable() {
val newCapacity = capacity * 2
if (newCapacity >= (1 << 30)) {
// We can't make the table this big because we want an array of 2x
@@ -227,11 +240,58 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
data = newData
capacity = newCapacity
mask = newMask
- growThreshold = LOAD_FACTOR * newCapacity
+ growThreshold = (LOAD_FACTOR * newCapacity).toInt
}
private def nextPowerOf2(n: Int): Int = {
val highBit = Integer.highestOneBit(n)
if (highBit == n) n else highBit << 1
}
+
+ /**
+ * Return an iterator of the map in sorted order. This provides a way to sort the map without
+ * using additional memory, at the expense of destroying the validity of the map.
+ */
+ def destructiveSortedIterator(cmp: Comparator[(K, V)]): Iterator[(K, V)] = {
+ destroyed = true
+ // Pack KV pairs into the front of the underlying array
+ var keyIndex, newIndex = 0
+ while (keyIndex < capacity) {
+ if (data(2 * keyIndex) != null) {
+ data(newIndex) = (data(2 * keyIndex), data(2 * keyIndex + 1))
+ newIndex += 1
+ }
+ keyIndex += 1
+ }
+ assert(curSize == newIndex + (if (haveNullValue) 1 else 0))
+
+ // Sort by the given ordering
+ val rawOrdering = new Comparator[AnyRef] {
+ def compare(x: AnyRef, y: AnyRef): Int = {
+ cmp.compare(x.asInstanceOf[(K, V)], y.asInstanceOf[(K, V)])
+ }
+ }
+ Arrays.sort(data, 0, newIndex, rawOrdering)
+
+ new Iterator[(K, V)] {
+ var i = 0
+ var nullValueReady = haveNullValue
+ def hasNext: Boolean = (i < newIndex || nullValueReady)
+ def next(): (K, V) = {
+ if (nullValueReady) {
+ nullValueReady = false
+ (null.asInstanceOf[K], nullValue)
+ } else {
+ val item = data(i).asInstanceOf[(K, V)]
+ i += 1
+ item
+ }
+ }
+ }
+ }
+
+ /**
+ * Return whether the next insert will cause the map to grow
+ */
+ def atGrowThreshold: Boolean = curSize == growThreshold
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
new file mode 100644
index 0000000000..e3bcd895aa
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -0,0 +1,350 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io._
+import java.util.Comparator
+
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter}
+
+/**
+ * An append-only map that spills sorted content to disk when there is insufficient space for it
+ * to grow.
+ *
+ * This map takes two passes over the data:
+ *
+ * (1) Values are merged into combiners, which are sorted and spilled to disk as necessary
+ * (2) Combiners are read from disk and merged together
+ *
+ * The setting of the spill threshold faces the following trade-off: If the spill threshold is
+ * too high, the in-memory map may occupy more memory than is available, resulting in OOM.
+ * However, if the spill threshold is too low, we spill frequently and incur unnecessary disk
+ * writes. This may lead to a performance regression compared to the normal case of using the
+ * non-spilling AppendOnlyMap.
+ *
+ * Two parameters control the memory threshold:
+ *
+ * `spark.shuffle.memoryFraction` specifies the collective amount of memory used for storing
+ * these maps as a fraction of the executor's total memory. Since each concurrently running
+ * task maintains one map, the actual threshold for each map is this quantity divided by the
+ * number of running tasks.
+ *
+ * `spark.shuffle.safetyFraction` specifies an additional margin of safety as a fraction of
+ * this threshold, in case map size estimation is not sufficiently accurate.
+ */
+
+private[spark] class ExternalAppendOnlyMap[K, V, C](
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C,
+ serializer: Serializer = SparkEnv.get.serializerManager.default,
+ diskBlockManager: DiskBlockManager = SparkEnv.get.blockManager.diskBlockManager)
+ extends Iterable[(K, C)] with Serializable with Logging {
+
+ import ExternalAppendOnlyMap._
+
+ private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
+ private val spilledMaps = new ArrayBuffer[DiskMapIterator]
+ private val sparkConf = SparkEnv.get.conf
+
+ // Collective memory threshold shared across all running tasks
+ private val maxMemoryThreshold = {
+ val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.3)
+ val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8)
+ (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
+ }
+
+ // Number of pairs in the in-memory map
+ private var numPairsInMemory = 0
+
+ // Number of in-memory pairs inserted before tracking the map's shuffle memory usage
+ private val trackMemoryThreshold = 1000
+
+ // How many times we have spilled so far
+ private var spillCount = 0
+
+ private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
+ private val syncWrites = sparkConf.getBoolean("spark.shuffle.sync", false)
+ private val comparator = new KCComparator[K, C]
+ private val ser = serializer.newInstance()
+
+ /**
+ * Insert the given key and value into the map.
+ *
+ * If the underlying map is about to grow, check if the global pool of shuffle memory has
+ * enough room for this to happen. If so, allocate the memory required to grow the map;
+ * otherwise, spill the in-memory map to disk.
+ *
+ * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked.
+ */
+ def insert(key: K, value: V) {
+ val update: (Boolean, C) => C = (hadVal, oldVal) => {
+ if (hadVal) mergeValue(oldVal, value) else createCombiner(value)
+ }
+ if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) {
+ val mapSize = currentMap.estimateSize()
+ var shouldSpill = false
+ val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+
+ // Atomically check whether there is sufficient memory in the global pool for
+ // this map to grow and, if possible, allocate the required amount
+ shuffleMemoryMap.synchronized {
+ val threadId = Thread.currentThread().getId
+ val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
+ val availableMemory = maxMemoryThreshold -
+ (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
+
+ // Assume map growth factor is 2x
+ shouldSpill = availableMemory < mapSize * 2
+ if (!shouldSpill) {
+ shuffleMemoryMap(threadId) = mapSize * 2
+ }
+ }
+ // Do not synchronize spills
+ if (shouldSpill) {
+ spill(mapSize)
+ }
+ }
+ currentMap.changeValue(key, update)
+ numPairsInMemory += 1
+ }
+
+ /**
+ * Sort the existing contents of the in-memory map and spill them to a temporary file on disk
+ */
+ private def spill(mapSize: Long) {
+ spillCount += 1
+ logWarning("Spilling in-memory map of %d MB to disk (%d time%s so far)"
+ .format(mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
+ val (blockId, file) = diskBlockManager.createTempBlock()
+ val writer =
+ new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize, identity, syncWrites)
+ try {
+ val it = currentMap.destructiveSortedIterator(comparator)
+ while (it.hasNext) {
+ val kv = it.next()
+ writer.write(kv)
+ }
+ writer.commit()
+ } finally {
+ // Partial failures cannot be tolerated; do not revert partial writes
+ writer.close()
+ }
+ currentMap = new SizeTrackingAppendOnlyMap[K, C]
+ spilledMaps.append(new DiskMapIterator(file))
+
+ // Reset the amount of shuffle memory used by this map in the global pool
+ val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+ shuffleMemoryMap.synchronized {
+ shuffleMemoryMap(Thread.currentThread().getId) = 0
+ }
+ numPairsInMemory = 0
+ }
+
+ /**
+ * Return an iterator that merges the in-memory map with the spilled maps.
+ * If no spill has occurred, simply return the in-memory map's iterator.
+ */
+ override def iterator: Iterator[(K, C)] = {
+ if (spilledMaps.isEmpty) {
+ currentMap.iterator
+ } else {
+ new ExternalIterator()
+ }
+ }
+
+ /**
+ * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps
+ */
+ private class ExternalIterator extends Iterator[(K, C)] {
+
+ // A fixed-size queue that maintains a buffer for each stream we are currently merging
+ val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
+
+ // Input streams are derived both from the in-memory map and spilled maps on disk
+ // The in-memory map is sorted in place, while the spilled maps are already in sorted order
+ val sortedMap = currentMap.destructiveSortedIterator(comparator)
+ val inputStreams = Seq(sortedMap) ++ spilledMaps
+
+ inputStreams.foreach { it =>
+ val kcPairs = getMorePairs(it)
+ mergeHeap.enqueue(StreamBuffer(it, kcPairs))
+ }
+
+ /**
+ * Fetch from the given iterator until a key of different hash is retrieved. In the
+ * event of key hash collisions, this ensures no pairs are hidden from being merged.
+ * Assume the given iterator is in sorted order.
+ */
+ def getMorePairs(it: Iterator[(K, C)]): ArrayBuffer[(K, C)] = {
+ val kcPairs = new ArrayBuffer[(K, C)]
+ if (it.hasNext) {
+ var kc = it.next()
+ kcPairs += kc
+ val minHash = kc._1.hashCode()
+ while (it.hasNext && kc._1.hashCode() == minHash) {
+ kc = it.next()
+ kcPairs += kc
+ }
+ }
+ kcPairs
+ }
+
+ /**
+ * If the given buffer contains a value for the given key, merge that value into
+ * baseCombiner and remove the corresponding (K, C) pair from the buffer
+ */
+ def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = {
+ var i = 0
+ while (i < buffer.pairs.size) {
+ val (k, c) = buffer.pairs(i)
+ if (k == key) {
+ buffer.pairs.remove(i)
+ return mergeCombiners(baseCombiner, c)
+ }
+ i += 1
+ }
+ baseCombiner
+ }
+
+ /**
+ * Return true if there exists an input stream that still has unvisited pairs
+ */
+ override def hasNext: Boolean = mergeHeap.exists(!_.pairs.isEmpty)
+
+ /**
+ * Select a key with the minimum hash, then combine all values with the same key from all input streams.
+ */
+ override def next(): (K, C) = {
+ // Select a key from the StreamBuffer that holds the lowest key hash
+ val minBuffer = mergeHeap.dequeue()
+ val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash)
+ if (minPairs.length == 0) {
+ // Should only happen when no other stream buffers have any pairs left
+ throw new NoSuchElementException
+ }
+ var (minKey, minCombiner) = minPairs.remove(0)
+ assert(minKey.hashCode() == minHash)
+
+ // For all other streams that may have this key (i.e. have the same minimum key hash),
+ // merge in the corresponding value (if any) from that stream
+ val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer)
+ while (!mergeHeap.isEmpty && mergeHeap.head.minKeyHash == minHash) {
+ val newBuffer = mergeHeap.dequeue()
+ minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer)
+ mergedBuffers += newBuffer
+ }
+
+ // Repopulate each visited stream buffer and add it back to the merge heap
+ mergedBuffers.foreach { buffer =>
+ if (buffer.pairs.length == 0) {
+ buffer.pairs ++= getMorePairs(buffer.iterator)
+ }
+ mergeHeap.enqueue(buffer)
+ }
+
+ (minKey, minCombiner)
+ }
+
+ /**
+ * A buffer for streaming from a map iterator (in-memory or on-disk) sorted by key hash.
+ * Each buffer maintains the lowest-ordered keys in the corresponding iterator. Due to
+ * hash collisions, it is possible for multiple keys to be "tied" for being the lowest.
+ *
+ * StreamBuffers are ordered by the minimum key hash found across all of their own pairs.
+ */
+ case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)])
+ extends Comparable[StreamBuffer] {
+
+ def minKeyHash: Int = {
+ if (pairs.length > 0){
+ // pairs are already sorted by key hash
+ pairs(0)._1.hashCode()
+ } else {
+ Int.MaxValue
+ }
+ }
+
+ override def compareTo(other: StreamBuffer): Int = {
+ // minus sign because mutable.PriorityQueue dequeues the max, not the min
+ -minKeyHash.compareTo(other.minKeyHash)
+ }
+ }
+ }
+
+ /**
+ * An iterator that returns (K, C) pairs in sorted order from an on-disk map
+ */
+ private class DiskMapIterator(file: File) extends Iterator[(K, C)] {
+ val fileStream = new FileInputStream(file)
+ val bufferedStream = new FastBufferedInputStream(fileStream)
+ val deserializeStream = ser.deserializeStream(bufferedStream)
+ var nextItem: (K, C) = null
+ var eof = false
+
+ def readNextItem(): (K, C) = {
+ if (!eof) {
+ try {
+ return deserializeStream.readObject().asInstanceOf[(K, C)]
+ } catch {
+ case e: EOFException =>
+ eof = true
+ cleanup()
+ }
+ }
+ null
+ }
+
+ override def hasNext: Boolean = {
+ if (nextItem == null) {
+ nextItem = readNextItem()
+ }
+ nextItem != null
+ }
+
+ override def next(): (K, C) = {
+ val item = if (nextItem == null) readNextItem() else nextItem
+ if (item == null) {
+ throw new NoSuchElementException
+ }
+ nextItem = null
+ item
+ }
+
+ // TODO: Ensure this gets called even if the iterator isn't drained.
+ def cleanup() {
+ deserializeStream.close()
+ file.delete()
+ }
+ }
+}
+
+private[spark] object ExternalAppendOnlyMap {
+ private class KCComparator[K, C] extends Comparator[(K, C)] {
+ def compare(kc1: (K, C), kc2: (K, C)): Int = {
+ kc1._1.hashCode().compareTo(kc2._1.hashCode())
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
new file mode 100644
index 0000000000..204330dad4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.util.SizeEstimator
+import org.apache.spark.util.collection.SizeTrackingAppendOnlyMap.Sample
+
+/**
+ * Append-only map that keeps track of its estimated size in bytes.
+ * We sample with a slow exponential back-off using the SizeEstimator to amortize the time,
+ * as each call to SizeEstimator can take a sizable amount of time (order of a few milliseconds).
+ */
+private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] {
+
+ /**
+ * Controls the base of the exponential which governs the rate of sampling.
+ * E.g., a value of 2 would mean we sample at 1, 2, 4, 8, ... elements.
+ */
+ private val SAMPLE_GROWTH_RATE = 1.1
+
+ /** All samples taken since last resetSamples(). Only the last two are used for extrapolation. */
+ private val samples = new ArrayBuffer[Sample]()
+
+ /** Total number of insertions and updates into the map since the last resetSamples(). */
+ private var numUpdates: Long = _
+
+ /** The value of 'numUpdates' at which we will take our next sample. */
+ private var nextSampleNum: Long = _
+
+ /** The average number of bytes per update between our last two samples. */
+ private var bytesPerUpdate: Double = _
+
+ resetSamples()
+
+ /** Called after the map grows in size, as this can be a dramatic change for small objects. */
+ def resetSamples() {
+ numUpdates = 1
+ nextSampleNum = 1
+ samples.clear()
+ takeSample()
+ }
+
+ override def update(key: K, value: V): Unit = {
+ super.update(key, value)
+ numUpdates += 1
+ if (nextSampleNum == numUpdates) { takeSample() }
+ }
+
+ override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
+ val newValue = super.changeValue(key, updateFunc)
+ numUpdates += 1
+ if (nextSampleNum == numUpdates) { takeSample() }
+ newValue
+ }
+
+ /** Takes a new sample of the current map's size. */
+ def takeSample() {
+ samples += Sample(SizeEstimator.estimate(this), numUpdates)
+ // Only use the last two samples to extrapolate. If fewer than 2 samples, assume no change.
+ bytesPerUpdate = math.max(0, samples.toSeq.reverse match {
+ case latest :: previous :: tail =>
+ (latest.size - previous.size).toDouble / (latest.numUpdates - previous.numUpdates)
+ case _ =>
+ 0
+ })
+ nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong
+ }
+
+ override protected def growTable() {
+ super.growTable()
+ resetSamples()
+ }
+
+ /** Estimates the current size of the map in bytes. O(1) time. */
+ def estimateSize(): Long = {
+ assert(samples.nonEmpty)
+ val extrapolatedDelta = bytesPerUpdate * (numUpdates - samples.last.numUpdates)
+ (samples.last.size + extrapolatedDelta).toLong
+ }
+}
+
+private object SizeTrackingAppendOnlyMap {
+ case class Sample(size: Long, numUpdates: Long)
+}
diff --git a/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala
new file mode 100644
index 0000000000..93f0c6a8e6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/SizeTrackingAppendOnlyMapSuite.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.util.Random
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.util.SizeTrackingAppendOnlyMapSuite.LargeDummyClass
+import org.apache.spark.util.collection.{AppendOnlyMap, SizeTrackingAppendOnlyMap}
+
+class SizeTrackingAppendOnlyMapSuite extends FunSuite with BeforeAndAfterAll {
+ val NORMAL_ERROR = 0.20
+ val HIGH_ERROR = 0.30
+
+ test("fixed size insertions") {
+ testWith[Int, Long](10000, i => (i, i.toLong))
+ testWith[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong)))
+ testWith[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass()))
+ }
+
+ test("variable size insertions") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testWith[Int, String](10000, i => (i, randString(0, 10)))
+ testWith[Int, String](10000, i => (i, randString(0, 100)))
+ testWith[Int, String](10000, i => (i, randString(90, 100)))
+ }
+
+ test("updates") {
+ val rand = new Random(123456789)
+ def randString(minLen: Int, maxLen: Int): String = {
+ "a" * (rand.nextInt(maxLen - minLen) + minLen)
+ }
+ testWith[String, Int](10000, i => (randString(0, 10000), i))
+ }
+
+ def testWith[K, V](numElements: Int, makeElement: (Int) => (K, V)) {
+ val map = new SizeTrackingAppendOnlyMap[K, V]()
+ for (i <- 0 until numElements) {
+ val (k, v) = makeElement(i)
+ map(k) = v
+ expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR)
+ }
+ }
+
+ def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) {
+ val betterEstimatedSize = SizeEstimator.estimate(obj)
+ assert(betterEstimatedSize * (1 - error) < estimatedSize,
+ s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize")
+ assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize,
+ s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize")
+ }
+}
+
+object SizeTrackingAppendOnlyMapSuite {
+ // Speed test, for reproducibility of results.
+ // These could be highly non-deterministic in general, however.
+ // Results:
+ // AppendOnlyMap: 31 ms
+ // SizeTracker: 54 ms
+ // SizeEstimator: 1500 ms
+ def main(args: Array[String]) {
+ val numElements = 100000
+
+ val baseTimes = for (i <- 0 until 10) yield time {
+ val map = new AppendOnlyMap[Int, LargeDummyClass]()
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass()
+ }
+ }
+
+ val sampledTimes = for (i <- 0 until 10) yield time {
+ val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]()
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass()
+ map.estimateSize()
+ }
+ }
+
+ val unsampledTimes = for (i <- 0 until 3) yield time {
+ val map = new AppendOnlyMap[Int, LargeDummyClass]()
+ for (i <- 0 until numElements) {
+ map(i) = new LargeDummyClass()
+ SizeEstimator.estimate(map)
+ }
+ }
+
+ println("Base: " + baseTimes)
+ println("SizeTracker (sampled): " + sampledTimes)
+ println("SizeEstimator (unsampled): " + unsampledTimes)
+ }
+
+ def time(f: => Unit): Long = {
+ val start = System.currentTimeMillis()
+ f
+ System.currentTimeMillis() - start
+ }
+
+ private class LargeDummyClass {
+ val arr = new Array[Int](100)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala
new file mode 100644
index 0000000000..7006571ef0
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+/**
+ * Tests org.apache.spark.util.Vector functionality
+ */
+class VectorSuite extends FunSuite {
+
+ def verifyVector(vector: Vector, expectedLength: Int) = {
+ assert(vector.length == expectedLength)
+ assert(vector.elements.min > 0.0)
+ assert(vector.elements.max < 1.0)
+ }
+
+ test("random with default random number generator") {
+ val vector100 = Vector.random(100)
+ verifyVector(vector100, 100)
+ }
+
+ test("random with given random number generator") {
+ val vector100 = Vector.random(100, new Random(100))
+ verifyVector(vector100, 100)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala
index 7177919a58..f44442f1a5 100644
--- a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.util
+package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
+import java.util.Comparator
class AppendOnlyMapSuite extends FunSuite {
test("initialization") {
@@ -151,4 +152,47 @@ class AppendOnlyMapSuite extends FunSuite {
assert(map("" + i) === "" + i)
}
}
+
+ test("destructive sort") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ map.update(null, "happy new year!")
+
+ try {
+ map.apply("1")
+ map.update("1", "2013")
+ map.changeValue("1", (hadValue, oldValue) => "2014")
+ map.iterator
+ } catch {
+ case e: IllegalStateException => fail()
+ }
+
+ val it = map.destructiveSortedIterator(new Comparator[(String, String)] {
+ def compare(kv1: (String, String), kv2: (String, String)): Int = {
+ val x = if (kv1 != null && kv1._1 != null) kv1._1.toInt else Int.MinValue
+ val y = if (kv2 != null && kv2._1 != null) kv2._1.toInt else Int.MinValue
+ x.compareTo(y)
+ }
+ })
+
+ // Should be sorted by key
+ assert(it.hasNext)
+ var previous = it.next()
+ assert(previous == (null, "happy new year!"))
+ previous = it.next()
+ assert(previous == ("1", "2014"))
+ while (it.hasNext) {
+ val kv = it.next()
+ assert(kv._1.toInt > previous._1.toInt)
+ previous = kv
+ }
+
+ // All subsequent calls to apply, update, changeValue and iterator should throw exception
+ intercept[AssertionError] { map.apply("1") }
+ intercept[AssertionError] { map.update("1", "2013") }
+ intercept[AssertionError] { map.changeValue("1", (hadValue, oldValue) => "2014") }
+ intercept[AssertionError] { map.iterator }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
new file mode 100644
index 0000000000..ef957bb0e5
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -0,0 +1,230 @@
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+
+class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ override def beforeEach() {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.externalSorting", "true")
+ sc = new SparkContext("local", "test", conf)
+ }
+
+ val createCombiner: (Int => ArrayBuffer[Int]) = i => ArrayBuffer[Int](i)
+ val mergeValue: (ArrayBuffer[Int], Int) => ArrayBuffer[Int] = (buffer, i) => {
+ buffer += i
+ }
+ val mergeCombiners: (ArrayBuffer[Int], ArrayBuffer[Int]) => ArrayBuffer[Int] =
+ (buf1, buf2) => {
+ buf1 ++= buf2
+ }
+
+ test("simple insert") {
+ val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+
+ // Single insert
+ map.insert(1, 10)
+ var it = map.iterator
+ assert(it.hasNext)
+ val kv = it.next()
+ assert(kv._1 == 1 && kv._2 == ArrayBuffer[Int](10))
+ assert(!it.hasNext)
+
+ // Multiple insert
+ map.insert(2, 20)
+ map.insert(3, 30)
+ it = map.iterator
+ assert(it.hasNext)
+ assert(it.toSet == Set[(Int, ArrayBuffer[Int])](
+ (1, ArrayBuffer[Int](10)),
+ (2, ArrayBuffer[Int](20)),
+ (3, ArrayBuffer[Int](30))))
+ }
+
+ test("insert with collision") {
+ val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+
+ map.insert(1, 10)
+ map.insert(2, 20)
+ map.insert(3, 30)
+ map.insert(1, 100)
+ map.insert(2, 200)
+ map.insert(1, 1000)
+ val it = map.iterator
+ assert(it.hasNext)
+ val result = it.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
+ assert(result == Set[(Int, Set[Int])](
+ (1, Set[Int](10, 100, 1000)),
+ (2, Set[Int](20, 200)),
+ (3, Set[Int](30))))
+ }
+
+ test("ordering") {
+ val map1 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map1.insert(1, 10)
+ map1.insert(2, 20)
+ map1.insert(3, 30)
+
+ val map2 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map2.insert(2, 20)
+ map2.insert(3, 30)
+ map2.insert(1, 10)
+
+ val map3 = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map3.insert(3, 30)
+ map3.insert(1, 10)
+ map3.insert(2, 20)
+
+ val it1 = map1.iterator
+ val it2 = map2.iterator
+ val it3 = map3.iterator
+
+ var kv1 = it1.next()
+ var kv2 = it2.next()
+ var kv3 = it3.next()
+ assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
+ assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+
+ kv1 = it1.next()
+ kv2 = it2.next()
+ kv3 = it3.next()
+ assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
+ assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+
+ kv1 = it1.next()
+ kv2 = it2.next()
+ kv3 = it3.next()
+ assert(kv1._1 == kv2._1 && kv2._1 == kv3._1)
+ assert(kv1._2 == kv2._2 && kv2._2 == kv3._2)
+ }
+
+ test("null keys and values") {
+ val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
+ mergeValue, mergeCombiners)
+ map.insert(1, 5)
+ map.insert(2, 6)
+ map.insert(3, 7)
+ assert(map.size === 3)
+ assert(map.iterator.toSet == Set[(Int, Seq[Int])](
+ (1, Seq[Int](5)),
+ (2, Seq[Int](6)),
+ (3, Seq[Int](7))
+ ))
+
+ // Null keys
+ val nullInt = null.asInstanceOf[Int]
+ map.insert(nullInt, 8)
+ assert(map.size === 4)
+ assert(map.iterator.toSet == Set[(Int, Seq[Int])](
+ (1, Seq[Int](5)),
+ (2, Seq[Int](6)),
+ (3, Seq[Int](7)),
+ (nullInt, Seq[Int](8))
+ ))
+
+ // Null values
+ map.insert(4, nullInt)
+ map.insert(nullInt, nullInt)
+ assert(map.size === 5)
+ val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
+ assert(result == Set[(Int, Set[Int])](
+ (1, Set[Int](5)),
+ (2, Set[Int](6)),
+ (3, Set[Int](7)),
+ (4, Set[Int](nullInt)),
+ (nullInt, Set[Int](nullInt, 8))
+ ))
+ }
+
+ test("simple aggregator") {
+ // reduceByKey
+ val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1))
+ val result1 = rdd.reduceByKey(_+_).collect()
+ assert(result1.toSet == Set[(Int, Int)]((0, 5), (1, 5)))
+
+ // groupByKey
+ val result2 = rdd.groupByKey().collect()
+ assert(result2.toSet == Set[(Int, Seq[Int])]
+ ((0, ArrayBuffer[Int](1, 1, 1, 1, 1)), (1, ArrayBuffer[Int](1, 1, 1, 1, 1))))
+ }
+
+ test("simple cogroup") {
+ val rdd1 = sc.parallelize(1 to 4).map(i => (i, i))
+ val rdd2 = sc.parallelize(1 to 4).map(i => (i%2, i))
+ val result = rdd1.cogroup(rdd2).collect()
+
+ result.foreach { case (i, (seq1, seq2)) =>
+ i match {
+ case 0 => assert(seq1.toSet == Set[Int]() && seq2.toSet == Set[Int](2, 4))
+ case 1 => assert(seq1.toSet == Set[Int](1) && seq2.toSet == Set[Int](1, 3))
+ case 2 => assert(seq1.toSet == Set[Int](2) && seq2.toSet == Set[Int]())
+ case 3 => assert(seq1.toSet == Set[Int](3) && seq2.toSet == Set[Int]())
+ case 4 => assert(seq1.toSet == Set[Int](4) && seq2.toSet == Set[Int]())
+ }
+ }
+ }
+
+ test("spilling") {
+ // TODO: Figure out correct memory parameters to actually induce spilling
+ // System.setProperty("spark.shuffle.buffer.mb", "1")
+ // System.setProperty("spark.shuffle.buffer.fraction", "0.05")
+
+ // reduceByKey - should spill exactly 6 times
+ val rddA = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ val resultA = rddA.reduceByKey(math.max(_, _)).collect()
+ assert(resultA.length == 5000)
+ resultA.foreach { case(k, v) =>
+ k match {
+ case 0 => assert(v == 1)
+ case 2500 => assert(v == 5001)
+ case 4999 => assert(v == 9999)
+ case _ =>
+ }
+ }
+
+ // groupByKey - should spill exactly 11 times
+ val rddB = sc.parallelize(0 until 10000).map(i => (i/4, i))
+ val resultB = rddB.groupByKey().collect()
+ assert(resultB.length == 2500)
+ resultB.foreach { case(i, seq) =>
+ i match {
+ case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3))
+ case 1250 => assert(seq.toSet == Set[Int](5000, 5001, 5002, 5003))
+ case 2499 => assert(seq.toSet == Set[Int](9996, 9997, 9998, 9999))
+ case _ =>
+ }
+ }
+
+ // cogroup - should spill exactly 7 times
+ val rddC1 = sc.parallelize(0 until 1000).map(i => (i, i))
+ val rddC2 = sc.parallelize(0 until 1000).map(i => (i%100, i))
+ val resultC = rddC1.cogroup(rddC2).collect()
+ assert(resultC.length == 1000)
+ resultC.foreach { case(i, (seq1, seq2)) =>
+ i match {
+ case 0 =>
+ assert(seq1.toSet == Set[Int](0))
+ assert(seq2.toSet == Set[Int](0, 100, 200, 300, 400, 500, 600, 700, 800, 900))
+ case 500 =>
+ assert(seq1.toSet == Set[Int](500))
+ assert(seq2.toSet == Set[Int]())
+ case 999 =>
+ assert(seq1.toSet == Set[Int](999))
+ assert(seq2.toSet == Set[Int]())
+ case _ =>
+ }
+ }
+ }
+
+ // TODO: Test memory allocation for multiple concurrently running tasks
+}
diff --git a/docs/configuration.md b/docs/configuration.md
index 6717757781..ad75e06fc7 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -104,14 +104,25 @@ Apart from these, the following properties are also available, and may be useful
</tr>
<tr>
<td>spark.storage.memoryFraction</td>
- <td>0.66</td>
+ <td>0.6</td>
<td>
Fraction of Java heap to use for Spark's memory cache. This should not be larger than the "old"
- generation of objects in the JVM, which by default is given 2/3 of the heap, but you can increase
+ generation of objects in the JVM, which by default is given 0.6 of the heap, but you can increase
it if you configure your own old generation size.
</td>
</tr>
<tr>
+ <td>spark.shuffle.memoryFraction</td>
+ <td>0.3</td>
+ <td>
+ Fraction of Java heap to use for aggregation and cogroups during shuffles, if
+ <code>spark.shuffle.externalSorting</code> is enabled. At any given time, the collective size of
+ all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will
+ begin to spill to disk. If spills are often, consider increasing this value at the expense of
+ <code>spark.storage.memoryFraction</code>.
+ </td>
+</tr>
+<tr>
<td>spark.mesos.coarse</td>
<td>false</td>
<td>
@@ -371,12 +382,20 @@ Apart from these, the following properties are also available, and may be useful
<tr>
<td>spark.shuffle.consolidateFiles</td>
- <td>false</td>
+ <td>true</td>
<td>
If set to "true", consolidates intermediate files created during a shuffle. Creating fewer files can improve filesystem performance for shuffles with large numbers of reduce tasks. It is recommended to set this to "true" when using ext4 or xfs filesystems. On ext3, this option might degrade performance on machines with many (>8) cores due to filesystem limitations.
</td>
</tr>
<tr>
+ <td>spark.shuffle.externalSorting</td>
+ <td>true</td>
+ <td>
+ If set to "true", limits the amount of memory used during reduces by spilling data out to disk. This spilling
+ threshold is specified by <code>spark.shuffle.memoryFraction</code>.
+ </td>
+</tr>
+<tr>
<td>spark.speculation</td>
<td>false</td>
<td>
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index dc187b3efe..c4236f8312 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -99,8 +99,9 @@ $ MASTER=local[4] ./bin/pyspark
## IPython
-It is also possible to launch PySpark in [IPython](http://ipython.org), the enhanced Python interpreter.
-To do this, set the `IPYTHON` variable to `1` when running `bin/pyspark`:
+It is also possible to launch PySpark in [IPython](http://ipython.org), the
+enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To
+use IPython, set the `IPYTHON` variable to `1` when running `bin/pyspark`:
{% highlight bash %}
$ IPYTHON=1 ./bin/pyspark
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index b206270107..3bd62646ba 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -101,7 +101,19 @@ With this mode, your application is actually run on the remote machine where the
With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR
-In order to tune worker core/number/memory etc. You need to export SPARK_WORKER_CORES, SPARK_WORKER_MEMORY, SPARK_WORKER_INSTANCES e.g. by ./conf/spark-env.sh
+Configuration in yarn-client mode:
+
+In order to tune worker core/number/memory etc. You need to export environment variables or add them to the spark configuration file (./conf/spark_env.sh). The following are the list of options.
+
+* `SPARK_YARN_APP_JAR`, Path to your application's JAR file (required)
+* `SPARK_WORKER_INSTANCES`, Number of workers to start (Default: 2)
+* `SPARK_WORKER_CORES`, Number of cores for the workers (Default: 1).
+* `SPARK_WORKER_MEMORY`, Memory per Worker (e.g. 1000M, 2G) (Default: 1G)
+* `SPARK_MASTER_MEMORY`, Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)
+* `SPARK_YARN_APP_NAME`, The name of your application (Default: Spark)
+* `SPARK_YARN_QUEUE`, The hadoop queue to use for allocation requests (Default: 'default')
+* `SPARK_YARN_DIST_FILES`, Comma separated list of files to be distributed with the job.
+* `SPARK_YARN_DIST_ARCHIVES`, Comma separated list of archives to be distributed with the job.
For example:
@@ -114,7 +126,6 @@ For example:
SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
MASTER=yarn-client ./bin/spark-shell
-You can also send extra files to yarn cluster for worker to use by exporting SPARK_YARN_DIST_FILES=file1,file2... etc.
# Building Spark for Hadoop/YARN 2.2.x
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index d82a1e1490..e7cb5ab3ff 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -185,7 +185,11 @@ def get_spark_ami(opts):
"hi1.4xlarge": "hvm",
"m3.xlarge": "hvm",
"m3.2xlarge": "hvm",
- "cr1.8xlarge": "hvm"
+ "cr1.8xlarge": "hvm",
+ "i2.xlarge": "hvm",
+ "i2.2xlarge": "hvm",
+ "i2.4xlarge": "hvm",
+ "i2.8xlarge": "hvm"
}
if opts.instance_type in instance_types:
instance_type = instance_types[opts.instance_type]
@@ -478,7 +482,11 @@ def get_num_disks(instance_type):
"cr1.8xlarge": 2,
"hi1.4xlarge": 2,
"m3.xlarge": 0,
- "m3.2xlarge": 0
+ "m3.2xlarge": 0,
+ "i2.xlarge": 1,
+ "i2.2xlarge": 2,
+ "i2.4xlarge": 4,
+ "i2.8xlarge": 8
}
if instance_type in disks_by_instance:
return disks_by_instance[instance_type]
diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java
index b11cfa667e..7b5a243e26 100644
--- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java
+++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaFlumeEventCount.java
@@ -47,6 +47,8 @@ public final class JavaFlumeEventCount {
System.exit(1);
}
+ StreamingExamples.setStreamingLogLevels();
+
String master = args[0];
String host = args[1];
int port = Integer.parseInt(args[2]);
diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
index 16b8a948e6..04f62ee204 100644
--- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
@@ -59,6 +59,8 @@ public final class JavaKafkaWordCount {
System.exit(1);
}
+ StreamingExamples.setStreamingLogLevels();
+
// Create the context with a 1 second batch size
JavaStreamingContext jssc = new JavaStreamingContext(args[0], "KafkaWordCount",
new Duration(2000), System.getenv("SPARK_HOME"),
diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java
index 1e2efd359c..349d826ab5 100644
--- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java
@@ -38,7 +38,7 @@ import java.util.regex.Pattern;
* To run this on your local machine, you need to first run a Netcat server
* `$ nc -lk 9999`
* and then run the example
- * `$ ./run spark.streaming.examples.JavaNetworkWordCount local[2] localhost 9999`
+ * `$ ./run org.apache.spark.streaming.examples.JavaNetworkWordCount local[2] localhost 9999`
*/
public final class JavaNetworkWordCount {
private static final Pattern SPACE = Pattern.compile(" ");
@@ -48,18 +48,20 @@ public final class JavaNetworkWordCount {
public static void main(String[] args) {
if (args.length < 3) {
- System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" +
+ System.err.println("Usage: JavaNetworkWordCount <master> <hostname> <port>\n" +
"In local mode, <master> should be 'local[n]' with n > 1");
System.exit(1);
}
+ StreamingExamples.setStreamingLogLevels();
+
// Create the context with a 1 second batch size
- JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount",
+ JavaStreamingContext ssc = new JavaStreamingContext(args[0], "JavaNetworkWordCount",
new Duration(1000), System.getenv("SPARK_HOME"),
JavaStreamingContext.jarOfClass(JavaNetworkWordCount.class));
// Create a NetworkInputDStream on target ip:port and count the
- // words in input stream of \n delimited test (eg. generated by 'nc')
+ // words in input stream of \n delimited text (eg. generated by 'nc')
JavaDStream<String> lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2]));
JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
@Override
@@ -82,6 +84,5 @@ public final class JavaNetworkWordCount {
wordCounts.print();
ssc.start();
-
}
}
diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java
index e05551ab83..7ef9c6c8f4 100644
--- a/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java
+++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaQueueStream.java
@@ -41,6 +41,8 @@ public final class JavaQueueStream {
System.exit(1);
}
+ StreamingExamples.setStreamingLogLevels();
+
// Create the context
JavaStreamingContext ssc = new JavaStreamingContext(args[0], "QueueStream", new Duration(1000),
System.getenv("SPARK_HOME"), JavaStreamingContext.jarOfClass(JavaQueueStream.class));
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
index 4e0058cd70..57e1b1f806 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
@@ -18,17 +18,13 @@
package org.apache.spark.streaming.examples
import scala.collection.mutable.LinkedList
-import scala.util.Random
import scala.reflect.ClassTag
+import scala.util.Random
-import akka.actor.Actor
-import akka.actor.ActorRef
-import akka.actor.Props
-import akka.actor.actorRef2Scala
+import akka.actor.{Actor, ActorRef, Props, actorRef2Scala}
import org.apache.spark.SparkConf
-import org.apache.spark.streaming.Seconds
-import org.apache.spark.streaming.StreamingContext
+import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions
import org.apache.spark.streaming.receivers.Receiver
import org.apache.spark.util.AkkaUtils
@@ -147,6 +143,8 @@ object ActorWordCount {
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
+
val Seq(master, host, port) = args.toSeq
// Create the context and set the batch size
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala
index ae3709b3d9..a59be7899d 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/FlumeEventCount.scala
@@ -17,10 +17,10 @@
package org.apache.spark.streaming.examples
-import org.apache.spark.util.IntParam
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
import org.apache.spark.streaming.flume._
+import org.apache.spark.util.IntParam
/**
* Produces a count of events received from Flume.
@@ -44,6 +44,8 @@ object FlumeEventCount {
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
+
val Array(master, host, IntParam(port)) = args
val batchInterval = Milliseconds(2000)
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala
index ea6ea67419..704b315ef8 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala
@@ -20,7 +20,6 @@ package org.apache.spark.streaming.examples
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext._
-
/**
* Counts words in new text files created in the given directory
* Usage: HdfsWordCount <master> <directory>
@@ -38,6 +37,8 @@ object HdfsWordCount {
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
+
// Create the context
val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2),
System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass))
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
index 31a94bd224..4a3d81c09a 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
@@ -23,8 +23,8 @@ import kafka.producer._
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.streaming.util.RawTextHelper._
import org.apache.spark.streaming.kafka._
+import org.apache.spark.streaming.util.RawTextHelper._
/**
* Consumes messages from one or more topics in Kafka and does wordcount.
@@ -40,12 +40,13 @@ import org.apache.spark.streaming.kafka._
*/
object KafkaWordCount {
def main(args: Array[String]) {
-
if (args.length < 5) {
System.err.println("Usage: KafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>")
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
+
val Array(master, zkQuorum, group, topics, numThreads) = args
val ssc = new StreamingContext(master, "KafkaWordCount", Seconds(2),
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala
index 325290b66f..78b49fdcf1 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala
@@ -17,12 +17,8 @@
package org.apache.spark.streaming.examples
-import org.eclipse.paho.client.mqttv3.MqttClient
-import org.eclipse.paho.client.mqttv3.MqttClientPersistence
+import org.eclipse.paho.client.mqttv3.{MqttClient, MqttClientPersistence, MqttException, MqttMessage, MqttTopic}
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
-import org.eclipse.paho.client.mqttv3.MqttException
-import org.eclipse.paho.client.mqttv3.MqttMessage
-import org.eclipse.paho.client.mqttv3.MqttTopic
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext}
@@ -43,6 +39,8 @@ object MQTTPublisher {
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
+
val Seq(brokerUrl, topic) = args.toSeq
try {
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala
index 6a32c75373..25f7013307 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala
@@ -21,7 +21,8 @@ import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext._
/**
- * Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ * Counts words in text encoded with UTF8 received from the network every second.
+ *
* Usage: NetworkWordCount <master> <hostname> <port>
* <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
* <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
@@ -39,12 +40,14 @@ object NetworkWordCount {
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
+
// Create the context with a 1 second batch size
val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1),
System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass))
// Create a NetworkInputDStream on target ip:port and count the
- // words in input stream of \n delimited test (eg. generated by 'nc')
+ // words in input stream of \n delimited text (eg. generated by 'nc')
val lines = ssc.socketTextStream(args(1), args(2).toInt)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala
index 9d640e716b..4d4968ba6a 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala
@@ -17,12 +17,12 @@
package org.apache.spark.streaming.examples
+import scala.collection.mutable.SynchronizedQueue
+
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext._
-import scala.collection.mutable.SynchronizedQueue
-
object QueueStream {
def main(args: Array[String]) {
@@ -30,7 +30,9 @@ object QueueStream {
System.err.println("Usage: QueueStream <master>")
System.exit(1)
}
-
+
+ StreamingExamples.setStreamingLogLevels()
+
// Create the context
val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1),
System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass))
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala
index c0706d0724..3d08d86567 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RawNetworkGrep.scala
@@ -17,11 +17,10 @@
package org.apache.spark.streaming.examples
-import org.apache.spark.util.IntParam
import org.apache.spark.storage.StorageLevel
-
import org.apache.spark.streaming._
import org.apache.spark.streaming.util.RawTextHelper
+import org.apache.spark.util.IntParam
/**
* Receives text from multiple rawNetworkStreams and counts how many '\n' delimited
@@ -45,6 +44,8 @@ object RawNetworkGrep {
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
+
val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args
// Create the context
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala
new file mode 100644
index 0000000000..d51e6e9418
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/RecoverableNetworkWordCount.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.examples
+
+import org.apache.spark.streaming.{Time, Seconds, StreamingContext}
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.util.IntParam
+import java.io.File
+import org.apache.spark.rdd.RDD
+import com.google.common.io.Files
+import java.nio.charset.Charset
+
+/**
+ * Counts words in text encoded with UTF8 received from the network every second.
+ *
+ * Usage: NetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-file>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
+ * <checkpoint-directory> directory to HDFS-compatible file system which checkpoint data
+ * <output-file> file to which the word counts will be appended
+ *
+ * In local mode, <master> should be 'local[n]' with n > 1
+ * <checkpoint-directory> and <output-file> must be absolute paths
+ *
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ *
+ * `$ nc -lk 9999`
+ *
+ * and run the example as
+ *
+ * `$ ./run-example org.apache.spark.streaming.examples.RecoverableNetworkWordCount \
+ * local[2] localhost 9999 ~/checkpoint/ ~/out`
+ *
+ * If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create
+ * a new StreamingContext (will print "Creating new context" to the console). Otherwise, if
+ * checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from
+ * the checkpoint data.
+ *
+ * To run this example in a local standalone cluster with automatic driver recovery,
+ *
+ * `$ ./spark-class org.apache.spark.deploy.Client -s launch <cluster-url> <path-to-examples-jar> \
+ * org.apache.spark.streaming.examples.RecoverableNetworkWordCount <cluster-url> \
+ * localhost 9999 ~/checkpoint ~/out`
+ *
+ * <path-to-examples-jar> would typically be <spark-dir>/examples/target/scala-XX/spark-examples....jar
+ *
+ * Refer to the online documentation for more details.
+ */
+
+object RecoverableNetworkWordCount {
+
+ def createContext(master: String, ip: String, port: Int, outputPath: String) = {
+
+ // If you do not see this printed, that means the StreamingContext has been loaded
+ // from the new checkpoint
+ println("Creating new context")
+ val outputFile = new File(outputPath)
+ if (outputFile.exists()) outputFile.delete()
+
+ // Create the context with a 1 second batch size
+ val ssc = new StreamingContext(master, "RecoverableNetworkWordCount", Seconds(1),
+ System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass))
+
+ // Create a NetworkInputDStream on target ip:port and count the
+ // words in input stream of \n delimited text (eg. generated by 'nc')
+ val lines = ssc.socketTextStream(ip, port)
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.foreach((rdd: RDD[(String, Int)], time: Time) => {
+ val counts = "Counts at time " + time + " " + rdd.collect().mkString("[", ", ", "]")
+ println(counts)
+ println("Appending to " + outputFile.getAbsolutePath)
+ Files.append(counts + "\n", outputFile, Charset.defaultCharset())
+ })
+ ssc
+ }
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ System.err.println("You arguments were " + args.mkString("[", ", ", "]"))
+ System.err.println(
+ """
+ |Usage: RecoverableNetworkWordCount <master> <hostname> <port> <checkpoint-directory> <output-file>
+ | <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ | <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
+ | <checkpoint-directory> directory to HDFS-compatible file system which checkpoint data
+ | <output-file> file to which the word counts will be appended
+ |
+ |In local mode, <master> should be 'local[n]' with n > 1
+ |Both <checkpoint-directory> and <output-file> must be absolute paths
+ """.stripMargin
+ )
+ System.exit(1)
+ }
+ val Array(master, ip, IntParam(port), checkpointDirectory, outputPath) = args
+ val ssc = StreamingContext.getOrCreate(checkpointDirectory,
+ () => {
+ createContext(master, ip, port, outputPath)
+ })
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala
index 002db57d59..1183eba846 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala
@@ -39,6 +39,8 @@ object StatefulNetworkWordCount {
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
+
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
val currentCount = values.foldLeft(0)(_ + _)
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/StreamingExamples.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/StreamingExamples.scala
new file mode 100644
index 0000000000..d41d84a980
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/StreamingExamples.scala
@@ -0,0 +1,21 @@
+package org.apache.spark.streaming.examples
+
+import org.apache.spark.Logging
+
+import org.apache.log4j.{Level, Logger}
+
+/** Utility functions for Spark Streaming examples. */
+object StreamingExamples extends Logging {
+
+ /** Set reasonable logging levels for streaming if the user has not configured log4j. */
+ def setStreamingLogLevels() {
+ val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements
+ if (!log4jInitialized) {
+ // We first log something to initialize Spark's default logging, then we override the
+ // logging level.
+ logInfo("Setting log level to [WARN] for streaming example." +
+ " To override add a custom log4j.properties to the classpath.")
+ Logger.getRootLogger.setLevel(Level.WARN)
+ }
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala
index 3ccdc908e2..80b5a98b14 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdCMS.scala
@@ -17,12 +17,12 @@
package org.apache.spark.streaming.examples
-import org.apache.spark.streaming.{Seconds, StreamingContext}
-import org.apache.spark.storage.StorageLevel
import com.twitter.algebird._
-import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.SparkContext._
+import org.apache.spark.SparkContext._
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Seconds, StreamingContext}
+import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.twitter._
/**
@@ -51,6 +51,8 @@ object TwitterAlgebirdCMS {
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
+
// CMS parameters
val DELTA = 1E-3
val EPS = 0.01
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala
index c7e83e76b0..cb2f2c51a0 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterAlgebirdHLL.scala
@@ -17,10 +17,11 @@
package org.apache.spark.streaming.examples
-import org.apache.spark.streaming.{Seconds, StreamingContext}
-import org.apache.spark.storage.StorageLevel
-import com.twitter.algebird.HyperLogLog._
import com.twitter.algebird.HyperLogLogMonoid
+import com.twitter.algebird.HyperLogLog._
+
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.twitter._
/**
@@ -44,6 +45,8 @@ object TwitterAlgebirdHLL {
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
+
/** Bit size parameter for HyperLogLog, trades off accuracy vs size */
val BIT_SIZE = 12
val (master, filters) = (args.head, args.tail)
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala
index e2b0418d55..16c10feaba 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/TwitterPopularTags.scala
@@ -36,6 +36,8 @@ object TwitterPopularTags {
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
+
val (master, filters) = (args.head, args.tail)
val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2),
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala
index 03902ec353..12d2a1084f 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala
@@ -76,6 +76,7 @@ object ZeroMQWordCount {
"In local mode, <master> should be 'local[n]' with n > 1")
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
val Seq(master, url, topic) = args.toSeq
// Create the context and set the batch size
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala
index 807af199f4..da6b67bcce 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewStream.scala
@@ -17,9 +17,10 @@
package org.apache.spark.streaming.examples.clickstream
+import org.apache.spark.SparkContext._
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.SparkContext._
+import org.apache.spark.streaming.examples.StreamingExamples
/** Analyses a streaming dataset of web page views. This class demonstrates several types of
* operators available in Spark streaming.
@@ -36,6 +37,7 @@ object PageViewStream {
" errorRatePerZipCode, activeUserCount, popularUsersSeen")
System.exit(1)
}
+ StreamingExamples.setStreamingLogLevels()
val metric = args(0)
val host = args(1)
val port = args(2).toInt
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index ca0115f90e..1249ef4c3d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -24,10 +24,10 @@ import java.util.concurrent.RejectedExecutionException
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.conf.Configuration
-import org.apache.spark.{SparkConf, Logging}
+import org.apache.spark.{SparkException, SparkConf, Logging}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.MetadataCleaner
-import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.streaming.scheduler.JobGenerator
private[streaming]
@@ -44,6 +44,10 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf)
val sparkConf = ssc.conf
+ // These should be unset when a checkpoint is deserialized,
+ // otherwise the SparkContext won't initialize correctly.
+ sparkConf.remove("spark.hostPort").remove("spark.driver.host").remove("spark.driver.port")
+
def validate() {
assert(master != null, "Checkpoint.master is null")
assert(framework != null, "Checkpoint.framework is null")
@@ -53,59 +57,119 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
}
}
+private[streaming]
+object Checkpoint extends Logging {
+ val PREFIX = "checkpoint-"
+ val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r
+
+ /** Get the checkpoint file for the given checkpoint time */
+ def checkpointFile(checkpointDir: String, checkpointTime: Time) = {
+ new Path(checkpointDir, PREFIX + checkpointTime.milliseconds)
+ }
+
+ /** Get the checkpoint backup file for the given checkpoint time */
+ def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = {
+ new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk")
+ }
+
+ /** Get checkpoint files present in the give directory, ordered by oldest-first */
+ def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = {
+ def sortFunc(path1: Path, path2: Path): Boolean = {
+ val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) }
+ val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) }
+ (time1 < time2) || (time1 == time2 && bk1)
+ }
+
+ val path = new Path(checkpointDir)
+ if (fs.exists(path)) {
+ val statuses = fs.listStatus(path)
+ if (statuses != null) {
+ val paths = statuses.map(_.getPath)
+ val filtered = paths.filter(p => REGEX.findFirstIn(p.toString).nonEmpty)
+ filtered.sortWith(sortFunc)
+ } else {
+ logWarning("Listing " + path + " returned null")
+ Seq.empty
+ }
+ } else {
+ logInfo("Checkpoint directory " + path + " does not exist")
+ Seq.empty
+ }
+ }
+}
+
/**
* Convenience class to handle the writing of graph checkpoint to file
*/
private[streaming]
-class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Configuration)
- extends Logging
-{
- val file = new Path(checkpointDir, "graph")
+class CheckpointWriter(
+ jobGenerator: JobGenerator,
+ conf: SparkConf,
+ checkpointDir: String,
+ hadoopConf: Configuration
+ ) extends Logging {
val MAX_ATTEMPTS = 3
val executor = Executors.newFixedThreadPool(1)
val compressionCodec = CompressionCodec.createCodec(conf)
- // The file to which we actually write - and then "move" to file
- val writeFile = new Path(file.getParent, file.getName + ".next")
- // The file to which existing checkpoint is backed up (i.e. "moved")
- val bakFile = new Path(file.getParent, file.getName + ".bk")
-
private var stopped = false
private var fs_ : FileSystem = _
- // Removed code which validates whether there is only one CheckpointWriter per path 'file' since
- // I did not notice any errors - reintroduce it ?
class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable {
def run() {
var attempts = 0
val startTime = System.currentTimeMillis()
+ val tempFile = new Path(checkpointDir, "temp")
+ val checkpointFile = Checkpoint.checkpointFile(checkpointDir, checkpointTime)
+ val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, checkpointTime)
+
while (attempts < MAX_ATTEMPTS && !stopped) {
attempts += 1
try {
- logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'")
- // This is inherently thread unsafe, so alleviating it by writing to '.new' and
- // then moving it to the final file
- val fos = fs.create(writeFile)
+ logInfo("Saving checkpoint for time " + checkpointTime + " to file '" + checkpointFile + "'")
+
+ // Write checkpoint to temp file
+ fs.delete(tempFile, true) // just in case it exists
+ val fos = fs.create(tempFile)
fos.write(bytes)
fos.close()
- if (fs.exists(file) && fs.rename(file, bakFile)) {
- logDebug("Moved existing checkpoint file to " + bakFile)
+
+ // If the checkpoint file exists, back it up
+ // If the backup exists as well, just delete it, otherwise rename will fail
+ if (fs.exists(checkpointFile)) {
+ fs.delete(backupFile, true) // just in case it exists
+ if (!fs.rename(checkpointFile, backupFile)) {
+ logWarning("Could not rename " + checkpointFile + " to " + backupFile)
+ }
+ }
+
+ // Rename temp file to the final checkpoint file
+ if (!fs.rename(tempFile, checkpointFile)) {
+ logWarning("Could not rename " + tempFile + " to " + checkpointFile)
+ }
+
+ // Delete old checkpoint files
+ val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs)
+ if (allCheckpointFiles.size > 4) {
+ allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => {
+ logInfo("Deleting " + file)
+ fs.delete(file, true)
+ })
}
- // paranoia
- fs.delete(file, false)
- fs.rename(writeFile, file)
+ // All done, print success
val finishTime = System.currentTimeMillis()
- logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file +
- "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds")
+ logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile +
+ "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms")
+ jobGenerator.onCheckpointCompletion(checkpointTime)
return
} catch {
case ioe: IOException =>
- logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe)
+ logWarning("Error in attempt " + attempts + " of writing checkpoint to " + checkpointFile, ioe)
reset()
}
}
- logError("Could not write checkpoint for time " + checkpointTime + " to file '" + file + "'")
+ logWarning("Could not write checkpoint for time " + checkpointTime + " to file " + checkpointFile + "'")
}
}
@@ -118,6 +182,7 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi
bos.close()
try {
executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray))
+ logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
} catch {
case rej: RejectedExecutionException =>
logError("Could not submit checkpoint task to the thread pool executor", rej)
@@ -140,7 +205,7 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi
}
private def fs = synchronized {
- if (fs_ == null) fs_ = file.getFileSystem(hadoopConf)
+ if (fs_ == null) fs_ = new Path(checkpointDir).getFileSystem(hadoopConf)
fs_
}
@@ -153,43 +218,46 @@ class CheckpointWriter(conf: SparkConf, checkpointDir: String, hadoopConf: Confi
private[streaming]
object CheckpointReader extends Logging {
- def read(conf: SparkConf, path: String): Checkpoint = {
- val fs = new Path(path).getFileSystem(new Configuration())
- val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"),
- new Path(path), new Path(path + ".bk"))
+ def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = {
+ val checkpointPath = new Path(checkpointDir)
+ def fs = checkpointPath.getFileSystem(hadoopConf)
+
+ // Try to find the checkpoint files
+ val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse
+ if (checkpointFiles.isEmpty) {
+ return None
+ }
+ // Try to read the checkpoint files in the order
+ logInfo("Checkpoint files found: " + checkpointFiles.mkString(","))
val compressionCodec = CompressionCodec.createCodec(conf)
-
- attempts.foreach(file => {
- if (fs.exists(file)) {
- logInfo("Attempting to load checkpoint from file '" + file + "'")
- try {
- val fis = fs.open(file)
- // ObjectInputStream uses the last defined user-defined class loader in the stack
- // to find classes, which maybe the wrong class loader. Hence, a inherited version
- // of ObjectInputStream is used to explicitly use the current thread's default class
- // loader to find and load classes. This is a well know Java issue and has popped up
- // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
- val zis = compressionCodec.compressedInputStream(fis)
- val ois = new ObjectInputStreamWithLoader(zis,
- Thread.currentThread().getContextClassLoader)
- val cp = ois.readObject.asInstanceOf[Checkpoint]
- ois.close()
- fs.close()
- cp.validate()
- logInfo("Checkpoint successfully loaded from file '" + file + "'")
- logInfo("Checkpoint was generated at time " + cp.checkpointTime)
- return cp
- } catch {
- case e: Exception =>
- logError("Error loading checkpoint from file '" + file + "'", e)
- }
- } else {
- logWarning("Could not read checkpoint from file '" + file + "' as it does not exist")
+ checkpointFiles.foreach(file => {
+ logInfo("Attempting to load checkpoint from file " + file)
+ try {
+ val fis = fs.open(file)
+ // ObjectInputStream uses the last defined user-defined class loader in the stack
+ // to find classes, which maybe the wrong class loader. Hence, a inherited version
+ // of ObjectInputStream is used to explicitly use the current thread's default class
+ // loader to find and load classes. This is a well know Java issue and has popped up
+ // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
+ val zis = compressionCodec.compressedInputStream(fis)
+ val ois = new ObjectInputStreamWithLoader(zis,
+ Thread.currentThread().getContextClassLoader)
+ val cp = ois.readObject.asInstanceOf[Checkpoint]
+ ois.close()
+ fs.close()
+ cp.validate()
+ logInfo("Checkpoint successfully loaded from file " + file)
+ logInfo("Checkpoint was generated at time " + cp.checkpointTime)
+ return Some(cp)
+ } catch {
+ case e: Exception =>
+ logWarning("Error reading checkpoint from file " + file, e)
}
-
})
- throw new Exception("Could not read checkpoint from path '" + path + "'")
+
+ // If none of checkpoint files could be read, then throw exception
+ throw new SparkException("Failed to read checkpoint from directory " + checkpointPath)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index 00671ba520..b98f4a5101 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -329,13 +329,12 @@ abstract class DStream[T: ClassTag] (
* implementation clears the old generated RDDs. Subclasses of DStream may override
* this to clear their own metadata along with the generated RDDs.
*/
- protected[streaming] def clearOldMetadata(time: Time) {
- var numForgotten = 0
+ protected[streaming] def clearMetadata(time: Time) {
val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration))
generatedRDDs --= oldRDDs.keys
- logInfo("Cleared " + oldRDDs.size + " RDDs that were older than " +
+ logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " +
(time - rememberDuration) + ": " + oldRDDs.keys.mkString(", "))
- dependencies.foreach(_.clearOldMetadata(time))
+ dependencies.foreach(_.clearMetadata(time))
}
/* Adds metadata to the Stream while it is running.
@@ -356,12 +355,18 @@ abstract class DStream[T: ClassTag] (
*/
protected[streaming] def updateCheckpointData(currentTime: Time) {
logInfo("Updating checkpoint data for time " + currentTime)
- checkpointData.update()
+ checkpointData.update(currentTime)
dependencies.foreach(_.updateCheckpointData(currentTime))
- checkpointData.cleanup()
logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData)
}
+ protected[streaming] def clearCheckpointData(time: Time) {
+ logInfo("Clearing checkpoint data")
+ checkpointData.cleanup(time)
+ dependencies.foreach(_.clearCheckpointData(time))
+ logInfo("Cleared checkpoint data")
+ }
+
/**
* Restore the RDDs in generatedRDDs from the checkpointData. This is an internal method
* that should not be called directly. This is a default implementation that recreates RDDs
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
index 3fd5d52403..671f7bbce7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
@@ -17,77 +17,86 @@
package org.apache.spark.streaming
+import scala.collection.mutable.{HashMap, HashSet}
+import scala.reflect.ClassTag
+
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.conf.Configuration
-import collection.mutable.HashMap
import org.apache.spark.Logging
-import scala.collection.mutable.HashMap
-import scala.reflect.ClassTag
-
+import java.io.{ObjectInputStream, IOException}
private[streaming]
class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
extends Serializable with Logging {
protected val data = new HashMap[Time, AnyRef]()
- @transient private var fileSystem : FileSystem = null
- @transient private var lastCheckpointFiles: HashMap[Time, String] = null
+ // Mapping of the batch time to the checkpointed RDD file of that time
+ @transient private var timeToCheckpointFile = new HashMap[Time, String]
+ // Mapping of the batch time to the time of the oldest checkpointed RDD
+ // in that batch's checkpoint data
+ @transient private var timeToOldestCheckpointFileTime = new HashMap[Time, Time]
- protected[streaming] def checkpointFiles = data.asInstanceOf[HashMap[Time, String]]
+ @transient private var fileSystem : FileSystem = null
+ protected[streaming] def currentCheckpointFiles = data.asInstanceOf[HashMap[Time, String]]
/**
* Updates the checkpoint data of the DStream. This gets called every time
* the graph checkpoint is initiated. Default implementation records the
* checkpoint files to which the generate RDDs of the DStream has been saved.
*/
- def update() {
+ def update(time: Time) {
// Get the checkpointed RDDs from the generated RDDs
- val newCheckpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
+ val checkpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
.map(x => (x._1, x._2.getCheckpointFile.get))
-
- // Make a copy of the existing checkpoint data (checkpointed RDDs)
- lastCheckpointFiles = checkpointFiles.clone()
-
- // If the new checkpoint data has checkpoints then replace existing with the new one
- if (newCheckpointFiles.size > 0) {
- checkpointFiles.clear()
- checkpointFiles ++= newCheckpointFiles
- }
-
- // TODO: remove this, this is just for debugging
- newCheckpointFiles.foreach {
- case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") }
+ logDebug("Current checkpoint files:\n" + checkpointFiles.toSeq.mkString("\n"))
+
+ // Add the checkpoint files to the data to be serialized
+ if (!checkpointFiles.isEmpty) {
+ currentCheckpointFiles.clear()
+ currentCheckpointFiles ++= checkpointFiles
+ // Add the current checkpoint files to the map of all checkpoint files
+ // This will be used to delete old checkpoint files
+ timeToCheckpointFile ++= currentCheckpointFiles
+ // Remember the time of the oldest checkpoint RDD in current state
+ timeToOldestCheckpointFileTime(time) = currentCheckpointFiles.keys.min(Time.ordering)
}
}
/**
- * Cleanup old checkpoint data. This gets called every time the graph
- * checkpoint is initiated, but after `update` is called. Default
- * implementation, cleans up old checkpoint files.
+ * Cleanup old checkpoint data. This gets called after a checkpoint of `time` has been
+ * written to the checkpoint directory.
*/
- def cleanup() {
- // If there is at least on checkpoint file in the current checkpoint files,
- // then delete the old checkpoint files.
- if (checkpointFiles.size > 0 && lastCheckpointFiles != null) {
- (lastCheckpointFiles -- checkpointFiles.keySet).foreach {
- case (time, file) => {
- try {
- val path = new Path(file)
- if (fileSystem == null) {
- fileSystem = path.getFileSystem(new Configuration())
+ def cleanup(time: Time) {
+ // Get the time of the oldest checkpointed RDD that was written as part of the
+ // checkpoint of `time`
+ timeToOldestCheckpointFileTime.remove(time) match {
+ case Some(lastCheckpointFileTime) =>
+ // Find all the checkpointed RDDs (i.e. files) that are older than `lastCheckpointFileTime`
+ // This is because checkpointed RDDs older than this are not going to be needed
+ // even after master fails, as the checkpoint data of `time` does not refer to those files
+ val filesToDelete = timeToCheckpointFile.filter(_._1 < lastCheckpointFileTime)
+ logDebug("Files to delete:\n" + filesToDelete.mkString(","))
+ filesToDelete.foreach {
+ case (time, file) =>
+ try {
+ val path = new Path(file)
+ if (fileSystem == null) {
+ fileSystem = path.getFileSystem(dstream.ssc.sparkContext.hadoopConfiguration)
+ }
+ fileSystem.delete(path, true)
+ timeToCheckpointFile -= time
+ logInfo("Deleted checkpoint file '" + file + "' for time " + time)
+ } catch {
+ case e: Exception =>
+ logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e)
+ fileSystem = null
}
- fileSystem.delete(path, true)
- logInfo("Deleted checkpoint file '" + file + "' for time " + time)
- } catch {
- case e: Exception =>
- logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e)
- }
}
- }
+ case None =>
+ logInfo("Nothing to delete")
}
}
@@ -98,7 +107,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
*/
def restore() {
// Create RDDs from the checkpoint data
- checkpointFiles.foreach {
+ currentCheckpointFiles.foreach {
case(time, file) => {
logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'")
dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file)))
@@ -107,6 +116,13 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
}
override def toString() = {
- "[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]"
+ "[\n" + currentCheckpointFiles.size + " checkpoint files \n" + currentCheckpointFiles.mkString("\n") + "\n]"
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ ois.defaultReadObject()
+ timeToOldestCheckpointFileTime = new HashMap[Time, Time]
+ timeToCheckpointFile = new HashMap[Time, String]
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index a09b891956..eee9591ffc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -104,36 +104,44 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
def getOutputStreams() = this.synchronized { outputStreams.toArray }
def generateJobs(time: Time): Seq[Job] = {
- this.synchronized {
- logInfo("Generating jobs for time " + time)
- val jobs = outputStreams.flatMap(outputStream => outputStream.generateJob(time))
- logInfo("Generated " + jobs.length + " jobs for time " + time)
- jobs
+ logDebug("Generating jobs for time " + time)
+ val jobs = this.synchronized {
+ outputStreams.flatMap(outputStream => outputStream.generateJob(time))
}
+ logDebug("Generated " + jobs.length + " jobs for time " + time)
+ jobs
}
- def clearOldMetadata(time: Time) {
+ def clearMetadata(time: Time) {
+ logDebug("Clearing metadata for time " + time)
this.synchronized {
- logInfo("Clearing old metadata for time " + time)
- outputStreams.foreach(_.clearOldMetadata(time))
- logInfo("Cleared old metadata for time " + time)
+ outputStreams.foreach(_.clearMetadata(time))
}
+ logDebug("Cleared old metadata for time " + time)
}
def updateCheckpointData(time: Time) {
+ logInfo("Updating checkpoint data for time " + time)
this.synchronized {
- logInfo("Updating checkpoint data for time " + time)
outputStreams.foreach(_.updateCheckpointData(time))
- logInfo("Updated checkpoint data for time " + time)
}
+ logInfo("Updated checkpoint data for time " + time)
+ }
+
+ def clearCheckpointData(time: Time) {
+ logInfo("Clearing checkpoint data for time " + time)
+ this.synchronized {
+ outputStreams.foreach(_.clearCheckpointData(time))
+ }
+ logInfo("Cleared checkpoint data for time " + time)
}
def restoreCheckpointData() {
+ logInfo("Restoring checkpoint data")
this.synchronized {
- logInfo("Restoring checkpoint data")
outputStreams.foreach(_.restoreCheckpointData())
- logInfo("Restored checkpoint data")
}
+ logInfo("Restored checkpoint data")
}
def validate() {
@@ -146,8 +154,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
+ logDebug("DStreamGraph.writeObject used")
this.synchronized {
- logDebug("DStreamGraph.writeObject used")
checkpointInProgress = true
oos.defaultWriteObject()
checkpointInProgress = false
@@ -156,8 +164,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
@throws(classOf[IOException])
private def readObject(ois: ObjectInputStream) {
+ logDebug("DStreamGraph.readObject used")
this.synchronized {
- logDebug("DStreamGraph.readObject used")
checkpointInProgress = true
ois.defaultReadObject()
checkpointInProgress = false
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 693cb7fc30..dd34f6f4f2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -39,6 +39,7 @@ import org.apache.spark.util.MetadataCleaner
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.receivers._
import org.apache.spark.streaming.scheduler._
+import org.apache.hadoop.conf.Configuration
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -88,10 +89,12 @@ class StreamingContext private (
/**
* Re-create a StreamingContext from a checkpoint file.
- * @param path Path either to the directory that was specified as the checkpoint directory, or
- * to the checkpoint file 'graph' or 'graph.bk'.
+ * @param path Path to the directory that was specified as the checkpoint directory
+ * @param hadoopConf Optional, configuration object if necessary for reading from
+ * HDFS compatible filesystems
*/
- def this(path: String) = this(null, CheckpointReader.read(new SparkConf(), path), null)
+ def this(path: String, hadoopConf: Configuration = new Configuration) =
+ this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null)
if (sc_ == null && cp_ == null) {
throw new Exception("Spark Streaming cannot be initialized with " +
@@ -171,8 +174,9 @@ class StreamingContext private (
/**
* Set the context to periodically checkpoint the DStream operations for master
- * fault-tolerance. The graph will be checkpointed every batch interval.
- * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored
+ * fault-tolerance.
+ * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored.
+ * Note that this must be a fault-tolerant file system like HDFS for
*/
def checkpoint(directory: String) {
if (directory != null) {
@@ -461,26 +465,64 @@ class StreamingContext private (
}
}
+/**
+ * StreamingContext object contains a number of utility functions related to the
+ * StreamingContext class.
+ */
-object StreamingContext {
+object StreamingContext extends Logging {
implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = {
new PairDStreamFunctions[K, V](stream)
}
/**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the StreamingContext
+ * will be created by called the provided `creatingFunc`.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param creatingFunc Function to create a new StreamingContext
+ * @param hadoopConf Optional Hadoop configuration if necessary for reading from the
+ * file system
+ * @param createOnError Optional, whether to create a new StreamingContext if there is an
+ * error in reading checkpoint data. By default, an exception will be
+ * thrown on error.
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ creatingFunc: () => StreamingContext,
+ hadoopConf: Configuration = new Configuration(),
+ createOnError: Boolean = false
+ ): StreamingContext = {
+ val checkpointOption = try {
+ CheckpointReader.read(checkpointPath, new SparkConf(), hadoopConf)
+ } catch {
+ case e: Exception =>
+ if (createOnError) {
+ None
+ } else {
+ throw e
+ }
+ }
+ checkpointOption.map(new StreamingContext(null, _, null)).getOrElse(creatingFunc())
+ }
+
+ /**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
- * their JARs to SparkContext.
+ * their JARs to StreamingContext.
*/
def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls)
+
protected[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = {
// Set the default cleaner delay to an hour if not already set.
// This should be sufficient for even 1 second batch intervals.
- val sc = new SparkContext(conf)
- if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) {
- MetadataCleaner.setDelaySeconds(sc.conf, 3600)
+ if (MetadataCleaner.getDelaySeconds(conf) < 0) {
+ MetadataCleaner.setDelaySeconds(conf, 3600)
}
+ val sc = new SparkContext(conf)
sc
}
@@ -489,14 +531,17 @@ object StreamingContext {
appName: String,
sparkHome: String,
jars: Seq[String],
- environment: Map[String, String]): SparkContext =
- {
- val sc = new SparkContext(master, appName, sparkHome, jars, environment)
+ environment: Map[String, String]
+ ): SparkContext = {
+
+ val conf = SparkContext.updatedConf(
+ new SparkConf(), master, appName, sparkHome, jars, environment)
// Set the default cleaner delay to an hour if not already set.
// This should be sufficient for even 1 second batch intervals.
- if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) {
- MetadataCleaner.setDelaySeconds(sc.conf, 3600)
+ if (MetadataCleaner.getDelaySeconds(conf) < 0) {
+ MetadataCleaner.setDelaySeconds(conf, 3600)
}
+ val sc = new SparkContext(master, appName, sparkHome, jars, environment)
sc
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 7068f32517..523173d45a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -35,6 +35,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
import org.apache.spark.streaming.scheduler.StreamingListener
+import org.apache.hadoop.conf.Configuration
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -128,10 +129,16 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Re-creates a StreamingContext from a checkpoint file.
- * @param path Path either to the directory that was specified as the checkpoint directory, or
- * to the checkpoint file 'graph' or 'graph.bk'.
+ * @param path Path to the directory that was specified as the checkpoint directory
*/
- def this(path: String) = this (new StreamingContext(path))
+ def this(path: String) = this(new StreamingContext(path))
+
+ /**
+ * Re-creates a StreamingContext from a checkpoint file.
+ * @param path Path to the directory that was specified as the checkpoint directory
+ *
+ */
+ def this(path: String, hadoopConf: Configuration) = this(new StreamingContext(path, hadoopConf))
/** The underlying SparkContext */
val sc: JavaSparkContext = new JavaSparkContext(ssc.sc)
@@ -471,20 +478,97 @@ class JavaStreamingContext(val ssc: StreamingContext) {
}
/**
- * Starts the execution of the streams.
+ * Start the execution of the streams.
*/
def start() = ssc.start()
/**
- * Sstops the execution of the streams.
+ * Stop the execution of the streams.
*/
def stop() = ssc.stop()
}
+/**
+ * JavaStreamingContext object contains a number of utility functions.
+ */
object JavaStreamingContext {
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the provided factory
+ * will be used to create a JavaStreamingContext.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
+ * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ factory: JavaStreamingContextFactory
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ factory.create.ssc
+ })
+ new JavaStreamingContext(ssc)
+ }
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the provided factory
+ * will be used to create a JavaStreamingContext.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
+ * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible
+ * file system
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ hadoopConf: Configuration,
+ factory: JavaStreamingContextFactory
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ factory.create.ssc
+ }, hadoopConf)
+ new JavaStreamingContext(ssc)
+ }
+
+ /**
+ * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ * If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ * recreated from the checkpoint data. If the data does not exist, then the provided factory
+ * will be used to create a JavaStreamingContext.
+ *
+ * @param checkpointPath Checkpoint directory used in an earlier StreamingContext program
+ * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext
+ * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible
+ * file system
+ * @param createOnError Whether to create a new JavaStreamingContext if there is an
+ * error in reading checkpoint data.
+ */
+ def getOrCreate(
+ checkpointPath: String,
+ hadoopConf: Configuration,
+ factory: JavaStreamingContextFactory,
+ createOnError: Boolean
+ ): JavaStreamingContext = {
+ val ssc = StreamingContext.getOrCreate(checkpointPath, () => {
+ factory.create.ssc
+ }, hadoopConf, createOnError)
+ new JavaStreamingContext(ssc)
+ }
+
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
- * their JARs to SparkContext.
+ * their JARs to StreamingContext.
*/
def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls).toArray
}
+
+/**
+ * Factory interface for creating a new JavaStreamingContext
+ */
+trait JavaStreamingContextFactory {
+ def create(): JavaStreamingContext
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index fb9eda8996..1f0f31c4b1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -23,10 +23,10 @@ import scala.reflect.ClassTag
import org.apache.hadoop.fs.{FileSystem, Path, PathFilter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.UnionRDD
import org.apache.spark.streaming.{DStreamCheckpointData, StreamingContext, Time}
+import org.apache.spark.util.TimeStampedHashMap
private[streaming]
@@ -46,6 +46,8 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
@transient private var path_ : Path = null
@transient private var fs_ : FileSystem = null
@transient private[streaming] var files = new HashMap[Time, Array[String]]
+ @transient private var fileModTimes = new TimeStampedHashMap[String, Long](true)
+ @transient private var lastNewFileFindingTime = 0L
override def start() {
if (newFilesOnly) {
@@ -88,14 +90,16 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
}
/** Clear the old time-to-files mappings along with old RDDs */
- protected[streaming] override def clearOldMetadata(time: Time) {
- super.clearOldMetadata(time)
+ protected[streaming] override def clearMetadata(time: Time) {
+ super.clearMetadata(time)
val oldFiles = files.filter(_._1 <= (time - rememberDuration))
files --= oldFiles.keys
logInfo("Cleared " + oldFiles.size + " old files that were older than " +
(time - rememberDuration) + ": " + oldFiles.keys.mkString(", "))
logDebug("Cleared files are:\n" +
oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n"))
+ // Delete file mod times that weren't accessed in the last round of getting new files
+ fileModTimes.clearOldValues(lastNewFileFindingTime - 1)
}
/**
@@ -104,8 +108,19 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
*/
private def findNewFiles(currentTime: Long): (Seq[String], Long, Seq[String]) = {
logDebug("Trying to get new files for time " + currentTime)
+ lastNewFileFindingTime = System.currentTimeMillis
val filter = new CustomPathFilter(currentTime)
- val newFiles = fs.listStatus(path, filter).map(_.getPath.toString)
+ val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString)
+ val timeTaken = System.currentTimeMillis - lastNewFileFindingTime
+ logInfo("Finding new files took " + timeTaken + " ms")
+ logDebug("# cached file times = " + fileModTimes.size)
+ if (timeTaken > slideDuration.milliseconds) {
+ logWarning(
+ "Time taken to find new files exceeds the batch size. " +
+ "Consider increasing the batch size or reduceing the number of " +
+ "files in the monitored directory."
+ )
+ }
(newFiles, filter.latestModTime, filter.latestModTimeFiles.toSeq)
}
@@ -122,16 +137,21 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
new UnionRDD(context.sparkContext, fileRDDs)
}
- private def path: Path = {
+ private def directoryPath: Path = {
if (path_ == null) path_ = new Path(directory)
path_
}
private def fs: FileSystem = {
- if (fs_ == null) fs_ = path.getFileSystem(new Configuration())
+ if (fs_ == null) fs_ = directoryPath.getFileSystem(new Configuration())
fs_
}
+ private def getFileModTime(path: Path) = {
+ // Get file mod time from cache or fetch it from the file system
+ fileModTimes.getOrElseUpdate(path.toString, fs.getFileStatus(path).getModificationTime())
+ }
+
private def reset() {
fs_ = null
}
@@ -142,6 +162,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
ois.defaultReadObject()
generatedRDDs = new HashMap[Time, RDD[(K,V)]] ()
files = new HashMap[Time, Array[String]]
+ fileModTimes = new TimeStampedHashMap[String, Long](true)
}
/**
@@ -153,15 +174,15 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]]
- override def update() {
+ override def update(time: Time) {
hadoopFiles.clear()
hadoopFiles ++= files
}
- override def cleanup() { }
+ override def cleanup(time: Time) { }
override def restore() {
- hadoopFiles.foreach {
+ hadoopFiles.toSeq.sortBy(_._1)(Time.ordering).foreach {
case (t, f) => {
// Restore the metadata in both files and generatedRDDs
logInfo("Restoring files for time " + t + " - " +
@@ -187,14 +208,13 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
// Latest file mod time seen in this round of fetching files and its corresponding files
var latestModTime = 0L
val latestModTimeFiles = new HashSet[String]()
-
def accept(path: Path): Boolean = {
try {
if (!filter(path)) { // Reject file if it does not satisfy filter
logDebug("Rejected by filter " + path)
return false
}
- val modTime = fs.getFileStatus(path).getModificationTime()
+ val modTime = getFileModTime(path)
logDebug("Mod time for " + path + " is " + modTime)
if (modTime < prevModTime) {
logDebug("Mod time less than last mod time")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 3c624e8199..2fa6853ae0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -26,8 +26,9 @@ import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
/** Event classes for JobGenerator */
private[scheduler] sealed trait JobGeneratorEvent
private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent
-private[scheduler] case class ClearOldMetadata(time: Time) extends JobGeneratorEvent
+private[scheduler] case class ClearMetadata(time: Time) extends JobGeneratorEvent
private[scheduler] case class DoCheckpoint(time: Time) extends JobGeneratorEvent
+private[scheduler] case class ClearCheckpointData(time: Time) extends JobGeneratorEvent
/**
* This class generates jobs from DStreams as well as drives checkpointing and cleaning
@@ -53,7 +54,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
longTime => eventProcessorActor ! GenerateJobs(new Time(longTime)))
lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
- new CheckpointWriter(ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
+ new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
} else {
null
}
@@ -77,15 +78,20 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
* On batch completion, clear old metadata and checkpoint computation.
*/
private[scheduler] def onBatchCompletion(time: Time) {
- eventProcessorActor ! ClearOldMetadata(time)
+ eventProcessorActor ! ClearMetadata(time)
+ }
+
+ private[streaming] def onCheckpointCompletion(time: Time) {
+ eventProcessorActor ! ClearCheckpointData(time)
}
/** Processes all events */
private def processEvent(event: JobGeneratorEvent) {
event match {
case GenerateJobs(time) => generateJobs(time)
- case ClearOldMetadata(time) => clearOldMetadata(time)
+ case ClearMetadata(time) => clearMetadata(time)
case DoCheckpoint(time) => doCheckpoint(time)
+ case ClearCheckpointData(time) => clearCheckpointData(time)
}
}
@@ -115,14 +121,14 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val checkpointTime = ssc.initialCheckpoint.checkpointTime
val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds))
val downTimes = checkpointTime.until(restartTime, batchDuration)
- logInfo("Batches during down time: " + downTimes.mkString(", "))
+ logInfo("Batches during down time (" + downTimes.size + " batches): " + downTimes.mkString(", "))
// Batches that were unprocessed before failure
- val pendingTimes = ssc.initialCheckpoint.pendingTimes
- logInfo("Batches pending processing: " + pendingTimes.mkString(", "))
+ val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering)
+ logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", "))
// Reschedule jobs for these times
val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
- logInfo("Batches to reschedule: " + timesToReschedule.mkString(", "))
+ logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", "))
timesToReschedule.foreach(time =>
jobScheduler.runJobs(time, graph.generateJobs(time))
)
@@ -141,11 +147,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
}
/** Clear DStream metadata for the given `time`. */
- private def clearOldMetadata(time: Time) {
- ssc.graph.clearOldMetadata(time)
+ private def clearMetadata(time: Time) {
+ ssc.graph.clearMetadata(time)
eventProcessorActor ! DoCheckpoint(time)
}
+ /** Clear DStream checkpoint data for the given `time`. */
+ private def clearCheckpointData(time: Time) {
+ ssc.graph.clearCheckpointData(time)
+ }
+
/** Perform checkpoint for the give `time`. */
private def doCheckpoint(time: Time) = synchronized {
if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
index 1559f7a9f7..162b19d7f0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
@@ -42,6 +42,7 @@ object MasterFailureTest extends Logging {
@volatile var killed = false
@volatile var killCount = 0
+ @volatile var setupCalled = false
def main(args: Array[String]) {
if (args.size < 2) {
@@ -131,8 +132,26 @@ object MasterFailureTest extends Logging {
// Just making sure that the expected output does not have duplicates
assert(expectedOutput.distinct.toSet == expectedOutput.toSet)
+ // Reset all state
+ reset()
+
+ // Create the directories for this test
+ val uuid = UUID.randomUUID().toString
+ val rootDir = new Path(directory, uuid)
+ val fs = rootDir.getFileSystem(new Configuration())
+ val checkpointDir = new Path(rootDir, "checkpoint")
+ val testDir = new Path(rootDir, "test")
+ fs.mkdirs(checkpointDir)
+ fs.mkdirs(testDir)
+
// Setup the stream computation with the given operation
- val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation)
+ val ssc = StreamingContext.getOrCreate(checkpointDir.toString, () => {
+ setupStreams(batchDuration, operation, checkpointDir, testDir)
+ })
+
+ // Check if setupStream was called to create StreamingContext
+ // (and not created from checkpoint file)
+ assert(setupCalled, "Setup was not called in the first call to StreamingContext.getOrCreate")
// Start generating files in the a different thread
val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds)
@@ -144,9 +163,7 @@ object MasterFailureTest extends Logging {
val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2
val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun)
- // Delete directories
fileGeneratingThread.join()
- val fs = checkpointDir.getFileSystem(new Configuration())
fs.delete(checkpointDir, true)
fs.delete(testDir, true)
logInfo("Finished test after " + killCount + " failures")
@@ -159,32 +176,24 @@ object MasterFailureTest extends Logging {
* files should be written for testing.
*/
private def setupStreams[T: ClassTag](
- directory: String,
batchDuration: Duration,
- operation: DStream[String] => DStream[T]
- ): (StreamingContext, Path, Path) = {
- // Reset all state
- reset()
-
- // Create the directories for this test
- val uuid = UUID.randomUUID().toString
- val rootDir = new Path(directory, uuid)
- val fs = rootDir.getFileSystem(new Configuration())
- val checkpointDir = new Path(rootDir, "checkpoint")
- val testDir = new Path(rootDir, "test")
- fs.mkdirs(checkpointDir)
- fs.mkdirs(testDir)
+ operation: DStream[String] => DStream[T],
+ checkpointDir: Path,
+ testDir: Path
+ ): StreamingContext = {
+ // Mark that setup was called
+ setupCalled = true
// Setup the streaming computation with the given operation
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
- var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map())
+ val ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration, null, Nil, Map())
ssc.checkpoint(checkpointDir.toString)
val inputStream = ssc.textFileStream(testDir.toString)
val operatedStream = operation(inputStream)
val outputStream = new TestOutputStream(operatedStream)
ssc.registerOutputStream(outputStream)
- (ssc, checkpointDir, testDir)
+ ssc
}
@@ -204,7 +213,7 @@ object MasterFailureTest extends Logging {
var isTimedOut = false
val mergedOutput = new ArrayBuffer[T]()
val checkpointDir = ssc.checkpointDir
- var batchDuration = ssc.graph.batchDuration
+ val batchDuration = ssc.graph.batchDuration
while(!isLastOutputGenerated && !isTimedOut) {
// Get the output buffer
@@ -261,7 +270,10 @@ object MasterFailureTest extends Logging {
)
Thread.sleep(sleepTime)
// Recreate the streaming context from checkpoint
- ssc = new StreamingContext(checkpointDir)
+ ssc = StreamingContext.getOrCreate(checkpointDir, () => {
+ throw new Exception("Trying to create new context when it " +
+ "should be reading from checkpoint file")
+ })
}
}
mergedOutput
@@ -297,6 +309,7 @@ object MasterFailureTest extends Logging {
private def reset() {
killed = false
killCount = 0
+ setupCalled = false
}
}
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 0d2145da9a..8b7d7709bf 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -28,6 +28,7 @@ import java.util.*;
import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
+import com.google.common.collect.Sets;
import org.apache.spark.SparkConf;
import org.apache.spark.HashPartitioner;
@@ -441,13 +442,13 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
new Tuple2<String, String>("new york", "islanders")));
- List<List<Tuple2<String, Tuple2<String, String>>>> expected = Arrays.asList(
- Arrays.asList(
+ List<HashSet<Tuple2<String, Tuple2<String, String>>>> expected = Arrays.asList(
+ Sets.newHashSet(
new Tuple2<String, Tuple2<String, String>>("california",
new Tuple2<String, String>("dodgers", "giants")),
new Tuple2<String, Tuple2<String, String>>("new york",
- new Tuple2<String, String>("yankees", "mets"))),
- Arrays.asList(
+ new Tuple2<String, String>("yankees", "mets"))),
+ Sets.newHashSet(
new Tuple2<String, Tuple2<String, String>>("california",
new Tuple2<String, String>("sharks", "ducks")),
new Tuple2<String, Tuple2<String, String>>("new york",
@@ -482,8 +483,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
JavaTestUtils.attachTestOutputStream(joined);
List<List<Tuple2<String, Tuple2<String, String>>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+ List<HashSet<Tuple2<String, Tuple2<String, String>>>> unorderedResult = Lists.newArrayList();
+ for (List<Tuple2<String, Tuple2<String, String>>> res: result) {
+ unorderedResult.add(Sets.newHashSet(res));
+ }
- Assert.assertEquals(expected, result);
+ Assert.assertEquals(expected, unorderedResult);
}
@@ -1196,15 +1201,15 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
Arrays.asList("hello", "moon"),
Arrays.asList("hello"));
- List<List<Tuple2<String, Long>>> expected = Arrays.asList(
- Arrays.asList(
+ List<HashSet<Tuple2<String, Long>>> expected = Arrays.asList(
+ Sets.newHashSet(
new Tuple2<String, Long>("hello", 1L),
new Tuple2<String, Long>("world", 1L)),
- Arrays.asList(
+ Sets.newHashSet(
new Tuple2<String, Long>("hello", 2L),
new Tuple2<String, Long>("world", 1L),
new Tuple2<String, Long>("moon", 1L)),
- Arrays.asList(
+ Sets.newHashSet(
new Tuple2<String, Long>("hello", 2L),
new Tuple2<String, Long>("moon", 1L)));
@@ -1214,8 +1219,12 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
stream.countByValueAndWindow(new Duration(2000), new Duration(1000));
JavaTestUtils.attachTestOutputStream(counted);
List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+ List<HashSet<Tuple2<String, Long>>> unorderedResult = Lists.newArrayList();
+ for (List<Tuple2<String, Long>> res: result) {
+ unorderedResult.add(Sets.newHashSet(res));
+ }
- Assert.assertEquals(expected, result);
+ Assert.assertEquals(expected, unorderedResult);
}
@Test
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index ee6b433d1f..5ccef7f461 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -375,7 +375,7 @@ class BasicOperationsSuite extends TestSuiteBase {
}
test("slice") {
- val conf2 = new SparkConf()
+ val conf2 = conf.clone()
.setMaster("local[2]")
.setAppName("BasicOperationsSuite")
.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 8dc80ac2ed..6499de98c9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -84,9 +84,9 @@ class CheckpointSuite extends TestSuiteBase {
ssc.start()
advanceTimeWithRealDelay(ssc, firstNumBatches)
logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData)
- assert(!stateStream.checkpointData.checkpointFiles.isEmpty,
+ assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty,
"No checkpointed RDDs in state stream before first failure")
- stateStream.checkpointData.checkpointFiles.foreach {
+ stateStream.checkpointData.currentCheckpointFiles.foreach {
case (time, file) => {
assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time +
" for state stream before first failure does not exist")
@@ -95,7 +95,7 @@ class CheckpointSuite extends TestSuiteBase {
// Run till a further time such that previous checkpoint files in the stream would be deleted
// and check whether the earlier checkpoint files are deleted
- val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2))
+ val checkpointFiles = stateStream.checkpointData.currentCheckpointFiles.map(x => new File(x._2))
advanceTimeWithRealDelay(ssc, secondNumBatches)
checkpointFiles.foreach(file =>
assert(!file.exists, "Checkpoint file '" + file + "' was not deleted"))
@@ -114,9 +114,9 @@ class CheckpointSuite extends TestSuiteBase {
// is present in the checkpoint data or not
ssc.start()
advanceTimeWithRealDelay(ssc, 1)
- assert(!stateStream.checkpointData.checkpointFiles.isEmpty,
+ assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty,
"No checkpointed RDDs in state stream before second failure")
- stateStream.checkpointData.checkpointFiles.foreach {
+ stateStream.checkpointData.currentCheckpointFiles.foreach {
case (time, file) => {
assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time +
" for state stream before seconds failure does not exist")
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 2bb11e54c5..2e46d750c4 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -127,14 +127,13 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
// local dirs, so lets check both. We assume one of the 2 is set.
// LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
- .getOrElse(Option(System.getenv("LOCAL_DIRS"))
- .getOrElse(""))
-
- if (localDirs.isEmpty()) {
- throw new Exception("Yarn Local dirs can't be empty")
+ .orElse(Option(System.getenv("LOCAL_DIRS")))
+
+ localDirs match {
+ case None => throw new Exception("Yarn Local dirs can't be empty")
+ case Some(l) => l
}
- localDirs
- }
+ }
private def getApplicationAttemptId(): ApplicationAttemptId = {
val envs = System.getenv()
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
index ddfec1a4ac..62b20b8fba 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -76,6 +76,10 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
def run() {
+ // Setup the directories so things go to yarn approved directories rather
+ // then user specified and /tmp.
+ System.setProperty("spark.local.dir", getLocalDirs())
+
appAttemptId = getApplicationAttemptId()
resourceManager = registerWithResourceManager()
val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
@@ -103,10 +107,12 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
// ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
- // must be <= timeoutInterval/ 2.
- // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
- // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
- val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+ // we want to be reasonably responsive without causing too many requests to RM.
+ val schedulerInterval =
+ System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong
+ // must be <= timeoutInterval / 2.
+ val interval = math.min(timeoutInterval / 2, schedulerInterval)
+
reporterThread = launchReporterThread(interval)
// Wait for the reporter thread to Finish.
@@ -119,6 +125,20 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
System.exit(0)
}
+ /** Get the Yarn approved local directories. */
+ private def getLocalDirs(): String = {
+ // Hadoop 0.23 and 2.x have different Environment variable names for the
+ // local dirs, so lets check both. We assume one of the 2 is set.
+ // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
+ val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
+ .orElse(Option(System.getenv("LOCAL_DIRS")))
+
+ localDirs match {
+ case None => throw new Exception("Yarn Local dirs can't be empty")
+ case Some(l) => l
+ }
+ }
+
private def getApplicationAttemptId(): ApplicationAttemptId = {
val envs = System.getenv()
val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index 4b1b5da048..22e55e0c60 100644
--- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -22,6 +22,8 @@ import org.apache.spark.{SparkException, Logging, SparkContext}
import org.apache.spark.deploy.yarn.{Client, ClientArguments}
import org.apache.spark.scheduler.TaskSchedulerImpl
+import scala.collection.mutable.ArrayBuffer
+
private[spark] class YarnClientSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext)
@@ -31,45 +33,47 @@ private[spark] class YarnClientSchedulerBackend(
var client: Client = null
var appId: ApplicationId = null
+ private[spark] def addArg(optionName: String, optionalParam: String, arrayBuf: ArrayBuffer[String]) {
+ Option(System.getenv(optionalParam)) foreach {
+ optParam => {
+ arrayBuf += (optionName, optParam)
+ }
+ }
+ }
+
override def start() {
super.start()
- val defalutWorkerCores = "2"
- val defalutWorkerMemory = "512m"
- val defaultWorkerNumber = "1"
-
val userJar = System.getenv("SPARK_YARN_APP_JAR")
- val distFiles = System.getenv("SPARK_YARN_DIST_FILES")
- var workerCores = System.getenv("SPARK_WORKER_CORES")
- var workerMemory = System.getenv("SPARK_WORKER_MEMORY")
- var workerNumber = System.getenv("SPARK_WORKER_INSTANCES")
-
if (userJar == null)
throw new SparkException("env SPARK_YARN_APP_JAR is not set")
- if (workerCores == null)
- workerCores = defalutWorkerCores
- if (workerMemory == null)
- workerMemory = defalutWorkerMemory
- if (workerNumber == null)
- workerNumber = defaultWorkerNumber
-
val driverHost = conf.get("spark.driver.host")
val driverPort = conf.get("spark.driver.port")
val hostport = driverHost + ":" + driverPort
- val argsArray = Array[String](
+ val argsArrayBuf = new ArrayBuffer[String]()
+ argsArrayBuf += (
"--class", "notused",
"--jar", userJar,
"--args", hostport,
- "--worker-memory", workerMemory,
- "--worker-cores", workerCores,
- "--num-workers", workerNumber,
- "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher",
- "--files", distFiles
+ "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher"
)
- val args = new ClientArguments(argsArray, conf)
+ // process any optional arguments, use the defaults already defined in ClientArguments
+ // if things aren't specified
+ Map("--master-memory" -> "SPARK_MASTER_MEMORY",
+ "--num-workers" -> "SPARK_WORKER_INSTANCES",
+ "--worker-memory" -> "SPARK_WORKER_MEMORY",
+ "--worker-cores" -> "SPARK_WORKER_CORES",
+ "--queue" -> "SPARK_YARN_QUEUE",
+ "--name" -> "SPARK_YARN_APP_NAME",
+ "--files" -> "SPARK_YARN_DIST_FILES",
+ "--archives" -> "SPARK_YARN_DIST_ARCHIVES")
+ .foreach { case (optName, optParam) => addArg(optName, optParam, argsArrayBuf) }
+
+ logDebug("ClientArguments called with: " + argsArrayBuf)
+ val args = new ClientArguments(argsArrayBuf.toArray, conf)
client = new Client(args, conf)
appId = client.runApp()
waitForApp()
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 69ae14ce83..4b777d5fa7 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -116,14 +116,13 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
// local dirs, so lets check both. We assume one of the 2 is set.
// LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
- .getOrElse(Option(System.getenv("LOCAL_DIRS"))
- .getOrElse(""))
-
- if (localDirs.isEmpty()) {
- throw new Exception("Yarn Local dirs can't be empty")
+ .orElse(Option(System.getenv("LOCAL_DIRS")))
+
+ localDirs match {
+ case None => throw new Exception("Yarn Local dirs can't be empty")
+ case Some(l) => l
}
- localDirs
- }
+ }
private def getApplicationAttemptId(): ApplicationAttemptId = {
val envs = System.getenv()
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index be323d7783..952e963389 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -99,6 +99,7 @@ class Client(args: ClientArguments, conf: Configuration, sparkConf: SparkConf)
appContext.setApplicationName(args.appName)
appContext.setQueue(args.amQueue)
appContext.setAMContainerSpec(amContainer)
+ appContext.setApplicationType("SPARK")
// Memory for the ApplicationMaster.
val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
index 49248a8516..78353224fa 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -78,6 +78,10 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
def run() {
+ // Setup the directories so things go to yarn approved directories rather
+ // then user specified and /tmp.
+ System.setProperty("spark.local.dir", getLocalDirs())
+
amClient = AMRMClient.createAMRMClient()
amClient.init(yarnConf)
amClient.start()
@@ -94,10 +98,12 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
// ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
- // must be <= timeoutInterval/ 2.
- // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
- // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
- val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval / 10, 60000L))
+ // we want to be reasonably responsive without causing too many requests to RM.
+ val schedulerInterval =
+ System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong
+ // must be <= timeoutInterval / 2.
+ val interval = math.min(timeoutInterval / 2, schedulerInterval)
+
reporterThread = launchReporterThread(interval)
// Wait for the reporter thread to Finish.
@@ -110,6 +116,20 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar
System.exit(0)
}
+ /** Get the Yarn approved local directories. */
+ private def getLocalDirs(): String = {
+ // Hadoop 0.23 and 2.x have different Environment variable names for the
+ // local dirs, so lets check both. We assume one of the 2 is set.
+ // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
+ val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
+ .orElse(Option(System.getenv("LOCAL_DIRS")))
+
+ localDirs match {
+ case None => throw new Exception("Yarn Local dirs can't be empty")
+ case Some(l) => l
+ }
+ }
+
private def getApplicationAttemptId(): ApplicationAttemptId = {
val envs = System.getenv()
val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())