aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephen Haberman <stephen@exigencecorp.com>2013-02-25 23:48:52 -0600
committerStephen Haberman <stephen@exigencecorp.com>2013-02-25 23:48:52 -0600
commita4adeb255c66bbbb8eb7f4abcfd2b4c63906be31 (patch)
treeb0071ea76237b2f1d882d75f1d24dcd3ecad6c17
parent921be765339ac6a1b1a12672d73620855984eade (diff)
parentd6e6abece306008c50410807669596d73d6d6738 (diff)
downloadspark-a4adeb255c66bbbb8eb7f4abcfd2b4c63906be31.tar.gz
spark-a4adeb255c66bbbb8eb7f4abcfd2b4c63906be31.tar.bz2
spark-a4adeb255c66bbbb8eb7f4abcfd2b4c63906be31.zip
Merge branch 'master' into nomocks
Conflicts: core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
-rw-r--r--.gitignore2
-rw-r--r--bagel/src/main/scala/spark/bagel/Bagel.scala42
-rw-r--r--bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala6
-rw-r--r--bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala2
-rw-r--r--bagel/src/test/scala/bagel/BagelSuite.scala35
-rw-r--r--core/src/main/scala/spark/CacheManager.scala4
-rw-r--r--core/src/main/scala/spark/DoubleRDDFunctions.scala4
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala87
-rw-r--r--core/src/main/scala/spark/Partition.scala (renamed from core/src/main/scala/spark/Split.scala)2
-rw-r--r--core/src/main/scala/spark/Partitioner.scala32
-rw-r--r--core/src/main/scala/spark/RDD.scala130
-rw-r--r--core/src/main/scala/spark/RDDCheckpointData.scala12
-rw-r--r--core/src/main/scala/spark/SparkContext.scala32
-rw-r--r--core/src/main/scala/spark/Utils.scala21
-rw-r--r--core/src/main/scala/spark/api/java/JavaDoubleRDD.scala30
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala93
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDD.scala27
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala27
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala22
-rw-r--r--core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java20
-rw-r--r--core/src/main/scala/spark/api/python/PythonPartitioner.scala2
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala10
-rw-r--r--core/src/main/scala/spark/deploy/ApplicationDescription.scala (renamed from core/src/main/scala/spark/deploy/JobDescription.scala)4
-rw-r--r--core/src/main/scala/spark/deploy/DeployMessage.scala29
-rw-r--r--core/src/main/scala/spark/deploy/JsonProtocol.scala18
-rw-r--r--core/src/main/scala/spark/deploy/client/Client.scala24
-rw-r--r--core/src/main/scala/spark/deploy/client/ClientListener.scala2
-rw-r--r--core/src/main/scala/spark/deploy/client/TestClient.scala6
-rw-r--r--core/src/main/scala/spark/deploy/master/ApplicationInfo.scala (renamed from core/src/main/scala/spark/deploy/master/JobInfo.scala)10
-rw-r--r--core/src/main/scala/spark/deploy/master/ApplicationState.scala11
-rw-r--r--core/src/main/scala/spark/deploy/master/ExecutorInfo.scala4
-rw-r--r--core/src/main/scala/spark/deploy/master/JobState.scala9
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala208
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterWebUI.scala22
-rw-r--r--core/src/main/scala/spark/deploy/master/WorkerInfo.scala6
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala26
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala29
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerArguments.scala2
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala4
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala17
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala5
-rw-r--r--core/src/main/scala/spark/network/Connection.scala24
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala35
-rw-r--r--core/src/main/scala/spark/partial/ApproximateActionListener.scala2
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala24
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala37
-rw-r--r--core/src/main/scala/spark/rdd/CheckpointRDD.scala26
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala49
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala33
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/GlommedRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala40
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsRDD.scala8
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala (renamed from core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala)12
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala37
-rw-r--r--core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala (renamed from core/src/main/scala/spark/ParallelCollection.scala)31
-rw-r--r--core/src/main/scala/spark/rdd/PartitionPruningRDD.scala28
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala6
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala22
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/SubtractedRDD.scala108
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala34
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala36
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala14
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala26
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala17
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala43
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala10
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala4
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala4
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala6
-rw-r--r--core/src/main/scala/spark/storage/StorageUtils.scala2
-rw-r--r--core/src/main/scala/spark/util/MetadataCleaner.scala4
-rw-r--r--core/src/main/scala/spark/util/Vector.scala10
-rw-r--r--core/src/main/twirl/spark/deploy/master/app_details.scala.html40
-rw-r--r--core/src/main/twirl/spark/deploy/master/app_row.scala.html20
-rw-r--r--core/src/main/twirl/spark/deploy/master/app_table.scala.html (renamed from core/src/main/twirl/spark/deploy/master/job_table.scala.html)8
-rw-r--r--core/src/main/twirl/spark/deploy/master/executor_row.scala.html6
-rw-r--r--core/src/main/twirl/spark/deploy/master/index.scala.html22
-rw-r--r--core/src/main/twirl/spark/deploy/master/job_details.scala.html40
-rw-r--r--core/src/main/twirl/spark/deploy/master/job_row.scala.html20
-rw-r--r--core/src/main/twirl/spark/deploy/worker/executor_row.scala.html10
-rw-r--r--core/src/main/twirl/spark/deploy/worker/index.scala.html6
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala126
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala21
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java24
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala8
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala21
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala70
-rw-r--r--core/src/test/scala/spark/SortingSuite.scala10
-rw-r--r--core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala (renamed from core/src/test/scala/spark/ParallelCollectionSplitSuite.scala)40
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala21
-rw-r--r--core/src/test/scala/spark/scheduler/TaskContextSuite.scala10
-rw-r--r--docs/_config.yml1
-rw-r--r--docs/configuration.md16
-rw-r--r--docs/contributing-to-spark.md2
-rw-r--r--docs/python-programming-guide.md2
-rw-r--r--docs/scala-programming-guide.md2
-rw-r--r--docs/spark-standalone.md8
-rw-r--r--docs/streaming-custom-receivers.md101
-rw-r--r--docs/streaming-programming-guide.md315
-rw-r--r--docs/tuning.md10
-rw-r--r--ec2/deploy.generic/root/spark-ec2/ec2-variables.sh11
-rwxr-xr-xec2/spark_ec2.py81
-rw-r--r--examples/pom.xml7
-rw-r--r--examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java (renamed from examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java)0
-rw-r--r--examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java (renamed from examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java)4
-rw-r--r--examples/src/main/java/spark/streaming/examples/JavaQueueStream.java (renamed from examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java)0
-rw-r--r--examples/src/main/scala/spark/examples/LogQuery.scala66
-rw-r--r--examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala157
-rw-r--r--examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala28
-rw-r--r--examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala4
-rw-r--r--examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala2
-rw-r--r--examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala93
-rw-r--r--examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala71
-rw-r--r--examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala (renamed from examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala)33
-rw-r--r--examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala73
-rw-r--r--examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala13
-rw-r--r--pom.xml6
-rw-r--r--project/SparkBuild.scala12
-rwxr-xr-xpyspark7
-rw-r--r--python/pyspark/join.py20
-rw-r--r--python/pyspark/rdd.py58
-rwxr-xr-xrun34
-rw-r--r--run2.cmd26
-rw-r--r--sbt/sbt.cmd2
-rw-r--r--streaming/pom.xml11
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala72
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala221
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala93
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala58
-rw-r--r--streaming/src/main/scala/spark/streaming/Duration.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/Interval.scala1
-rw-r--r--streaming/src/main/scala/spark/streaming/JobManager.scala44
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala23
-rw-r--r--streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala156
-rw-r--r--streaming/src/main/scala/spark/streaming/Scheduler.scala102
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala181
-rw-r--r--streaming/src/main/scala/spark/streaming/Time.scala13
-rw-r--r--streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala36
-rw-r--r--streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala156
-rw-r--r--streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala131
-rw-r--r--streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala252
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala140
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala4
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala36
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala105
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala17
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala13
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala1
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala28
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala2
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala12
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala (renamed from examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala)13
-rw-r--r--streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala153
-rw-r--r--streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala33
-rw-r--r--streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala392
-rw-r--r--streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala30
-rw-r--r--streaming/src/test/java/spark/streaming/JavaAPISuite.java426
-rw-r--r--streaming/src/test/java/spark/streaming/JavaTestUtils.scala6
-rw-r--r--streaming/src/test/resources/log4j.properties1
-rw-r--r--streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala86
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala193
-rw-r--r--streaming/src/test/scala/spark/streaming/FailureSuite.scala189
-rw-r--r--streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala205
-rw-r--r--streaming/src/test/scala/spark/streaming/TestSuiteBase.scala22
-rw-r--r--streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala62
176 files changed, 5166 insertions, 2212 deletions
diff --git a/.gitignore b/.gitignore
index 88d7b56181..155e785b01 100644
--- a/.gitignore
+++ b/.gitignore
@@ -34,3 +34,5 @@ log/
spark-tests.log
streaming-tests.log
dependency-reduced-pom.xml
+.ensime
+.ensime_lucene
diff --git a/bagel/src/main/scala/spark/bagel/Bagel.scala b/bagel/src/main/scala/spark/bagel/Bagel.scala
index 996ca2a877..094e57dacb 100644
--- a/bagel/src/main/scala/spark/bagel/Bagel.scala
+++ b/bagel/src/main/scala/spark/bagel/Bagel.scala
@@ -6,19 +6,19 @@ import spark.SparkContext._
import scala.collection.mutable.ArrayBuffer
object Bagel extends Logging {
- def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
- C : Manifest, A : Manifest](
+ def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
+ C: Manifest, A: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
aggregator: Option[Aggregator[V, A]],
partitioner: Partitioner,
- numSplits: Int
+ numPartitions: Int
)(
compute: (V, Option[C], Option[A], Int) => (V, Array[M])
): RDD[(K, V)] = {
- val splits = if (numSplits != 0) numSplits else sc.defaultParallelism
+ val splits = if (numPartitions != 0) numPartitions else sc.defaultParallelism
var superstep = 0
var verts = vertices
@@ -50,49 +50,47 @@ object Bagel extends Logging {
verts
}
- def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
- C : Manifest](
+ def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
partitioner: Partitioner,
- numSplits: Int
+ numPartitions: Int
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = {
run[K, V, M, C, Nothing](
- sc, vertices, messages, combiner, None, partitioner, numSplits)(
+ sc, vertices, messages, combiner, None, partitioner, numPartitions)(
addAggregatorArg[K, V, M, C](compute))
}
- def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
- C : Manifest](
+ def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
- numSplits: Int
+ numPartitions: Int
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = {
- val part = new HashPartitioner(numSplits)
+ val part = new HashPartitioner(numPartitions)
run[K, V, M, C, Nothing](
- sc, vertices, messages, combiner, None, part, numSplits)(
+ sc, vertices, messages, combiner, None, part, numPartitions)(
addAggregatorArg[K, V, M, C](compute))
}
- def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
+ def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
- numSplits: Int
+ numPartitions: Int
)(
compute: (V, Option[Array[M]], Int) => (V, Array[M])
): RDD[(K, V)] = {
- val part = new HashPartitioner(numSplits)
+ val part = new HashPartitioner(numPartitions)
run[K, V, M, Array[M], Nothing](
- sc, vertices, messages, new DefaultCombiner(), None, part, numSplits)(
+ sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions)(
addAggregatorArg[K, V, M, Array[M]](compute))
}
@@ -100,7 +98,7 @@ object Bagel extends Logging {
* Aggregates the given vertices using the given aggregator, if it
* is specified.
*/
- private def agg[K, V <: Vertex, A : Manifest](
+ private def agg[K, V <: Vertex, A: Manifest](
verts: RDD[(K, V)],
aggregator: Option[Aggregator[V, A]]
): Option[A] = aggregator match {
@@ -116,7 +114,7 @@ object Bagel extends Logging {
* function. Returns the processed RDD, the number of messages
* created, and the number of active vertices.
*/
- private def comp[K : Manifest, V <: Vertex, M <: Message[K], C](
+ private def comp[K: Manifest, V <: Vertex, M <: Message[K], C](
sc: SparkContext,
grouped: RDD[(K, (Seq[C], Seq[V]))],
compute: (V, Option[C]) => (V, Array[M])
@@ -149,9 +147,7 @@ object Bagel extends Logging {
* Converts a compute function that doesn't take an aggregator to
* one that does, so it can be passed to Bagel.run.
*/
- private def addAggregatorArg[
- K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C
- ](
+ private def addAggregatorArg[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C](
compute: (V, Option[C], Int) => (V, Array[M])
): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = {
(vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) =>
@@ -170,7 +166,7 @@ trait Aggregator[V, A] {
def mergeAggregators(a: A, b: A): A
}
-class DefaultCombiner[M : Manifest] extends Combiner[M, Array[M]] with Serializable {
+class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable {
def createCombiner(msg: M): Array[M] =
Array(msg)
def mergeMsg(combiner: Array[M], msg: M): Array[M] =
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala
index 03843019c0..bc32663e0f 100644
--- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala
+++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala
@@ -16,7 +16,7 @@ import scala.xml.{XML,NodeSeq}
object WikipediaPageRank {
def main(args: Array[String]) {
if (args.length < 5) {
- System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numSplits> <host> <usePartitioner>")
+ System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numPartitions> <host> <usePartitioner>")
System.exit(-1)
}
@@ -25,7 +25,7 @@ object WikipediaPageRank {
val inputFile = args(0)
val threshold = args(1).toDouble
- val numSplits = args(2).toInt
+ val numPartitions = args(2).toInt
val host = args(3)
val usePartitioner = args(4).toBoolean
val sc = new SparkContext(host, "WikipediaPageRank")
@@ -69,7 +69,7 @@ object WikipediaPageRank {
val result =
Bagel.run(
sc, vertices, messages, combiner = new PRCombiner(),
- numSplits = numSplits)(
+ numPartitions = numPartitions)(
utils.computeWithCombiner(numVertices, epsilon))
// Print the result
diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
index 06cc8c748b..9d9d80d809 100644
--- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
+++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala
@@ -88,7 +88,7 @@ object WikipediaPageRankStandalone {
n: Long,
partitioner: Partitioner,
usePartitioner: Boolean,
- numSplits: Int
+ numPartitions: Int
): RDD[(String, Double)] = {
var ranks = links.mapValues { edges => defaultRank }
for (i <- 1 to numIterations) {
diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala
index 3c2f9c4616..47829a431e 100644
--- a/bagel/src/test/scala/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/bagel/BagelSuite.scala
@@ -1,10 +1,8 @@
package spark.bagel
import org.scalatest.{FunSuite, Assertions, BeforeAndAfter}
-import org.scalatest.prop.Checkers
-import org.scalacheck.Arbitrary._
-import org.scalacheck.Gen
-import org.scalacheck.Prop._
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
import scala.collection.mutable.ArrayBuffer
@@ -13,7 +11,7 @@ import spark._
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable
-class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
+class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts {
var sc: SparkContext = _
@@ -25,7 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
}
-
+
test("halting by voting") {
sc = new SparkContext("local", "test")
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0))))
@@ -36,8 +34,9 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
}
- for ((id, vert) <- result.collect)
+ for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
+ }
}
test("halting by message silence") {
@@ -57,7 +56,27 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
}
(new TestVertex(self.active, self.age + 1), msgsOut)
}
- for ((id, vert) <- result.collect)
+ for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
+ }
+ }
+
+ test("large number of iterations") {
+ // This tests whether jobs with a large number of iterations finish in a reasonable time,
+ // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang
+ failAfter(10 seconds) {
+ sc = new SparkContext("local", "test")
+ val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
+ val msgs = sc.parallelize(Array[(String, TestMessage)]())
+ val numSupersteps = 50
+ val result =
+ Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
+ (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
+ (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
+ }
+ for ((id, vert) <- result.collect) {
+ assert(vert.age === numSupersteps)
+ }
+ }
}
}
diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala
index 711435c333..c7b379a3fb 100644
--- a/core/src/main/scala/spark/CacheManager.scala
+++ b/core/src/main/scala/spark/CacheManager.scala
@@ -11,13 +11,13 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
private val loading = new HashSet[String]
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
- def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
+ def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
- // Split is in cache, so just return its values
+ // Partition is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedValues.asInstanceOf[Iterator[T]]
diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala
index b2a0e2b631..178d31a73b 100644
--- a/core/src/main/scala/spark/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/spark/DoubleRDDFunctions.scala
@@ -42,14 +42,14 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
/** (Experimental) Approximate operation to return the mean within a timeout. */
def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
- val evaluator = new MeanEvaluator(self.splits.size, confidence)
+ val evaluator = new MeanEvaluator(self.partitions.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
/** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
- val evaluator = new SumEvaluator(self.splits.size, confidence)
+ val evaluator = new SumEvaluator(self.partitions.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
}
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index cc3cca2571..e7408e4352 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -23,6 +23,7 @@ import spark.partial.BoundedDouble
import spark.partial.PartialResult
import spark.rdd._
import spark.SparkContext._
+import spark.Partitioner._
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@@ -62,7 +63,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
- if (mapSideCombine) {
+ if (self.partitioner == Some(partitioner)) {
+ self.mapPartitions(aggregator.combineValuesByKey(_), true)
+ } else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
@@ -81,8 +84,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
- numSplits: Int): RDD[(K, C)] = {
- combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits))
+ numPartitions: Int): RDD[(K, C)] = {
+ combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
}
/**
@@ -143,10 +146,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits.
+ * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
*/
- def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
- reduceByKey(new HashPartitioner(numSplits), func)
+ def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = {
+ reduceByKey(new HashPartitioner(numPartitions), func)
}
/**
@@ -164,10 +167,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
- * resulting RDD with into `numSplits` partitions.
+ * resulting RDD with into `numPartitions` partitions.
*/
- def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = {
- groupByKey(new HashPartitioner(numSplits))
+ def groupByKey(numPartitions: Int): RDD[(K, Seq[V])] = {
+ groupByKey(new HashPartitioner(numPartitions))
}
/**
@@ -246,8 +249,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
- * Simplified version of combineByKey that hash-partitions the resulting RDD using the default
- * parallelism level.
+ * Simplified version of combineByKey that hash-partitions the resulting RDD using the
+ * existing partitioner/parallelism level.
*/
def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C)
: RDD[(K, C)] = {
@@ -257,7 +260,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level.
+ * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
+ * parallelism level.
*/
def reduceByKey(func: (V, V) => V): RDD[(K, V)] = {
reduceByKey(defaultPartitioner(self), func)
@@ -265,7 +269,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
- * resulting RDD with the default parallelism level.
+ * resulting RDD with the existing partitioner/parallelism level.
*/
def groupByKey(): RDD[(K, Seq[V])] = {
groupByKey(defaultPartitioner(self))
@@ -285,15 +289,15 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
* (k, v2) is in `other`. Performs a hash join across the cluster.
*/
- def join[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, W))] = {
- join(other, new HashPartitioner(numSplits))
+ def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = {
+ join(other, new HashPartitioner(numPartitions))
}
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
- * using the default level of parallelism.
+ * using the existing partitioner/parallelism level.
*/
def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = {
leftOuterJoin(other, defaultPartitioner(self, other))
@@ -303,17 +307,17 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
- * into `numSplits` partitions.
+ * into `numPartitions` partitions.
*/
- def leftOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, Option[W]))] = {
- leftOuterJoin(other, new HashPartitioner(numSplits))
+ def leftOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, Option[W]))] = {
+ leftOuterJoin(other, new HashPartitioner(numPartitions))
}
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
- * RDD using the default parallelism level.
+ * RDD using the existing partitioner/parallelism level.
*/
def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = {
rightOuterJoin(other, defaultPartitioner(self, other))
@@ -325,8 +329,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD into the given number of partitions.
*/
- def rightOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Option[V], W))] = {
- rightOuterJoin(other, new HashPartitioner(numSplits))
+ def rightOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], W))] = {
+ rightOuterJoin(other, new HashPartitioner(numPartitions))
}
/**
@@ -361,7 +365,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](
- Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]),
+ Seq(self.asInstanceOf[RDD[(K, _)]], other.asInstanceOf[RDD[(K, _)]]),
partitioner)
val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
prfs.mapValues {
@@ -380,9 +384,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](
- Seq(self.asInstanceOf[RDD[(_, _)]],
- other1.asInstanceOf[RDD[(_, _)]],
- other2.asInstanceOf[RDD[(_, _)]]),
+ Seq(self.asInstanceOf[RDD[(K, _)]],
+ other1.asInstanceOf[RDD[(K, _)]],
+ other2.asInstanceOf[RDD[(K, _)]]),
partitioner)
val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
prfs.mapValues {
@@ -412,17 +416,17 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
- def cogroup[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Seq[V], Seq[W]))] = {
- cogroup(other, new HashPartitioner(numSplits))
+ def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Seq[V], Seq[W]))] = {
+ cogroup(other, new HashPartitioner(numPartitions))
}
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
- def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numSplits: Int)
+ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
- cogroup(other1, other2, new HashPartitioner(numSplits))
+ cogroup(other1, other2, new HashPartitioner(numPartitions))
}
/** Alias for cogroup. */
@@ -437,17 +441,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
- * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. If any of
- * the RDDs already has a partitioner, choose that one, otherwise use a default HashPartitioner.
- */
- def defaultPartitioner(rdds: RDD[_]*): Partitioner = {
- for (r <- rdds if r.partitioner != None) {
- return r.partitioner.get
- }
- return new HashPartitioner(self.context.defaultParallelism)
- }
-
- /**
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
* RDD has a known partitioner by only searching the partition that the key maps to.
*/
@@ -634,9 +627,9 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
- def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = {
+ def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[(K,V)] = {
val shuffled =
- new ShuffledRDD[K, V](self, new RangePartitioner(numSplits, self, ascending))
+ new ShuffledRDD[K, V](self, new RangePartitioner(numPartitions, self, ascending))
shuffled.mapPartitions(iter => {
val buf = iter.toArray
if (ascending) {
@@ -650,9 +643,9 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
private[spark]
class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) {
- override def getSplits = firstParent[(K, V)].splits
+ override def getPartitions = firstParent[(K, V)].partitions
override val partitioner = firstParent[(K, V)].partitioner
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
firstParent[(K, V)].iterator(split, context).map{ case (k, v) => (k, f(v)) }
}
@@ -660,9 +653,9 @@ private[spark]
class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U])
extends RDD[(K, U)](prev) {
- override def getSplits = firstParent[(K, V)].splits
+ override def getPartitions = firstParent[(K, V)].partitions
override val partitioner = firstParent[(K, V)].partitioner
- override def compute(split: Split, context: TaskContext) = {
+ override def compute(split: Partition, context: TaskContext) = {
firstParent[(K, V)].iterator(split, context).flatMap { case (k, v) => f(v).map(x => (k, x)) }
}
}
diff --git a/core/src/main/scala/spark/Split.scala b/core/src/main/scala/spark/Partition.scala
index 90d4b47c55..e384308ef6 100644
--- a/core/src/main/scala/spark/Split.scala
+++ b/core/src/main/scala/spark/Partition.scala
@@ -3,7 +3,7 @@ package spark
/**
* A partition of an RDD.
*/
-trait Split extends Serializable {
+trait Partition extends Serializable {
/**
* Get the split's index within its parent RDD
*/
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index 9d5b966e1e..eec0e8dd79 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -9,6 +9,38 @@ abstract class Partitioner extends Serializable {
def getPartition(key: Any): Int
}
+object Partitioner {
+
+ private val useDefaultParallelism = System.getProperty("spark.default.parallelism") != null
+
+ /**
+ * Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
+ *
+ * If any of the RDDs already has a partitioner, choose that one.
+ *
+ * Otherwise, we use a default HashPartitioner. For the number of partitions, if
+ * spark.default.parallelism is set, then we'll use the value from SparkContext
+ * defaultParallelism, otherwise we'll use the max number of upstream partitions.
+ *
+ * Unless spark.default.parallelism is set, He number of partitions will be the
+ * same as the number of partitions in the largest upstream RDD, as this should
+ * be least likely to cause out-of-memory errors.
+ *
+ * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD.
+ */
+ def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
+ val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
+ for (r <- bySize if r.partitioner != None) {
+ return r.partitioner.get
+ }
+ if (useDefaultParallelism) {
+ return new HashPartitioner(rdd.context.defaultParallelism)
+ } else {
+ return new HashPartitioner(bySize.head.partitions.size)
+ }
+ }
+}
+
/**
* A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
*
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 9d6ea782bd..584efa8adf 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -16,19 +16,22 @@ import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+import spark.Partitioner._
import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult
+import spark.rdd.CoalescedRDD
import spark.rdd.CartesianRDD
import spark.rdd.FilteredRDD
import spark.rdd.FlatMappedRDD
import spark.rdd.GlommedRDD
import spark.rdd.MappedRDD
import spark.rdd.MapPartitionsRDD
-import spark.rdd.MapPartitionsWithSplitRDD
+import spark.rdd.MapPartitionsWithIndexRDD
import spark.rdd.PipedRDD
import spark.rdd.SampledRDD
+import spark.rdd.SubtractedRDD
import spark.rdd.UnionRDD
import spark.rdd.ZippedRDD
import spark.storage.StorageLevel
@@ -48,7 +51,7 @@ import SparkContext._
*
* Internally, each RDD is characterized by five main properties:
*
- * - A list of splits (partitions)
+ * - A list of partitions
* - A function for computing each split
* - A list of dependencies on other RDDs
* - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned)
@@ -75,13 +78,13 @@ abstract class RDD[T: ClassManifest](
// =======================================================================
/** Implemented by subclasses to compute a given partition. */
- def compute(split: Split, context: TaskContext): Iterator[T]
+ def compute(split: Partition, context: TaskContext): Iterator[T]
/**
* Implemented by subclasses to return the set of partitions in this RDD. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
*/
- protected def getSplits: Array[Split]
+ protected def getPartitions: Array[Partition]
/**
* Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only
@@ -90,7 +93,7 @@ abstract class RDD[T: ClassManifest](
protected def getDependencies: Seq[Dependency[_]] = deps
/** Optionally overridden by subclasses to specify placement preferences. */
- protected def getPreferredLocations(split: Split): Seq[String] = Nil
+ protected def getPreferredLocations(split: Partition): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
@@ -136,10 +139,10 @@ abstract class RDD[T: ClassManifest](
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
- // Our dependencies and splits will be gotten by calling subclass's methods below, and will
+ // Our dependencies and partitions will be gotten by calling subclass's methods below, and will
// be overwritten when we're checkpointed
private var dependencies_ : Seq[Dependency[_]] = null
- @transient private var splits_ : Array[Split] = null
+ @transient private var partitions_ : Array[Partition] = null
/** An Option holding our checkpoint RDD, if we are checkpointed */
private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
@@ -158,15 +161,15 @@ abstract class RDD[T: ClassManifest](
}
/**
- * Get the array of splits of this RDD, taking into account whether the
+ * Get the array of partitions of this RDD, taking into account whether the
* RDD is checkpointed or not.
*/
- final def splits: Array[Split] = {
- checkpointRDD.map(_.splits).getOrElse {
- if (splits_ == null) {
- splits_ = getSplits
+ final def partitions: Array[Partition] = {
+ checkpointRDD.map(_.partitions).getOrElse {
+ if (partitions_ == null) {
+ partitions_ = getPartitions
}
- splits_
+ partitions_
}
}
@@ -174,7 +177,7 @@ abstract class RDD[T: ClassManifest](
* Get the preferred location of a split, taking into account whether the
* RDD is checkpointed or not.
*/
- final def preferredLocations(split: Split): Seq[String] = {
+ final def preferredLocations(split: Partition): Seq[String] = {
checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
getPreferredLocations(split)
}
@@ -185,7 +188,7 @@ abstract class RDD[T: ClassManifest](
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
- final def iterator(split: Split, context: TaskContext): Iterator[T] = {
+ final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
@@ -196,7 +199,7 @@ abstract class RDD[T: ClassManifest](
/**
* Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
*/
- private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = {
+ private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = {
if (isCheckpointed) {
firstParent[T].iterator(split, context)
} else {
@@ -226,10 +229,15 @@ abstract class RDD[T: ClassManifest](
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int): RDD[T] =
- map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1)
+ def distinct(numPartitions: Int): RDD[T] =
+ map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1)
- def distinct(): RDD[T] = distinct(splits.size)
+ def distinct(): RDD[T] = distinct(partitions.size)
+
+ /**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int): RDD[T] = new CoalescedRDD(this, numPartitions)
/**
* Return a sampled subset of this RDD.
@@ -294,18 +302,25 @@ abstract class RDD[T: ClassManifest](
def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
/**
+ * Return an RDD of grouped items.
+ */
+ def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] =
+ groupBy[K](f, defaultPartitioner(this))
+
+ /**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K: ClassManifest](f: T => K, numSplits: Int): RDD[(K, Seq[T])] = {
- val cleanF = sc.clean(f)
- this.map(t => (cleanF(t), t)).groupByKey(numSplits)
- }
+ def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] =
+ groupBy(f, new HashPartitioner(numPartitions))
/**
* Return an RDD of grouped items.
*/
- def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = groupBy[K](f, sc.defaultParallelism)
+ def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = {
+ val cleanF = sc.clean(f)
+ this.map(t => (cleanF(t), t)).groupByKey(p)
+ }
/**
* Return an RDD created by piping elements to a forked external process.
@@ -330,14 +345,24 @@ abstract class RDD[T: ClassManifest](
preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
- /**
+ /**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
+ def mapPartitionsWithIndex[U: ClassManifest](
+ f: (Int, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] =
+ new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
+ * of the original partition.
+ */
+ @deprecated("use mapPartitionsWithIndex")
def mapPartitionsWithSplit[U: ClassManifest](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithSplitRDD(this, sc.clean(f), preservesPartitioning)
+ new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
/**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
@@ -378,7 +403,27 @@ abstract class RDD[T: ClassManifest](
}
/**
- * Reduces the elements of this RDD using the specified associative binary operator.
+ * Return an RDD with the elements from `this` that are not in `other`.
+ *
+ * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+ * RDD will be <= us.
+ */
+ def subtract(other: RDD[T]): RDD[T] =
+ subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: RDD[T], numPartitions: Int): RDD[T] =
+ subtract(other, new HashPartitioner(numPartitions))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: RDD[T], p: Partitioner): RDD[T] = new SubtractedRDD[T](this, other, p)
+
+ /**
+ * Reduces the elements of this RDD using the specified commutative and associative binary operator.
*/
def reduce(f: (T, T) => T): T = {
val cleanF = sc.clean(f)
@@ -465,7 +510,7 @@ abstract class RDD[T: ClassManifest](
}
result
}
- val evaluator = new CountEvaluator(splits.size, confidence)
+ val evaluator = new CountEvaluator(partitions.size, confidence)
sc.runApproximateJob(this, countElements, evaluator, timeout)
}
@@ -516,7 +561,7 @@ abstract class RDD[T: ClassManifest](
}
map
}
- val evaluator = new GroupedCountEvaluator[T](splits.size, confidence)
+ val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence)
sc.runApproximateJob(this, countPartition, evaluator, timeout)
}
@@ -531,7 +576,7 @@ abstract class RDD[T: ClassManifest](
}
val buf = new ArrayBuffer[T]
var p = 0
- while (buf.size < num && p < splits.size) {
+ while (buf.size < num && p < partitions.size) {
val left = num - buf.size
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true)
buf ++= res(0)
@@ -630,27 +675,32 @@ abstract class RDD[T: ClassManifest](
/** The [[spark.SparkContext]] that this RDD was created on. */
def context = sc
+ // Avoid handling doCheckpoint multiple times to prevent excessive recursion
+ private var doCheckpointCalled = false
+
/**
* Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
* after a job using this RDD has completed (therefore the RDD has been materialized and
* potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
*/
private[spark] def doCheckpoint() {
- if (checkpointData.isDefined) {
- checkpointData.get.doCheckpoint()
- } else {
- dependencies.foreach(_.rdd.doCheckpoint())
+ if (!doCheckpointCalled) {
+ doCheckpointCalled = true
+ if (checkpointData.isDefined) {
+ checkpointData.get.doCheckpoint()
+ } else {
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
}
}
/**
* Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
- * created from the checkpoint file, and forget its old dependencies and splits.
+ * created from the checkpoint file, and forget its old dependencies and partitions.
*/
private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
clearDependencies()
- dependencies_ = null
- splits_ = null
+ partitions_ = null
deps = null // Forget the constructor argument for dependencies too
}
@@ -665,15 +715,15 @@ abstract class RDD[T: ClassManifest](
}
/** A description of this RDD and its recursive dependencies for debugging. */
- def toDebugString(): String = {
+ def toDebugString: String = {
def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = {
- Seq(prefix + rdd + " (" + rdd.splits.size + " splits)") ++
+ Seq(prefix + rdd + " (" + rdd.partitions.size + " partitions)") ++
rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " "))
}
debugString(this).mkString("\n")
}
- override def toString(): String = "%s%s[%d] at %s".format(
+ override def toString: String = "%s%s[%d] at %s".format(
Option(name).map(_ + " ").getOrElse(""),
getClass.getSimpleName,
id,
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
index a4a4ebaf53..d00092e984 100644
--- a/core/src/main/scala/spark/RDDCheckpointData.scala
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -16,7 +16,7 @@ private[spark] object CheckpointState extends Enumeration {
/**
* This class contains all the information related to RDD checkpointing. Each instance of this class
* is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as,
- * manages the post-checkpoint state by providing the updated splits, iterator and preferred locations
+ * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations
* of the checkpointed RDD.
*/
private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
@@ -67,11 +67,11 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
val newRDD = new CheckpointRDD[T](rdd.context, path)
- // Change the dependencies and splits of the RDD
+ // Change the dependencies and partitions of the RDD
RDDCheckpointData.synchronized {
cpFile = Some(path)
cpRDD = Some(newRDD)
- rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits
+ rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
@@ -79,15 +79,15 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
}
// Get preferred location of a split after checkpointing
- def getPreferredLocations(split: Split): Seq[String] = {
+ def getPreferredLocations(split: Partition): Seq[String] = {
RDDCheckpointData.synchronized {
cpRDD.get.preferredLocations(split)
}
}
- def getSplits: Array[Split] = {
+ def getPartitions: Array[Partition] = {
RDDCheckpointData.synchronized {
- cpRDD.get.splits
+ cpRDD.get.partitions
}
}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 0efc00d5dd..df23710d46 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -39,7 +39,7 @@ import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD}
+import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
@@ -53,7 +53,7 @@ import storage.{StorageStatus, StorageUtils, RDDInfo}
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
*
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI.
+ * @param appName A name for your application, to display on the cluster web UI.
* @param sparkHome Location where Spark is installed on cluster nodes.
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
@@ -61,7 +61,7 @@ import storage.{StorageStatus, StorageUtils, RDDInfo}
*/
class SparkContext(
val master: String,
- val jobName: String,
+ val appName: String,
val sparkHome: String = null,
val jars: Seq[String] = Nil,
environment: Map[String, String] = Map())
@@ -143,7 +143,7 @@ class SparkContext(
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName)
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
scheduler.initialize(backend)
scheduler
@@ -162,7 +162,7 @@ class SparkContext(
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val sparkUrl = localCluster.start()
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName)
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
@@ -178,9 +178,9 @@ class SparkContext(
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName)
+ new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
} else {
- new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName)
+ new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
}
scheduler.initialize(backend)
scheduler
@@ -216,7 +216,7 @@ class SparkContext(
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
- new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]())
+ new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
/** Distribute a local Scala collection to form an RDD. */
@@ -229,7 +229,7 @@ class SparkContext(
* Create a new partition for each collection item. */
def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = {
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
- new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs)
+ new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
/**
@@ -439,7 +439,7 @@ class SparkContext(
}
/**
- * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
+ * Broadcast a read-only variable to the cluster, returning a [[spark.broadcast.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
@@ -614,14 +614,14 @@ class SparkContext(
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
- runJob(rdd, func, 0 until rdd.splits.size, false)
+ runJob(rdd, func, 0 until rdd.partitions.size, false)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
- runJob(rdd, func, 0 until rdd.splits.size, false)
+ runJob(rdd, func, 0 until rdd.partitions.size, false)
}
/**
@@ -632,7 +632,7 @@ class SparkContext(
processPartition: (TaskContext, Iterator[T]) => U,
resultHandler: (Int, U) => Unit)
{
- runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler)
+ runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler)
}
/**
@@ -644,7 +644,7 @@ class SparkContext(
resultHandler: (Int, U) => Unit)
{
val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
- runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler)
+ runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler)
}
/**
@@ -693,10 +693,10 @@ class SparkContext(
checkpointDir = Some(dir)
}
- /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
+ /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
def defaultParallelism: Int = taskScheduler.defaultParallelism
- /** Default min number of splits for Hadoop RDDs when not given by user */
+ /** Default min number of partitions for Hadoop RDDs when not given by user */
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
private var nextShuffleId = new AtomicInteger(0)
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 28d643abca..81daacf958 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -454,4 +454,25 @@ private object Utils extends Logging {
def clone[T](value: T, serializer: SerializerInstance): T = {
serializer.deserialize[T](serializer.serialize(value))
}
+
+ /**
+ * Detect whether this thread might be executing a shutdown hook. Will always return true if
+ * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g.
+ * if System.exit was just called by a concurrent thread).
+ *
+ * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing
+ * an IllegalStateException.
+ */
+ def inShutdown(): Boolean = {
+ try {
+ val hook = new Thread {
+ override def run() {}
+ }
+ Runtime.getRuntime.addShutdownHook(hook)
+ Runtime.getRuntime.removeShutdownHook(hook)
+ } catch {
+ case ise: IllegalStateException => return true
+ }
+ return false
+ }
}
diff --git a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
index 843e1bd18b..ba00b6a844 100644
--- a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
@@ -6,8 +6,8 @@ import spark.api.java.function.{Function => JFunction}
import spark.util.StatCounter
import spark.partial.{BoundedDouble, PartialResult}
import spark.storage.StorageLevel
-
import java.lang.Double
+import spark.Partitioner
class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] {
@@ -44,7 +44,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numSplits))
+ def distinct(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numPartitions))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
@@ -53,6 +53,32 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
fromRDD(srdd.filter(x => f(x).booleanValue()))
/**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.coalesce(numPartitions))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ *
+ * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+ * RDD will be <= us.
+ */
+ def subtract(other: JavaDoubleRDD): JavaDoubleRDD =
+ fromRDD(srdd.subtract(other))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaDoubleRDD, numPartitions: Int): JavaDoubleRDD =
+ fromRDD(srdd.subtract(other, numPartitions))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaDoubleRDD, p: Partitioner): JavaDoubleRDD =
+ fromRDD(srdd.subtract(other, p))
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaDoubleRDD =
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index 8ce32e0e2f..c1bd13c49a 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -19,6 +19,7 @@ import spark.OrderedRDDFunctions
import spark.storage.StorageLevel
import spark.HashPartitioner
import spark.Partitioner
+import spark.Partitioner._
import spark.RDD
import spark.SparkContext.rddToPairRDDFunctions
@@ -54,15 +55,20 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numSplits))
+ def distinct(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numPartitions))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
- def filter(f: Function[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] =
+ def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue()))
/**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.coalesce(numPartitions))
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
@@ -97,7 +103,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* In addition, users can control the partitioning of the output RDD, and whether to perform
* map-side aggregation (if a mapper can produce multiple items with the same key).
*/
- def combineByKey[C](createCombiner: Function[V, C],
+ def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
partitioner: Partitioner): JavaPairRDD[K, C] = {
@@ -117,8 +123,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
- numSplits: Int): JavaPairRDD[K, C] =
- combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits))
+ numPartitions: Int): JavaPairRDD[K, C] =
+ combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
/**
* Merge the values for each key using an associative reduce function. This will also perform
@@ -157,10 +163,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits.
+ * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
*/
- def reduceByKey(func: JFunction2[V, V, V], numSplits: Int): JavaPairRDD[K, V] =
- fromRDD(rdd.reduceByKey(func, numSplits))
+ def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairRDD[K, V] =
+ fromRDD(rdd.reduceByKey(func, numPartitions))
/**
* Group the values for each key in the RDD into a single sequence. Allows controlling the
@@ -171,10 +177,31 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
- * resulting RDD with into `numSplits` partitions.
+ * resulting RDD with into `numPartitions` partitions.
+ */
+ def groupByKey(numPartitions: Int): JavaPairRDD[K, JList[V]] =
+ fromRDD(groupByResultToJava(rdd.groupByKey(numPartitions)))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ *
+ * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+ * RDD will be <= us.
*/
- def groupByKey(numSplits: Int): JavaPairRDD[K, JList[V]] =
- fromRDD(groupByResultToJava(rdd.groupByKey(numSplits)))
+ def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.subtract(other))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaPairRDD[K, V], numPartitions: Int): JavaPairRDD[K, V] =
+ fromRDD(rdd.subtract(other, numPartitions))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaPairRDD[K, V], p: Partitioner): JavaPairRDD[K, V] =
+ fromRDD(rdd.subtract(other, p))
/**
* Return a copy of the RDD partitioned using the specified partitioner. If `mapSideCombine`
@@ -215,30 +242,30 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(rdd.rightOuterJoin(other, partitioner))
/**
- * Simplified version of combineByKey that hash-partitions the resulting RDD using the default
- * parallelism level.
+ * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
+ * partitioner/parallelism level.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
implicit val cm: ClassManifest[C] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
- fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners))
+ fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd)))
}
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level.
+ * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
+ * parallelism level.
*/
def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = {
- val partitioner = rdd.defaultPartitioner(rdd)
- fromRDD(reduceByKey(partitioner, func))
+ fromRDD(reduceByKey(defaultPartitioner(rdd), func))
}
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
- * resulting RDD with the default parallelism level.
+ * resulting RDD with the existing partitioner/parallelism level.
*/
def groupByKey(): JavaPairRDD[K, JList[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey()))
@@ -256,14 +283,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
* (k, v2) is in `other`. Performs a hash join across the cluster.
*/
- def join[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, W)] =
- fromRDD(rdd.join(other, numSplits))
+ def join[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, W)] =
+ fromRDD(rdd.join(other, numPartitions))
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
- * using the default level of parallelism.
+ * using the existing partitioner/parallelism level.
*/
def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Option[W])] =
fromRDD(rdd.leftOuterJoin(other))
@@ -272,16 +299,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
- * into `numSplits` partitions.
+ * into `numPartitions` partitions.
*/
- def leftOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, Option[W])] =
- fromRDD(rdd.leftOuterJoin(other, numSplits))
+ def leftOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, Option[W])] =
+ fromRDD(rdd.leftOuterJoin(other, numPartitions))
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
- * RDD using the default parallelism level.
+ * RDD using the existing partitioner/parallelism level.
*/
def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Option[V], W)] =
fromRDD(rdd.rightOuterJoin(other))
@@ -292,8 +319,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD into the given number of partitions.
*/
- def rightOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (Option[V], W)] =
- fromRDD(rdd.rightOuterJoin(other, numSplits))
+ def rightOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (Option[V], W)] =
+ fromRDD(rdd.rightOuterJoin(other, numPartitions))
/**
* Return the key-value pairs in this RDD to the master as a Map.
@@ -304,7 +331,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* Pass each value in the key-value pair RDD through a map function without changing the keys;
* this also retains the original RDD's partitioning.
*/
- def mapValues[U](f: Function[V, U]): JavaPairRDD[K, U] = {
+ def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = {
implicit val cm: ClassManifest[U] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
fromRDD(rdd.mapValues(f))
@@ -357,16 +384,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
- def cogroup[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (JList[V], JList[W])]
- = fromRDD(cogroupResultToJava(rdd.cogroup(other, numSplits)))
+ def cogroup[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (JList[V], JList[W])]
+ = fromRDD(cogroupResultToJava(rdd.cogroup(other, numPartitions)))
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
- def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numSplits: Int)
+ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numPartitions: Int)
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
- fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numSplits)))
+ fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions)))
/** Alias for cogroup. */
def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
@@ -447,7 +474,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
*/
def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
- sortByKey(comp, true)
+ sortByKey(comp, ascending)
}
/**
diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala
index ac31350ec3..3016888898 100644
--- a/core/src/main/scala/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDD.scala
@@ -30,7 +30,7 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numSplits))
+ def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
@@ -39,6 +39,11 @@ JavaRDDLike[T, JavaRDD[T]] {
wrapRDD(rdd.filter((x => f(x).booleanValue())))
/**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int): JavaRDD[T] = rdd.coalesce(numPartitions)
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
@@ -50,6 +55,26 @@ JavaRDDLike[T, JavaRDD[T]] {
*/
def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd))
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ *
+ * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+ * RDD will be <= us.
+ */
+ def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaRDD[T], numPartitions: Int): JavaRDD[T] =
+ wrapRDD(rdd.subtract(other, numPartitions))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
+ wrapRDD(rdd.subtract(other, p))
+
}
object JavaRDD {
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 60025b459c..d884529d7a 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -4,7 +4,7 @@ import java.util.{List => JList}
import scala.Tuple2
import scala.collection.JavaConversions._
-import spark.{SparkContext, Split, RDD, TaskContext}
+import spark.{SparkContext, Partition, RDD, TaskContext}
import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import spark.partial.{PartialResult, BoundedDouble}
@@ -12,7 +12,7 @@ import spark.storage.StorageLevel
import com.google.common.base.Optional
-trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] {
+trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
implicit val classManifest: ClassManifest[T]
@@ -20,7 +20,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
def rdd: RDD[T]
/** Set of partitions in this RDD. */
- def splits: JList[Split] = new java.util.ArrayList(rdd.splits.toSeq)
+ def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq)
/** The [[spark.SparkContext]] that this RDD was created on. */
def context: SparkContext = rdd.context
@@ -36,7 +36,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
- def iterator(split: Split, taskContext: TaskContext): java.util.Iterator[T] =
+ def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] =
asJavaIterator(rdd.iterator(split, taskContext))
// Transformations (return a new RDD)
@@ -82,12 +82,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
}
/**
- * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java.
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
*/
- private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
+ def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+ def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
@@ -110,8 +111,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]):
- JavaPairRDD[K, V] = {
+ def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]):
+ JavaPairRDD[K2, V2] = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
}
@@ -146,12 +147,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K](f: JFunction[T, K], numSplits: Int): JavaPairRDD[K, JList[T]] = {
+ def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = {
implicit val kcm: ClassManifest[K] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
implicit val vcm: ClassManifest[JList[T]] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
- JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numSplits)(f.returnType)))(kcm, vcm)
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm)
}
/**
@@ -201,7 +202,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
}
/**
- * Reduces the elements of this RDD using the specified associative binary operator.
+ * Reduces the elements of this RDD using the specified commutative and associative binary operator.
*/
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
@@ -333,6 +334,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
/** A description of this RDD and its recursive dependencies for debugging. */
def toDebugString(): String = {
- rdd.toDebugString()
+ rdd.toDebugString
}
}
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index 50b8970cd8..f75fc27c7b 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -23,41 +23,41 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
+ * @param appName A name for your application, to display on the cluster web UI
*/
- def this(master: String, jobName: String) = this(new SparkContext(master, jobName))
+ def this(master: String, appName: String) = this(new SparkContext(master, appName))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
+ * @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
*/
- def this(master: String, jobName: String, sparkHome: String, jarFile: String) =
- this(new SparkContext(master, jobName, sparkHome, Seq(jarFile)))
+ def this(master: String, appName: String, sparkHome: String, jarFile: String) =
+ this(new SparkContext(master, appName, sparkHome, Seq(jarFile)))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
+ * @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
*/
- def this(master: String, jobName: String, sparkHome: String, jars: Array[String]) =
- this(new SparkContext(master, jobName, sparkHome, jars.toSeq))
+ def this(master: String, appName: String, sparkHome: String, jars: Array[String]) =
+ this(new SparkContext(master, appName, sparkHome, jars.toSeq))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
+ * @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
* @param environment Environment variables to set on worker nodes
*/
- def this(master: String, jobName: String, sparkHome: String, jars: Array[String],
+ def this(master: String, appName: String, sparkHome: String, jars: Array[String],
environment: JMap[String, String]) =
- this(new SparkContext(master, jobName, sparkHome, jars.toSeq, environment))
+ this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment))
private[spark] val env = sc.env
diff --git a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java
deleted file mode 100644
index 68b6fd6622..0000000000
--- a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java
+++ /dev/null
@@ -1,20 +0,0 @@
-package spark.api.java;
-
-import spark.api.java.JavaPairRDD;
-import spark.api.java.JavaRDDLike;
-import spark.api.java.function.PairFlatMapFunction;
-
-import java.io.Serializable;
-
-/**
- * Workaround for SPARK-668.
- */
-class PairFlatMapWorkaround<T> implements Serializable {
- /**
- * Return a new RDD by first applying a function to all elements of this
- * RDD, and then flattening the results.
- */
- public <K, V> JavaPairRDD<K, V> flatMap(PairFlatMapFunction<T, K, V> f) {
- return ((JavaRDDLike <T, ?>) this).doFlatMap(f);
- }
-}
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
index 519e310323..d618c098c2 100644
--- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -9,7 +9,7 @@ import java.util.Arrays
*
* Stores the unique id() of the Python-side partitioning function so that it is incorporated into
* equality comparisons. Correctness requires that the id is a unique identifier for the
- * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
+ * lifetime of the program (i.e. that it is not re-used as the id of a different partitioning
* function). This can be ensured by using the Python id() function and maintaining a reference
* to the Python partitioning function so that its id() is not reused.
*/
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index ab8351e55e..8c73477384 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -32,11 +32,11 @@ private[spark] class PythonRDD[T: ClassManifest](
this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
broadcastVars, accumulator)
- override def getSplits = parent.splits
+ override def getPartitions = parent.partitions
override val partitioner = if (preservePartitoning) parent.partitioner else None
- override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
@@ -65,7 +65,7 @@ private[spark] class PythonRDD[T: ClassManifest](
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
val dOut = new DataOutputStream(proc.getOutputStream)
- // Split index
+ // Partition index
dOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
@@ -155,8 +155,8 @@ private class PythonException(msg: String) extends Exception(msg)
*/
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
RDD[(Array[Byte], Array[Byte])](prev) {
- override def getSplits = prev.splits
- override def compute(split: Split, context: TaskContext) =
+ override def getPartitions = prev.partitions
+ override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala
index 7160fc05fc..6659e53b25 100644
--- a/core/src/main/scala/spark/deploy/JobDescription.scala
+++ b/core/src/main/scala/spark/deploy/ApplicationDescription.scala
@@ -1,6 +1,6 @@
package spark.deploy
-private[spark] class JobDescription(
+private[spark] class ApplicationDescription(
val name: String,
val cores: Int,
val memoryPerSlave: Int,
@@ -10,5 +10,5 @@ private[spark] class JobDescription(
val user = System.getProperty("user.name", "<unknown>")
- override def toString: String = "JobDescription(" + name + ")"
+ override def toString: String = "ApplicationDescription(" + name + ")"
}
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 35f40c6e91..3cbf4fdd98 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -1,7 +1,7 @@
package spark.deploy
import spark.deploy.ExecutorState.ExecutorState
-import spark.deploy.master.{WorkerInfo, JobInfo}
+import spark.deploy.master.{WorkerInfo, ApplicationInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
@@ -23,37 +23,39 @@ case class RegisterWorker(
private[spark]
case class ExecutorStateChanged(
- jobId: String,
+ appId: String,
execId: Int,
state: ExecutorState,
message: Option[String],
exitStatus: Option[Int])
extends DeployMessage
+private[spark] case class Heartbeat(workerId: String) extends DeployMessage
+
// Master to Worker
private[spark] case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
private[spark] case class RegisterWorkerFailed(message: String) extends DeployMessage
-private[spark] case class KillExecutor(jobId: String, execId: Int) extends DeployMessage
+private[spark] case class KillExecutor(appId: String, execId: Int) extends DeployMessage
private[spark] case class LaunchExecutor(
- jobId: String,
+ appId: String,
execId: Int,
- jobDesc: JobDescription,
+ appDesc: ApplicationDescription,
cores: Int,
memory: Int,
sparkHome: String)
extends DeployMessage
-
// Client to Master
-private[spark] case class RegisterJob(jobDescription: JobDescription) extends DeployMessage
+private[spark] case class RegisterApplication(appDescription: ApplicationDescription)
+ extends DeployMessage
// Master to Client
private[spark]
-case class RegisteredJob(jobId: String) extends DeployMessage
+case class RegisteredApplication(appId: String) extends DeployMessage
private[spark]
case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
@@ -63,7 +65,7 @@ case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String
exitStatus: Option[Int])
private[spark]
-case class JobKilled(message: String)
+case class appKilled(message: String)
// Internal message in Client
@@ -76,8 +78,11 @@ private[spark] case object RequestMasterState
// Master to MasterWebUI
private[spark]
-case class MasterState(uri: String, workers: Array[WorkerInfo], activeJobs: Array[JobInfo],
- completedJobs: Array[JobInfo])
+case class MasterState(host: String, port: Int, workers: Array[WorkerInfo],
+ activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
+
+ def uri = "spark://" + host + ":" + port
+}
// WorkerWebUI to Worker
private[spark] case object RequestWorkerState
@@ -85,6 +90,6 @@ private[spark] case object RequestWorkerState
// Worker to WorkerWebUI
private[spark]
-case class WorkerState(uri: String, workerId: String, executors: List[ExecutorRunner],
+case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner],
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala
index 732fa08064..38a6ebfc24 100644
--- a/core/src/main/scala/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala
@@ -1,6 +1,6 @@
package spark.deploy
-import master.{JobInfo, WorkerInfo}
+import master.{ApplicationInfo, WorkerInfo}
import worker.ExecutorRunner
import cc.spray.json._
@@ -20,8 +20,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
)
}
- implicit object JobInfoJsonFormat extends RootJsonWriter[JobInfo] {
- def write(obj: JobInfo) = JsObject(
+ implicit object AppInfoJsonFormat extends RootJsonWriter[ApplicationInfo] {
+ def write(obj: ApplicationInfo) = JsObject(
"starttime" -> JsNumber(obj.startTime),
"id" -> JsString(obj.id),
"name" -> JsString(obj.desc.name),
@@ -31,8 +31,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
"submitdate" -> JsString(obj.submitDate.toString))
}
- implicit object JobDescriptionJsonFormat extends RootJsonWriter[JobDescription] {
- def write(obj: JobDescription) = JsObject(
+ implicit object AppDescriptionJsonFormat extends RootJsonWriter[ApplicationDescription] {
+ def write(obj: ApplicationDescription) = JsObject(
"name" -> JsString(obj.name),
"cores" -> JsNumber(obj.cores),
"memoryperslave" -> JsNumber(obj.memoryPerSlave),
@@ -44,8 +44,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
def write(obj: ExecutorRunner) = JsObject(
"id" -> JsNumber(obj.execId),
"memory" -> JsNumber(obj.memory),
- "jobid" -> JsString(obj.jobId),
- "jobdesc" -> obj.jobDesc.toJson.asJsObject
+ "appid" -> JsString(obj.appId),
+ "appdesc" -> obj.appDesc.toJson.asJsObject
)
}
@@ -57,8 +57,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
"coresused" -> JsNumber(obj.workers.map(_.coresUsed).sum),
"memory" -> JsNumber(obj.workers.map(_.memory).sum),
"memoryused" -> JsNumber(obj.workers.map(_.memoryUsed).sum),
- "activejobs" -> JsArray(obj.activeJobs.toList.map(_.toJson)),
- "completedjobs" -> JsArray(obj.completedJobs.toList.map(_.toJson))
+ "activeapps" -> JsArray(obj.activeApps.toList.map(_.toJson)),
+ "completedapps" -> JsArray(obj.completedApps.toList.map(_.toJson))
)
}
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index a63eee1233..1a95524cf9 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -8,25 +8,25 @@ import akka.pattern.AskTimeoutException
import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown
-import spark.deploy.RegisterJob
+import spark.deploy.RegisterApplication
import spark.deploy.master.Master
import akka.remote.RemoteClientDisconnected
import akka.actor.Terminated
import akka.dispatch.Await
/**
- * The main class used to talk to a Spark deploy cluster. Takes a master URL, a job description,
- * and a listener for job events, and calls back the listener when various events occur.
+ * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description,
+ * and a listener for cluster events, and calls back the listener when various events occur.
*/
private[spark] class Client(
actorSystem: ActorSystem,
masterUrl: String,
- jobDescription: JobDescription,
+ appDescription: ApplicationDescription,
listener: ClientListener)
extends Logging {
var actor: ActorRef = null
- var jobId: String = null
+ var appId: String = null
class ClientActor extends Actor with Logging {
var master: ActorRef = null
@@ -38,7 +38,7 @@ private[spark] class Client(
try {
master = context.actorFor(Master.toAkkaUrl(masterUrl))
masterAddress = master.path.address
- master ! RegisterJob(jobDescription)
+ master ! RegisterApplication(appDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
} catch {
@@ -50,17 +50,17 @@ private[spark] class Client(
}
override def receive = {
- case RegisteredJob(jobId_) =>
- jobId = jobId_
- listener.connected(jobId)
+ case RegisteredApplication(appId_) =>
+ appId = appId_
+ listener.connected(appId)
case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) =>
- val fullId = jobId + "/" + id
+ val fullId = appId + "/" + id
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores))
listener.executorAdded(fullId, workerId, host, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) =>
- val fullId = jobId + "/" + id
+ val fullId = appId + "/" + id
val messageText = message.map(s => " (" + s + ")").getOrElse("")
logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText))
if (ExecutorState.isFinished(state)) {
@@ -107,7 +107,7 @@ private[spark] class Client(
def stop() {
if (actor != null) {
try {
- val timeout = 1.seconds
+ val timeout = 5.seconds
val future = actor.ask(StopClient)(timeout)
Await.result(future, timeout)
} catch {
diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala
index 7035f4b394..b7008321df 100644
--- a/core/src/main/scala/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala
@@ -8,7 +8,7 @@ package spark.deploy.client
* Users of this API should *not* block inside the callback methods.
*/
private[spark] trait ClientListener {
- def connected(jobId: String): Unit
+ def connected(appId: String): Unit
def disconnected(): Unit
diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala
index 8764c400e2..dc004b59ca 100644
--- a/core/src/main/scala/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/spark/deploy/client/TestClient.scala
@@ -2,13 +2,13 @@ package spark.deploy.client
import spark.util.AkkaUtils
import spark.{Logging, Utils}
-import spark.deploy.{Command, JobDescription}
+import spark.deploy.{Command, ApplicationDescription}
private[spark] object TestClient {
class TestListener extends ClientListener with Logging {
def connected(id: String) {
- logInfo("Connected to master, got job ID " + id)
+ logInfo("Connected to master, got app ID " + id)
}
def disconnected() {
@@ -24,7 +24,7 @@ private[spark] object TestClient {
def main(args: Array[String]) {
val url = args(0)
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
- val desc = new JobDescription(
+ val desc = new ApplicationDescription(
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener)
diff --git a/core/src/main/scala/spark/deploy/master/JobInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
index a274b21c34..3591a94072 100644
--- a/core/src/main/scala/spark/deploy/master/JobInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala
@@ -1,18 +1,18 @@
package spark.deploy.master
-import spark.deploy.JobDescription
+import spark.deploy.ApplicationDescription
import java.util.Date
import akka.actor.ActorRef
import scala.collection.mutable
-private[spark] class JobInfo(
+private[spark] class ApplicationInfo(
val startTime: Long,
val id: String,
- val desc: JobDescription,
+ val desc: ApplicationDescription,
val submitDate: Date,
val driver: ActorRef)
{
- var state = JobState.WAITING
+ var state = ApplicationState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]
var coresGranted = 0
var endTime = -1L
@@ -48,7 +48,7 @@ private[spark] class JobInfo(
_retryCount
}
- def markFinished(endState: JobState.Value) {
+ def markFinished(endState: ApplicationState.Value) {
state = endState
endTime = System.currentTimeMillis()
}
diff --git a/core/src/main/scala/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/spark/deploy/master/ApplicationState.scala
new file mode 100644
index 0000000000..15016b388d
--- /dev/null
+++ b/core/src/main/scala/spark/deploy/master/ApplicationState.scala
@@ -0,0 +1,11 @@
+package spark.deploy.master
+
+private[spark] object ApplicationState
+ extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
+
+ type ApplicationState = Value
+
+ val WAITING, RUNNING, FINISHED, FAILED = Value
+
+ val MAX_NUM_RETRY = 10
+}
diff --git a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala
index 1db2c32633..48e6055fb5 100644
--- a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala
@@ -4,12 +4,12 @@ import spark.deploy.ExecutorState
private[spark] class ExecutorInfo(
val id: Int,
- val job: JobInfo,
+ val application: ApplicationInfo,
val worker: WorkerInfo,
val cores: Int,
val memory: Int) {
var state = ExecutorState.LAUNCHING
- def fullId: String = job.id + "/" + id
+ def fullId: String = application.id + "/" + id
}
diff --git a/core/src/main/scala/spark/deploy/master/JobState.scala b/core/src/main/scala/spark/deploy/master/JobState.scala
deleted file mode 100644
index 2b70cf0191..0000000000
--- a/core/src/main/scala/spark/deploy/master/JobState.scala
+++ /dev/null
@@ -1,9 +0,0 @@
-package spark.deploy.master
-
-private[spark] object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
- type JobState = Value
-
- val WAITING, RUNNING, FINISHED, FAILED = Value
-
- val MAX_NUM_RETRY = 10
-}
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index 92e7914b1b..b7f167425f 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -3,6 +3,7 @@ package spark.deploy.master
import akka.actor._
import akka.actor.Terminated
import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown}
+import akka.util.duration._
import java.text.SimpleDateFormat
import java.util.Date
@@ -15,21 +16,24 @@ import spark.util.AkkaUtils
private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
- val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For job IDs
+ val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
+ val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
- var nextJobNumber = 0
+ var nextAppNumber = 0
val workers = new HashSet[WorkerInfo]
val idToWorker = new HashMap[String, WorkerInfo]
val actorToWorker = new HashMap[ActorRef, WorkerInfo]
val addressToWorker = new HashMap[Address, WorkerInfo]
- val jobs = new HashSet[JobInfo]
- val idToJob = new HashMap[String, JobInfo]
- val actorToJob = new HashMap[ActorRef, JobInfo]
- val addressToJob = new HashMap[Address, JobInfo]
+ val apps = new HashSet[ApplicationInfo]
+ val idToApp = new HashMap[String, ApplicationInfo]
+ val actorToApp = new HashMap[ActorRef, ApplicationInfo]
+ val addressToApp = new HashMap[Address, ApplicationInfo]
- val waitingJobs = new ArrayBuffer[JobInfo]
- val completedJobs = new ArrayBuffer[JobInfo]
+ val waitingApps = new ArrayBuffer[ApplicationInfo]
+ val completedApps = new ArrayBuffer[ApplicationInfo]
+
+ var firstApp: Option[ApplicationInfo] = None
val masterPublicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
@@ -37,15 +41,16 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
// As a temporary workaround before better ways of configuring memory, we allow users to set
- // a flag that will perform round-robin scheduling across the nodes (spreading out each job
- // among all the nodes) instead of trying to consolidate each job onto a small # of nodes.
- val spreadOutJobs = System.getProperty("spark.deploy.spreadOut", "false").toBoolean
+ // a flag that will perform round-robin scheduling across the nodes (spreading out each app
+ // among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
+ val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "false").toBoolean
override def preStart() {
logInfo("Starting Spark master at spark://" + ip + ":" + port)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
startWebUi()
+ context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis)(timeOutDeadWorkers())
}
def startWebUi() {
@@ -73,92 +78,101 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
}
- case RegisterJob(description) => {
- logInfo("Registering job " + description.name)
- val job = addJob(description, sender)
- logInfo("Registered job " + description.name + " with ID " + job.id)
- waitingJobs += job
+ case RegisterApplication(description) => {
+ logInfo("Registering app " + description.name)
+ val app = addApplication(description, sender)
+ logInfo("Registered app " + description.name + " with ID " + app.id)
+ waitingApps += app
context.watch(sender) // This doesn't work with remote actors but helps for testing
- sender ! RegisteredJob(job.id)
+ sender ! RegisteredApplication(app.id)
schedule()
}
- case ExecutorStateChanged(jobId, execId, state, message, exitStatus) => {
- val execOption = idToJob.get(jobId).flatMap(job => job.executors.get(execId))
+ case ExecutorStateChanged(appId, execId, state, message, exitStatus) => {
+ val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId))
execOption match {
case Some(exec) => {
exec.state = state
- exec.job.driver ! ExecutorUpdated(execId, state, message, exitStatus)
+ exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus)
if (ExecutorState.isFinished(state)) {
- val jobInfo = idToJob(jobId)
- // Remove this executor from the worker and job
+ val appInfo = idToApp(appId)
+ // Remove this executor from the worker and app
logInfo("Removing executor " + exec.fullId + " because it is " + state)
- jobInfo.removeExecutor(exec)
+ appInfo.removeExecutor(exec)
exec.worker.removeExecutor(exec)
// Only retry certain number of times so we don't go into an infinite loop.
- if (jobInfo.incrementRetryCount < JobState.MAX_NUM_RETRY) {
+ if (appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) {
schedule()
} else {
- logError("Job %s with ID %s failed %d times, removing it".format(
- jobInfo.desc.name, jobInfo.id, jobInfo.retryCount))
- removeJob(jobInfo)
+ logError("Application %s with ID %s failed %d times, removing it".format(
+ appInfo.desc.name, appInfo.id, appInfo.retryCount))
+ removeApplication(appInfo)
}
}
}
case None =>
- logWarning("Got status update for unknown executor " + jobId + "/" + execId)
+ logWarning("Got status update for unknown executor " + appId + "/" + execId)
+ }
+ }
+
+ case Heartbeat(workerId) => {
+ idToWorker.get(workerId) match {
+ case Some(workerInfo) =>
+ workerInfo.lastHeartbeat = System.currentTimeMillis()
+ case None =>
+ logWarning("Got heartbeat from unregistered worker " + workerId)
}
}
case Terminated(actor) => {
- // The disconnected actor could've been either a worker or a job; remove whichever of
+ // The disconnected actor could've been either a worker or an app; remove whichever of
// those we have an entry for in the corresponding actor hashmap
actorToWorker.get(actor).foreach(removeWorker)
- actorToJob.get(actor).foreach(removeJob)
+ actorToApp.get(actor).foreach(removeApplication)
}
case RemoteClientDisconnected(transport, address) => {
- // The disconnected client could've been either a worker or a job; remove whichever it was
+ // The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
- addressToJob.get(address).foreach(removeJob)
+ addressToApp.get(address).foreach(removeApplication)
}
case RemoteClientShutdown(transport, address) => {
- // The disconnected client could've been either a worker or a job; remove whichever it was
+ // The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
- addressToJob.get(address).foreach(removeJob)
+ addressToApp.get(address).foreach(removeApplication)
}
case RequestMasterState => {
- sender ! MasterState(ip + ":" + port, workers.toArray, jobs.toArray, completedJobs.toArray)
+ sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray)
}
}
/**
- * Can a job use the given worker? True if the worker has enough memory and we haven't already
- * launched an executor for the job on it (right now the standalone backend doesn't like having
+ * Can an app use the given worker? True if the worker has enough memory and we haven't already
+ * launched an executor for the app on it (right now the standalone backend doesn't like having
* two executors on the same worker).
*/
- def canUse(job: JobInfo, worker: WorkerInfo): Boolean = {
- worker.memoryFree >= job.desc.memoryPerSlave && !worker.hasExecutor(job)
+ def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = {
+ worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app)
}
/**
- * Schedule the currently available resources among waiting jobs. This method will be called
- * every time a new job joins or resource availability changes.
+ * Schedule the currently available resources among waiting apps. This method will be called
+ * every time a new app joins or resource availability changes.
*/
def schedule() {
- // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first job
- // in the queue, then the second job, etc.
- if (spreadOutJobs) {
- // Try to spread out each job among all the nodes, until it has all its cores
- for (job <- waitingJobs if job.coresLeft > 0) {
+ // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app
+ // in the queue, then the second app, etc.
+ if (spreadOutApps) {
+ // Try to spread out each app among all the nodes, until it has all its cores
+ for (app <- waitingApps if app.coresLeft > 0) {
val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE)
- .filter(canUse(job, _)).sortBy(_.coresFree).reverse
+ .filter(canUse(app, _)).sortBy(_.coresFree).reverse
val numUsable = usableWorkers.length
val assigned = new Array[Int](numUsable) // Number of cores to give on each node
- var toAssign = math.min(job.coresLeft, usableWorkers.map(_.coresFree).sum)
+ var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum)
var pos = 0
while (toAssign > 0) {
if (usableWorkers(pos).coresFree - assigned(pos) > 0) {
@@ -170,22 +184,22 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
// Now that we've decided how many cores to give on each node, let's actually give them
for (pos <- 0 until numUsable) {
if (assigned(pos) > 0) {
- val exec = job.addExecutor(usableWorkers(pos), assigned(pos))
- launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome)
- job.state = JobState.RUNNING
+ val exec = app.addExecutor(usableWorkers(pos), assigned(pos))
+ launchExecutor(usableWorkers(pos), exec, app.desc.sparkHome)
+ app.state = ApplicationState.RUNNING
}
}
}
} else {
- // Pack each job into as few nodes as possible until we've assigned all its cores
- for (worker <- workers if worker.coresFree > 0) {
- for (job <- waitingJobs if job.coresLeft > 0) {
- if (canUse(job, worker)) {
- val coresToUse = math.min(worker.coresFree, job.coresLeft)
+ // Pack each app into as few nodes as possible until we've assigned all its cores
+ for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) {
+ for (app <- waitingApps if app.coresLeft > 0) {
+ if (canUse(app, worker)) {
+ val coresToUse = math.min(worker.coresFree, app.coresLeft)
if (coresToUse > 0) {
- val exec = job.addExecutor(worker, coresToUse)
- launchExecutor(worker, exec, job.desc.sparkHome)
- job.state = JobState.RUNNING
+ val exec = app.addExecutor(worker, coresToUse)
+ launchExecutor(worker, exec, app.desc.sparkHome)
+ app.state = ApplicationState.RUNNING
}
}
}
@@ -196,8 +210,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
- worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome)
- exec.job.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
+ worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
+ exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
@@ -219,45 +233,65 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
actorToWorker -= worker.actor
addressToWorker -= worker.actor.path.address
for (exec <- worker.executors.values) {
- exec.job.driver ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None)
- exec.job.executors -= exec.id
+ logInfo("Telling app of lost executor: " + exec.id)
+ exec.application.driver ! ExecutorUpdated(exec.id, ExecutorState.LOST, Some("worker lost"), None)
+ exec.application.removeExecutor(exec)
}
}
- def addJob(desc: JobDescription, driver: ActorRef): JobInfo = {
+ def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
- val job = new JobInfo(now, newJobId(date), desc, date, driver)
- jobs += job
- idToJob(job.id) = job
- actorToJob(driver) = job
- addressToJob(driver.path.address) = job
- return job
+ val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver)
+ apps += app
+ idToApp(app.id) = app
+ actorToApp(driver) = app
+ addressToApp(driver.path.address) = app
+ if (firstApp == None) {
+ firstApp = Some(app)
+ }
+ val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray
+ if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) {
+ logWarning("Could not find any workers with enough memory for " + firstApp.get.id)
+ }
+ return app
}
- def removeJob(job: JobInfo) {
- if (jobs.contains(job)) {
- logInfo("Removing job " + job.id)
- jobs -= job
- idToJob -= job.id
- actorToJob -= job.driver
- addressToWorker -= job.driver.path.address
- completedJobs += job // Remember it in our history
- waitingJobs -= job
- for (exec <- job.executors.values) {
+ def removeApplication(app: ApplicationInfo) {
+ if (apps.contains(app)) {
+ logInfo("Removing app " + app.id)
+ apps -= app
+ idToApp -= app.id
+ actorToApp -= app.driver
+ addressToWorker -= app.driver.path.address
+ completedApps += app // Remember it in our history
+ waitingApps -= app
+ for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
- exec.worker.actor ! KillExecutor(exec.job.id, exec.id)
+ exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
}
- job.markFinished(JobState.FINISHED) // TODO: Mark it as FAILED if it failed
+ app.markFinished(ApplicationState.FINISHED) // TODO: Mark it as FAILED if it failed
schedule()
}
}
- /** Generate a new job ID given a job's submission date */
- def newJobId(submitDate: Date): String = {
- val jobId = "job-%s-%04d".format(DATE_FORMAT.format(submitDate), nextJobNumber)
- nextJobNumber += 1
- jobId
+ /** Generate a new app ID given a app's submission date */
+ def newApplicationId(submitDate: Date): String = {
+ val appId = "app-%s-%04d".format(DATE_FORMAT.format(submitDate), nextAppNumber)
+ nextAppNumber += 1
+ appId
+ }
+
+ /** Check for, and remove, any timed-out workers */
+ def timeOutDeadWorkers() {
+ // Copy the workers into an array so we don't modify the hashset while iterating through it
+ val expirationTime = System.currentTimeMillis() - WORKER_TIMEOUT
+ val toRemove = workers.filter(_.lastHeartbeat < expirationTime).toArray
+ for (worker <- toRemove) {
+ logWarning("Removing %s because we got no heartbeat in %d seconds".format(
+ worker.id, WORKER_TIMEOUT))
+ removeWorker(worker)
+ }
}
}
diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
index 529f72e9da..54faa375fb 100644
--- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
@@ -40,27 +40,27 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
}
}
} ~
- path("job") {
- parameters("jobId", 'format ?) {
- case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
+ path("app") {
+ parameters("appId", 'format ?) {
+ case (appId, Some(js)) if (js.equalsIgnoreCase("json")) =>
val future = master ? RequestMasterState
- val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
- masterState.activeJobs.find(_.id == jobId).getOrElse({
- masterState.completedJobs.find(_.id == jobId).getOrElse(null)
+ val appInfo = for (masterState <- future.mapTo[MasterState]) yield {
+ masterState.activeApps.find(_.id == appId).getOrElse({
+ masterState.completedApps.find(_.id == appId).getOrElse(null)
})
}
respondWithMediaType(MediaTypes.`application/json`) { ctx =>
- ctx.complete(jobInfo.mapTo[JobInfo])
+ ctx.complete(appInfo.mapTo[ApplicationInfo])
}
- case (jobId, _) =>
+ case (appId, _) =>
completeWith {
val future = master ? RequestMasterState
future.map { state =>
val masterState = state.asInstanceOf[MasterState]
- val job = masterState.activeJobs.find(_.id == jobId).getOrElse({
- masterState.completedJobs.find(_.id == jobId).getOrElse(null)
+ val app = masterState.activeApps.find(_.id == appId).getOrElse({
+ masterState.completedApps.find(_.id == appId).getOrElse(null)
})
- spark.deploy.master.html.job_details.render(job)
+ spark.deploy.master.html.app_details.render(app)
}
}
}
diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
index 5a7f5fef8a..23df1bb463 100644
--- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
@@ -18,6 +18,8 @@ private[spark] class WorkerInfo(
var coresUsed = 0
var memoryUsed = 0
+ var lastHeartbeat = System.currentTimeMillis()
+
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
@@ -35,8 +37,8 @@ private[spark] class WorkerInfo(
}
}
- def hasExecutor(job: JobInfo): Boolean = {
- executors.values.exists(_.job == job)
+ def hasExecutor(app: ApplicationInfo): Boolean = {
+ executors.values.exists(_.application == app)
}
def webUiAddress : String = {
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index 4ef637090c..de11771c8e 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -1,7 +1,7 @@
package spark.deploy.worker
import java.io._
-import spark.deploy.{ExecutorState, ExecutorStateChanged, JobDescription}
+import spark.deploy.{ExecutorState, ExecutorStateChanged, ApplicationDescription}
import akka.actor.ActorRef
import spark.{Utils, Logging}
import java.net.{URI, URL}
@@ -14,9 +14,9 @@ import spark.deploy.ExecutorStateChanged
* Manages the execution of one executor process.
*/
private[spark] class ExecutorRunner(
- val jobId: String,
+ val appId: String,
val execId: Int,
- val jobDesc: JobDescription,
+ val appDesc: ApplicationDescription,
val cores: Int,
val memory: Int,
val worker: ActorRef,
@@ -26,7 +26,7 @@ private[spark] class ExecutorRunner(
val workDir: File)
extends Logging {
- val fullId = jobId + "/" + execId
+ val fullId = appId + "/" + execId
var workerThread: Thread = null
var process: Process = null
var shutdownHook: Thread = null
@@ -60,7 +60,7 @@ private[spark] class ExecutorRunner(
process.destroy()
process.waitFor()
}
- worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None, None)
+ worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None)
Runtime.getRuntime.removeShutdownHook(shutdownHook)
}
}
@@ -74,10 +74,10 @@ private[spark] class ExecutorRunner(
}
def buildCommandSeq(): Seq[String] = {
- val command = jobDesc.command
- val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run";
+ val command = appDesc.command
+ val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run"
val runScript = new File(sparkHome, script).getCanonicalPath
- Seq(runScript, command.mainClass) ++ command.arguments.map(substituteVariables)
+ Seq(runScript, command.mainClass) ++ (command.arguments ++ Seq(appId)).map(substituteVariables)
}
/** Spawn a thread that will redirect a given stream to a file */
@@ -96,12 +96,12 @@ private[spark] class ExecutorRunner(
}
/**
- * Download and run the executor described in our JobDescription
+ * Download and run the executor described in our ApplicationDescription
*/
def fetchAndRunExecutor() {
try {
// Create the executor's working directory
- val executorDir = new File(workDir, jobId + "/" + execId)
+ val executorDir = new File(workDir, appId + "/" + execId)
if (!executorDir.mkdirs()) {
throw new IOException("Failed to create directory " + executorDir)
}
@@ -110,7 +110,7 @@ private[spark] class ExecutorRunner(
val command = buildCommandSeq()
val builder = new ProcessBuilder(command: _*).directory(executorDir)
val env = builder.environment()
- for ((key, value) <- jobDesc.command.environment) {
+ for ((key, value) <- appDesc.command.environment) {
env.put(key, value)
}
env.put("SPARK_MEM", memory.toString + "m")
@@ -128,7 +128,7 @@ private[spark] class ExecutorRunner(
// times on the same machine.
val exitCode = process.waitFor()
val message = "Command exited with code " + exitCode
- worker ! ExecutorStateChanged(jobId, execId, ExecutorState.FAILED, Some(message),
+ worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message),
Some(exitCode))
} catch {
case interrupted: InterruptedException =>
@@ -140,7 +140,7 @@ private[spark] class ExecutorRunner(
process.destroy()
}
val message = e.getClass + ": " + e.getMessage
- worker ! ExecutorStateChanged(jobId, execId, ExecutorState.FAILED, Some(message), None)
+ worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None)
}
}
}
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 38547ec4f1..2bbc931316 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -2,6 +2,7 @@ package spark.deploy.worker
import scala.collection.mutable.{ArrayBuffer, HashMap}
import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
+import akka.util.duration._
import spark.{Logging, Utils}
import spark.util.AkkaUtils
import spark.deploy._
@@ -26,6 +27,9 @@ private[spark] class Worker(
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
+ // Send a heartbeat every (heartbeat timeout) / 4 milliseconds
+ val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4
+
var master: ActorRef = null
var masterWebUiUrl : String = ""
val workerId = generateWorkerId()
@@ -97,24 +101,27 @@ private[spark] class Worker(
case RegisteredWorker(url) =>
masterWebUiUrl = url
logInfo("Successfully registered with master")
+ context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) {
+ master ! Heartbeat(workerId)
+ }
case RegisterWorkerFailed(message) =>
logError("Worker registration failed: " + message)
System.exit(1)
- case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) =>
- logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name))
+ case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
+ logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
val manager = new ExecutorRunner(
- jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
- executors(jobId + "/" + execId) = manager
+ appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
+ executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
memoryUsed += memory_
- master ! ExecutorStateChanged(jobId, execId, ExecutorState.RUNNING, None, None)
+ master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None)
- case ExecutorStateChanged(jobId, execId, state, message, exitStatus) =>
- master ! ExecutorStateChanged(jobId, execId, state, message, exitStatus)
- val fullId = jobId + "/" + execId
+ case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
+ master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
+ val fullId = appId + "/" + execId
if (ExecutorState.isFinished(state)) {
val executor = executors(fullId)
logInfo("Executor " + fullId + " finished with state " + state +
@@ -126,8 +133,8 @@ private[spark] class Worker(
memoryUsed -= executor.memory
}
- case KillExecutor(jobId, execId) =>
- val fullId = jobId + "/" + execId
+ case KillExecutor(appId, execId) =>
+ val fullId = appId + "/" + execId
executors.get(fullId) match {
case Some(executor) =>
logInfo("Asked to kill executor " + fullId)
@@ -140,7 +147,7 @@ private[spark] class Worker(
masterDisconnected()
case RequestWorkerState => {
- sender ! WorkerState(ip + ":" + port, workerId, executors.values.toList,
+ sender ! WorkerState(ip, port, workerId, executors.values.toList,
finishedExecutors.values.toList, masterUrl, cores, memory,
coresUsed, memoryUsed, masterWebUiUrl)
}
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
index 37524a7c82..08f02bad80 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
@@ -92,7 +92,7 @@ private[spark] class WorkerArguments(args: Array[String]) {
"Options:\n" +
" -c CORES, --cores CORES Number of cores to use\n" +
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
- " -d DIR, --work-dir DIR Directory to run jobs in (default: SPARK_HOME/work)\n" +
+ " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
" --webui-port PORT Port for web UI (default: 8081)")
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
index ef81f072a3..135cc2e86c 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
@@ -41,9 +41,9 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
}
} ~
path("log") {
- parameters("jobId", "executorId", "logType") { (jobId, executorId, logType) =>
+ parameters("appId", "executorId", "logType") { (appId, executorId, logType) =>
respondWithMediaType(cc.spray.http.MediaTypes.`text/plain`) {
- getFromFileName("work/" + jobId + "/" + executorId + "/" + logType)
+ getFromFileName("work/" + appId + "/" + executorId + "/" + logType)
}
}
} ~
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index bd21ba719a..5de09030aa 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -50,14 +50,19 @@ private[spark] class Executor extends Logging {
override def uncaughtException(thread: Thread, exception: Throwable) {
try {
logError("Uncaught exception in thread " + thread, exception)
- if (exception.isInstanceOf[OutOfMemoryError]) {
- System.exit(ExecutorExitCode.OOM)
- } else {
- System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+
+ // We may have been called from a shutdown hook. If so, we must not call System.exit().
+ // (If we do, we will deadlock.)
+ if (!Utils.inShutdown()) {
+ if (exception.isInstanceOf[OutOfMemoryError]) {
+ System.exit(ExecutorExitCode.OOM)
+ } else {
+ System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ }
}
} catch {
- case oom: OutOfMemoryError => System.exit(ExecutorExitCode.OOM)
- case t: Throwable => System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
+ case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
+ case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
}
}
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index 224c126fdd..9a82c3054c 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -68,8 +68,9 @@ private[spark] object StandaloneExecutorBackend {
}
def main(args: Array[String]) {
- if (args.length != 4) {
- System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores>")
+ if (args.length < 4) {
+ //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
+ System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]")
System.exit(1)
}
run(args(0), args(1), args(2), args(3).toInt)
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index cd5b7d57f3..d1451bc212 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -198,7 +198,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
outbox.synchronized {
outbox.addMessage(message)
if (channel.isConnected) {
- changeConnectionKeyInterest(SelectionKey.OP_WRITE)
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
}
}
}
@@ -219,7 +219,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
def finishConnect() {
try {
channel.finishConnect
- changeConnectionKeyInterest(SelectionKey.OP_WRITE)
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
} catch {
case e: Exception => {
@@ -239,8 +239,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
currentBuffers ++= chunk.buffers
}
case None => {
- changeConnectionKeyInterest(0)
- /*key.interestOps(0)*/
+ changeConnectionKeyInterest(SelectionKey.OP_READ)
return
}
}
@@ -267,6 +266,23 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
}
+
+ override def read() {
+ // We don't expect the other side to send anything; so, we just read to detect an error or EOF.
+ try {
+ val length = channel.read(ByteBuffer.allocate(1))
+ if (length == -1) { // EOF
+ close()
+ } else if (length > 0) {
+ logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId)
+ }
+ } catch {
+ case e: Exception =>
+ logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e)
+ callOnExceptionCallback(e)
+ close()
+ }
+ }
}
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index c7f226044d..b6ec664d7e 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -66,31 +66,28 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
- val thisInstance = this
val selectorThread = new Thread("connection-manager-thread") {
- override def run() {
- thisInstance.run()
- }
+ override def run() = ConnectionManager.this.run()
}
selectorThread.setDaemon(true)
selectorThread.start()
- def run() {
+ private def run() {
try {
while(!selectorThread.isInterrupted) {
- for( (connectionManagerId, sendingConnection) <- connectionRequests) {
+ for ((connectionManagerId, sendingConnection) <- connectionRequests) {
sendingConnection.connect()
addConnection(sendingConnection)
connectionRequests -= connectionManagerId
}
sendMessageRequests.synchronized {
- while(!sendMessageRequests.isEmpty) {
+ while (!sendMessageRequests.isEmpty) {
val (message, connection) = sendMessageRequests.dequeue
connection.send(message)
}
}
- while(!keyInterestChangeRequests.isEmpty) {
+ while (!keyInterestChangeRequests.isEmpty) {
val (key, ops) = keyInterestChangeRequests.dequeue
val connection = connectionsByKey(key)
val lastOps = key.interestOps()
@@ -126,14 +123,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
if (key.isValid) {
if (key.isAcceptable) {
acceptConnection(key)
- } else
- if (key.isConnectable) {
+ } else if (key.isConnectable) {
connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
- } else
- if (key.isReadable) {
+ } else if (key.isReadable) {
connectionsByKey(key).read()
- } else
- if (key.isWritable) {
+ } else if (key.isWritable) {
connectionsByKey(key).write()
}
}
@@ -144,7 +138,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
}
- def acceptConnection(key: SelectionKey) {
+ private def acceptConnection(key: SelectionKey) {
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
val newChannel = serverChannel.accept()
val newConnection = new ReceivingConnection(newChannel, selector)
@@ -154,7 +148,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
}
- def addConnection(connection: Connection) {
+ private def addConnection(connection: Connection) {
connectionsByKey += ((connection.key, connection))
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
@@ -165,7 +159,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
connection.onClose(removeConnection)
}
- def removeConnection(connection: Connection) {
+ private def removeConnection(connection: Connection) {
connectionsByKey -= connection.key
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
@@ -222,16 +216,16 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
}
- def handleConnectionError(connection: Connection, e: Exception) {
+ private def handleConnectionError(connection: Connection, e: Exception) {
logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
removeConnection(connection)
}
- def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+ private def changeConnectionKeyInterest(connection: Connection, ops: Int) {
keyInterestChangeRequests += ((connection.key, ops))
}
- def receiveMessage(connection: Connection, message: Message) {
+ private def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
@@ -351,7 +345,6 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private[spark] object ConnectionManager {
def main(args: Array[String]) {
-
val manager = new ConnectionManager(9999)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
index 24b4909380..de2dce161a 100644
--- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala
+++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala
@@ -20,7 +20,7 @@ private[spark] class ApproximateActionListener[T, U, R](
extends JobListener {
val startTime = System.currentTimeMillis()
- val totalTasks = rdd.splits.size
+ val totalTasks = rdd.partitions.size
var finishedTasks = 0
var failure: Option[Exception] = None // Set if the job has failed (permanently)
var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index 2c022f88e0..7348c4f15b 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -1,9 +1,9 @@
package spark.rdd
import scala.collection.mutable.HashMap
-import spark.{RDD, SparkContext, SparkEnv, Split, TaskContext}
+import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext}
-private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
+private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
val index = idx
}
@@ -11,10 +11,6 @@ private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc, Nil) {
- @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
- new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
- }).toArray
-
@transient lazy val locations_ = {
val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
@@ -22,11 +18,14 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
HashMap(blockIds.zip(locations):_*)
}
- override def getSplits = splits_
+ override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => {
+ new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
+ }).toArray
- override def compute(split: Split, context: TaskContext): Iterator[T] = {
+
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
- val blockId = split.asInstanceOf[BlockRDDSplit].blockId
+ val blockId = split.asInstanceOf[BlockRDDPartition].blockId
blockManager.get(blockId) match {
case Some(block) => block.asInstanceOf[Iterator[T]]
case None =>
@@ -34,11 +33,8 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
}
}
- override def getPreferredLocations(split: Split) =
- locations_(split.asInstanceOf[BlockRDDSplit].blockId)
+ override def getPreferredLocations(split: Partition): Seq[String] =
+ locations_(split.asInstanceOf[BlockRDDPartition].blockId)
- override def clearDependencies() {
- splits_ = null
- }
}
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 0f9ca06531..38600b8be4 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -5,22 +5,22 @@ import spark._
private[spark]
-class CartesianSplit(
+class CartesianPartition(
idx: Int,
@transient rdd1: RDD[_],
@transient rdd2: RDD[_],
s1Index: Int,
s2Index: Int
- ) extends Split {
- var s1 = rdd1.splits(s1Index)
- var s2 = rdd2.splits(s2Index)
+ ) extends Partition {
+ var s1 = rdd1.partitions(s1Index)
+ var s2 = rdd2.partitions(s2Index)
override val index: Int = idx
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent split at the time of task serialization
- s1 = rdd1.splits(s1Index)
- s2 = rdd2.splits(s2Index)
+ s1 = rdd1.partitions(s1Index)
+ s2 = rdd2.partitions(s2Index)
oos.defaultWriteObject()
}
}
@@ -33,39 +33,40 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
extends RDD[Pair[T, U]](sc, Nil)
with Serializable {
- val numSplitsInRdd2 = rdd2.splits.size
+ val numPartitionsInRdd2 = rdd2.partitions.size
- override def getSplits: Array[Split] = {
+ override def getPartitions: Array[Partition] = {
// create the cross product split
- val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
- for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
- val idx = s1.index * numSplitsInRdd2 + s2.index
- array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index)
+ val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size)
+ for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) {
+ val idx = s1.index * numPartitionsInRdd2 + s2.index
+ array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index)
}
array
}
- override def getPreferredLocations(split: Split) = {
- val currSplit = split.asInstanceOf[CartesianSplit]
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ val currSplit = split.asInstanceOf[CartesianPartition]
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
}
- override def compute(split: Split, context: TaskContext) = {
- val currSplit = split.asInstanceOf[CartesianSplit]
+ override def compute(split: Partition, context: TaskContext) = {
+ val currSplit = split.asInstanceOf[CartesianPartition]
for (x <- rdd1.iterator(currSplit.s1, context);
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(rdd1) {
- def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
+ def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2)
},
new NarrowDependency(rdd2) {
- def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2)
+ def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2)
}
)
override def clearDependencies() {
+ super.clearDependencies()
rdd1 = null
rdd2 = null
}
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
index 96b593ba7c..9e37bdf659 100644
--- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -9,7 +9,7 @@ import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat
-private[spark] class CheckpointRDDSplit(val index: Int) extends Split {}
+private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
/**
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
@@ -20,29 +20,27 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
- @transient val splits_ : Array[Split] = {
+ override def getPartitions: Array[Partition] = {
val dirContents = fs.listStatus(new Path(checkpointPath))
- val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
- val numSplits = splitFiles.size
- if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
- !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) {
+ val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
+ val numPartitions = partitionFiles.size
+ if (numPartitions > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
+ ! partitionFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1)))) {
throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
}
- Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i))
+ Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
}
checkpointData = Some(new RDDCheckpointData[T](this))
checkpointData.get.cpFile = Some(checkpointPath)
- override def getSplits = splits_
-
- override def getPreferredLocations(split: Split): Seq[String] = {
+ override def getPreferredLocations(split: Partition): Seq[String] = {
val status = fs.getFileStatus(new Path(checkpointPath))
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
}
- override def compute(split: Split, context: TaskContext): Iterator[T] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
CheckpointRDD.readFromFile(file, context)
}
@@ -109,7 +107,7 @@ private[spark] object CheckpointRDD extends Logging {
deserializeStream.asIterator.asInstanceOf[Iterator[T]]
}
- // Test whether CheckpointRDD generate expected number of splits despite
+ // Test whether CheckpointRDD generate expected number of partitions despite
// each split file having multiple blocks. This needs to be run on a
// cluster (mesos or standalone) using HDFS.
def main(args: Array[String]) {
@@ -122,8 +120,8 @@ private[spark] object CheckpointRDD extends Logging {
val fs = path.getFileSystem(new Configuration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
- assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same")
- assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same")
+ assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
+ assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
fs.delete(path)
}
}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 4893fe8d78..5200fb6b65 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -5,7 +5,7 @@ import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
-import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
+import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Partition, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
@@ -14,13 +14,13 @@ private[spark] sealed trait CoGroupSplitDep extends Serializable
private[spark] case class NarrowCoGroupSplitDep(
rdd: RDD[_],
splitIndex: Int,
- var split: Split
+ var split: Partition
) extends CoGroupSplitDep {
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent split at the time of task serialization
- split = rdd.splits(splitIndex)
+ split = rdd.partitions(splitIndex)
oos.defaultWriteObject()
}
}
@@ -28,7 +28,7 @@ private[spark] case class NarrowCoGroupSplitDep(
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark]
-class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable {
+class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Partition with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
}
@@ -40,49 +40,45 @@ private[spark] class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
-class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
- extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging {
+class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
+ extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
- val aggr = new CoGroupAggregator
+ private val aggr = new CoGroupAggregator
- @transient var deps_ = {
- val deps = new ArrayBuffer[Dependency[_]]
- for ((rdd, index) <- rdds.zipWithIndex) {
+ override def getDependencies: Seq[Dependency[_]] = {
+ rdds.map { rdd =>
if (rdd.partitioner == Some(part)) {
logInfo("Adding one-to-one dependency with " + rdd)
- deps += new OneToOneDependency(rdd)
+ new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
- deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
+ new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
}
}
- deps.toList
}
- override def getDependencies = deps_
-
- @transient var splits_ : Array[Split] = {
- val array = new Array[Split](part.numPartitions)
+ override def getPartitions: Array[Partition] = {
+ val array = new Array[Partition](part.numPartitions)
for (i <- 0 until array.size) {
- array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
+ // Each CoGroupPartition will have a dependency per contributing RDD
+ array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) =>
+ // Assume each RDD contributed a single dependency, and get it
dependencies(j) match {
case s: ShuffleDependency[_, _] =>
- new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
+ new ShuffleCoGroupSplitDep(s.shuffleId)
case _ =>
- new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep
+ new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
}.toList)
}
array
}
- override def getSplits = splits_
-
override val partitioner = Some(part)
- override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
- val split = s.asInstanceOf[CoGroupSplit]
+ override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
+ val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
@@ -97,7 +93,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
}
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
- case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
for ((k, v) <- rdd.iterator(itsSplit, context)) {
getSeq(k.asInstanceOf[K])(depNum) += v
@@ -115,8 +111,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
}
override def clearDependencies() {
- deps_ = null
- splits_ = null
+ super.clearDependencies()
rdds = null
}
}
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 4c57434b65..0d16cf6e85 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -1,19 +1,19 @@
package spark.rdd
-import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Split, TaskContext}
+import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Partition, TaskContext}
import java.io.{ObjectOutputStream, IOException}
-private[spark] case class CoalescedRDDSplit(
+private[spark] case class CoalescedRDDPartition(
index: Int,
@transient rdd: RDD[_],
parentsIndices: Array[Int]
- ) extends Split {
- var parents: Seq[Split] = parentsIndices.map(rdd.splits(_))
+ ) extends Partition {
+ var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_))
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent split at the time of task serialization
- parents = parentsIndices.map(rdd.splits(_))
+ parents = parentsIndices.map(rdd.partitions(_))
oos.defaultWriteObject()
}
}
@@ -31,33 +31,34 @@ class CoalescedRDD[T: ClassManifest](
maxPartitions: Int)
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
- override def getSplits: Array[Split] = {
- val prevSplits = prev.splits
+ override def getPartitions: Array[Partition] = {
+ val prevSplits = prev.partitions
if (prevSplits.length < maxPartitions) {
- prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
+ prevSplits.map(_.index).map{idx => new CoalescedRDDPartition(idx, prev, Array(idx)) }
} else {
(0 until maxPartitions).map { i =>
val rangeStart = (i * prevSplits.length) / maxPartitions
val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
- new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray)
+ new CoalescedRDDPartition(i, prev, (rangeStart until rangeEnd).toArray)
}.toArray
}
}
- override def compute(split: Split, context: TaskContext): Iterator[T] = {
- split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ split.asInstanceOf[CoalescedRDDPartition].parents.iterator.flatMap { parentSplit =>
firstParent[T].iterator(parentSplit, context)
}
}
- override def getDependencies: Seq[Dependency[_]] = List(
- new NarrowDependency(prev) {
+ override def getDependencies: Seq[Dependency[_]] = {
+ Seq(new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
- splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
- }
- )
+ partitions(id).asInstanceOf[CoalescedRDDPartition].parentsIndices
+ })
+ }
override def clearDependencies() {
+ super.clearDependencies()
prev = null
}
}
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index 6dbe235bd9..c84ec39d21 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -1,16 +1,16 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{OneToOneDependency, RDD, Partition, TaskContext}
private[spark] class FilteredRDD[T: ClassManifest](
prev: RDD[T],
f: T => Boolean)
extends RDD[T](prev) {
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
override val partitioner = prev.partitioner // Since filter cannot change a partition's keys
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
firstParent[T].iterator(split, context).filter(f)
}
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 1b604c66e2..8ebc778925 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -1,6 +1,6 @@
package spark.rdd
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
private[spark]
@@ -9,8 +9,8 @@ class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
f: T => TraversableOnce[U])
extends RDD[U](prev) {
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
firstParent[T].iterator(split, context).flatMap(f)
}
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index 051bffed19..e16c7ba881 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -1,12 +1,12 @@
package spark.rdd
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T])
extends RDD[Array[T]](prev) {
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
Array(firstParent[T].iterator(split, context).toArray).iterator
}
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index f547f53812..78097502bc 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -15,14 +15,14 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
-import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext}
+import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext}
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
-private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
- extends Split {
+private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit)
+ extends Partition {
val inputSplit = new SerializableWritable[InputSplit](s)
@@ -42,18 +42,17 @@ class HadoopRDD[K, V](
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int)
- extends RDD[(K, V)](sc, Nil) {
+ extends RDD[(K, V)](sc, Nil) with Logging {
// A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
- val confBroadcast = sc.broadcast(new SerializableWritable(conf))
+ private val confBroadcast = sc.broadcast(new SerializableWritable(conf))
- @transient
- val splits_ : Array[Split] = {
+ override def getPartitions: Array[Partition] = {
val inputFormat = createInputFormat(conf)
val inputSplits = inputFormat.getSplits(conf, minSplits)
- val array = new Array[Split](inputSplits.size)
+ val array = new Array[Partition](inputSplits.size)
for (i <- 0 until inputSplits.size) {
- array(i) = new HadoopSplit(id, i, inputSplits(i))
+ array(i) = new HadoopPartition(id, i, inputSplits(i))
}
array
}
@@ -63,10 +62,8 @@ class HadoopRDD[K, V](
.asInstanceOf[InputFormat[K, V]]
}
- override def getSplits = splits_
-
- override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
- val split = theSplit.asInstanceOf[HadoopSplit]
+ override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[HadoopPartition]
var reader: RecordReader[K, V] = null
val conf = confBroadcast.value.value
@@ -74,7 +71,7 @@ class HadoopRDD[K, V](
reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback(() => reader.close())
+ context.addOnCompleteCallback{ () => close() }
val key: K = reader.createKey()
val value: V = reader.createValue()
@@ -91,9 +88,6 @@ class HadoopRDD[K, V](
}
gotNext = true
}
- if (finished) {
- reader.close()
- }
!finished
}
@@ -107,11 +101,19 @@ class HadoopRDD[K, V](
gotNext = false
(key, value)
}
+
+ private def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
+ }
}
- override def getPreferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Partition): Seq[String] = {
// TODO: Filtering out "localhost" in case of file:// URLs
- val hadoopSplit = split.asInstanceOf[HadoopSplit]
+ val hadoopSplit = split.asInstanceOf[HadoopPartition]
hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
}
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index 073f7d7d2a..d283c5b2bb 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -1,6 +1,6 @@
package spark.rdd
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
private[spark]
@@ -13,8 +13,8 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
override val partitioner =
if (preservesPartitioning) firstParent[T].partitioner else None
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
f(firstParent[T].iterator(split, context))
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala
index 2ddc3d01b6..afb7504ba1 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala
@@ -1,24 +1,24 @@
package spark.rdd
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
/**
- * A variant of the MapPartitionsRDD that passes the split index into the
+ * A variant of the MapPartitionsRDD that passes the partition index into the
* closure. This can be used to generate or collect partition specific
* information such as the number of tuples in a partition.
*/
private[spark]
-class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
+class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean
) extends RDD[U](prev) {
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
override val partitioner = if (preservesPartitioning) prev.partitioner else None
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
f(split.index, firstParent[T].iterator(split, context))
-} \ No newline at end of file
+}
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index 5466c9c657..af07311b6d 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -1,13 +1,13 @@
package spark.rdd
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
private[spark]
class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
extends RDD[U](prev) {
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
- override def compute(split: Split, context: TaskContext) =
+ override def compute(split: Partition, context: TaskContext) =
firstParent[T].iterator(split, context).map(f)
}
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index c3b155fcbd..df2361025c 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -7,12 +7,12 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext}
+import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext}
private[spark]
-class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
- extends Split {
+class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
+ extends Partition {
val serializableHadoopSplit = new SerializableWritable(rawSplit)
@@ -26,10 +26,11 @@ class NewHadoopRDD[K, V](
valueClass: Class[V],
@transient conf: Configuration)
extends RDD[(K, V)](sc, Nil)
- with HadoopMapReduceUtil {
+ with HadoopMapReduceUtil
+ with Logging {
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
- val confBroadcast = sc.broadcast(new SerializableWritable(conf))
+ private val confBroadcast = sc.broadcast(new SerializableWritable(conf))
// private val serializableConf = new SerializableWritable(conf)
private val jobtrackerId: String = {
@@ -39,21 +40,19 @@ class NewHadoopRDD[K, V](
@transient private val jobId = new JobID(jobtrackerId, id)
- @transient private val splits_ : Array[Split] = {
+ override def getPartitions: Array[Partition] = {
val inputFormat = inputFormatClass.newInstance
val jobContext = newJobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
- val result = new Array[Split](rawSplits.size)
+ val result = new Array[Partition](rawSplits.size)
for (i <- 0 until rawSplits.size) {
- result(i) = new NewHadoopSplit(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
}
result
}
- override def getSplits = splits_
-
- override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
- val split = theSplit.asInstanceOf[NewHadoopSplit]
+ override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[NewHadoopPartition]
val conf = confBroadcast.value.value
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
@@ -63,7 +62,7 @@ class NewHadoopRDD[K, V](
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
// Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback(() => reader.close())
+ context.addOnCompleteCallback(() => close())
var havePair = false
var finished = false
@@ -83,10 +82,18 @@ class NewHadoopRDD[K, V](
havePair = false
return (reader.getCurrentKey, reader.getCurrentValue)
}
+
+ private def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
+ }
}
- override def getPreferredLocations(split: Split) = {
- val theSplit = split.asInstanceOf[NewHadoopSplit]
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ val theSplit = split.asInstanceOf[NewHadoopPartition]
theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
}
}
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala
index 10adcd53ec..07585a88ce 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala
@@ -1,28 +1,29 @@
-package spark
+package spark.rdd
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
import scala.collection.Map
+import spark.{RDD, TaskContext, SparkContext, Partition}
-private[spark] class ParallelCollectionSplit[T: ClassManifest](
+private[spark] class ParallelCollectionPartition[T: ClassManifest](
val rddId: Long,
val slice: Int,
values: Seq[T])
- extends Split with Serializable {
+ extends Partition with Serializable {
def iterator: Iterator[T] = values.iterator
override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt
override def equals(other: Any): Boolean = other match {
- case that: ParallelCollectionSplit[_] => (this.rddId == that.rddId && this.slice == that.slice)
+ case that: ParallelCollectionPartition[_] => (this.rddId == that.rddId && this.slice == that.slice)
case _ => false
}
override val index: Int = slice
}
-private[spark] class ParallelCollection[T: ClassManifest](
+private[spark] class ParallelCollectionRDD[T: ClassManifest](
@transient sc: SparkContext,
@transient data: Seq[T],
numSlices: Int,
@@ -33,26 +34,20 @@ private[spark] class ParallelCollection[T: ClassManifest](
// instead.
// UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
- @transient var splits_ : Array[Split] = {
- val slices = ParallelCollection.slice(data, numSlices).toArray
- slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
+ override def getPartitions: Array[Partition] = {
+ val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
+ slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
- override def getSplits = splits_
+ override def compute(s: Partition, context: TaskContext) =
+ s.asInstanceOf[ParallelCollectionPartition[T]].iterator
- override def compute(s: Split, context: TaskContext) =
- s.asInstanceOf[ParallelCollectionSplit[T]].iterator
-
- override def getPreferredLocations(s: Split): Seq[String] = {
+ override def getPreferredLocations(s: Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)
}
-
- override def clearDependencies() {
- splits_ = null
- }
}
-private object ParallelCollection {
+private object ParallelCollectionRDD {
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
index a50ce75171..41ff62dd22 100644
--- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
@@ -1,9 +1,9 @@
package spark.rdd
-import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext}
+import spark.{NarrowDependency, RDD, SparkEnv, Partition, TaskContext}
-class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split {
+class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends Partition {
override val index = idx
}
@@ -16,15 +16,15 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
extends NarrowDependency[T](rdd) {
@transient
- val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
- .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split }
+ val partitions: Array[Partition] = rdd.partitions.filter(s => partitionFilterFunc(s.index))
+ .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition }
override def getParents(partitionId: Int) = List(partitions(partitionId).index)
}
/**
- * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on
+ * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on
* all partitions. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
@@ -34,9 +34,21 @@ class PartitionPruningRDD[T: ClassManifest](
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
- override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(
- split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context)
+ override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator(
+ split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context)
- override protected def getSplits =
+ override protected def getPartitions: Array[Partition] =
getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
}
+
+
+object PartitionPruningRDD {
+
+ /**
+ * Create a PartitionPruningRDD. This function can be used to create the PartitionPruningRDD
+ * when its type T is not known at compile time.
+ */
+ def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = {
+ new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassManifest)
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 6631f83510..962a1b21ad 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -8,7 +8,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
-import spark.{RDD, SparkEnv, Split, TaskContext}
+import spark.{RDD, SparkEnv, Partition, TaskContext}
/**
@@ -27,9 +27,9 @@ class PipedRDD[T: ClassManifest](
// using a standard StringTokenizer (i.e. by spaces)
def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
- override def getSplits = firstParent[T].splits
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
- override def compute(split: Split, context: TaskContext): Iterator[String] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index e24ad23b21..243673f151 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -5,10 +5,10 @@ import java.util.Random
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
-import spark.{RDD, Split, TaskContext}
+import spark.{RDD, Partition, TaskContext}
private[spark]
-class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
+class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
override val index: Int = prev.index
}
@@ -19,18 +19,16 @@ class SampledRDD[T: ClassManifest](
seed: Int)
extends RDD[T](prev) {
- @transient var splits_ : Array[Split] = {
+ override def getPartitions: Array[Partition] = {
val rg = new Random(seed)
- firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt))
+ firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt))
}
- override def getSplits = splits_
+ override def getPreferredLocations(split: Partition): Seq[String] =
+ firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev)
- override def getPreferredLocations(split: Split) =
- firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
-
- override def compute(splitIn: Split, context: TaskContext) = {
- val split = splitIn.asInstanceOf[SampledRDDSplit]
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
+ val split = splitIn.asInstanceOf[SampledRDDPartition]
if (withReplacement) {
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
@@ -48,8 +46,4 @@ class SampledRDD[T: ClassManifest](
firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
}
}
-
- override def clearDependencies() {
- splits_ = null
- }
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index d396478673..c2f118305f 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -1,9 +1,9 @@
package spark.rdd
-import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
+import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
import spark.SparkContext._
-private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
+private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
override def hashCode(): Int = idx
}
@@ -22,9 +22,11 @@ class ShuffledRDD[K, V](
override val partitioner = Some(part)
- override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
+ override def getPartitions: Array[Partition] = {
+ Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
+ }
- override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
}
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
new file mode 100644
index 0000000000..daf9cc993c
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -0,0 +1,108 @@
+package spark.rdd
+
+import java.util.{HashSet => JHashSet}
+import scala.collection.JavaConversions._
+import spark.RDD
+import spark.Partitioner
+import spark.Dependency
+import spark.TaskContext
+import spark.Partition
+import spark.SparkEnv
+import spark.ShuffleDependency
+import spark.OneToOneDependency
+
+/**
+ * An optimized version of cogroup for set difference/subtraction.
+ *
+ * It is possible to implement this operation with just `cogroup`, but
+ * that is less efficient because all of the entries from `rdd2`, for
+ * both matching and non-matching values in `rdd1`, are kept in the
+ * JHashMap until the end.
+ *
+ * With this implementation, only the entries from `rdd1` are kept in-memory,
+ * and the entries from `rdd2` are essentially streamed, as we only need to
+ * touch each once to decide if the value needs to be removed.
+ *
+ * This is particularly helpful when `rdd1` is much smaller than `rdd2`, as
+ * you can use `rdd1`'s partitioner/partition size and not worry about running
+ * out of memory because of the size of `rdd2`.
+ */
+private[spark] class SubtractedRDD[T: ClassManifest](
+ @transient var rdd1: RDD[T],
+ @transient var rdd2: RDD[T],
+ part: Partitioner) extends RDD[T](rdd1.context, Nil) {
+
+ override def getDependencies: Seq[Dependency[_]] = {
+ Seq(rdd1, rdd2).map { rdd =>
+ if (rdd.partitioner == Some(part)) {
+ logInfo("Adding one-to-one dependency with " + rdd)
+ new OneToOneDependency(rdd)
+ } else {
+ logInfo("Adding shuffle dependency with " + rdd)
+ val mapSideCombinedRDD = rdd.mapPartitions(i => {
+ val set = new JHashSet[T]()
+ while (i.hasNext) {
+ set.add(i.next)
+ }
+ set.iterator
+ }, true)
+ // ShuffleDependency requires a tuple (k, v), which it will partition by k.
+ // We need this to partition to map to the same place as the k for
+ // OneToOneDependency, which means:
+ // - for already-tupled RDD[(A, B)], into getPartition(a)
+ // - for non-tupled RDD[C], into getPartition(c)
+ val part2 = new Partitioner() {
+ def numPartitions = part.numPartitions
+ def getPartition(key: Any) = key match {
+ case (k, v) => part.getPartition(k)
+ case k => part.getPartition(k)
+ }
+ }
+ new ShuffleDependency(mapSideCombinedRDD.map((_, null)), part2)
+ }
+ }
+ }
+
+ override def getPartitions: Array[Partition] = {
+ val array = new Array[Partition](part.numPartitions)
+ for (i <- 0 until array.size) {
+ // Each CoGroupPartition will depend on rdd1 and rdd2
+ array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
+ dependencies(j) match {
+ case s: ShuffleDependency[_, _] =>
+ new ShuffleCoGroupSplitDep(s.shuffleId)
+ case _ =>
+ new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
+ }
+ }.toList)
+ }
+ array
+ }
+
+ override val partitioner = Some(part)
+
+ override def compute(p: Partition, context: TaskContext): Iterator[T] = {
+ val partition = p.asInstanceOf[CoGroupPartition]
+ val set = new JHashSet[T]
+ def integrate(dep: CoGroupSplitDep, op: T => Unit) = dep match {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
+ for (k <- rdd.iterator(itsSplit, context))
+ op(k.asInstanceOf[T])
+ case ShuffleCoGroupSplitDep(shuffleId) =>
+ for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index))
+ op(k.asInstanceOf[T])
+ }
+ // the first dep is rdd1; add all keys to the set
+ integrate(partition.deps(0), set.add)
+ // the second dep is rdd2; remove all of its keys from the set
+ integrate(partition.deps(1), set.remove)
+ set.iterator
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ }
+
+} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index 26a2d511f2..2c52a67e22 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -1,13 +1,13 @@
package spark.rdd
import scala.collection.mutable.ArrayBuffer
-import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
+import spark.{Dependency, RangeDependency, RDD, SparkContext, Partition, TaskContext}
import java.io.{ObjectOutputStream, IOException}
-private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
- extends Split {
+private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
+ extends Partition {
- var split: Split = rdd.splits(splitIndex)
+ var split: Partition = rdd.partitions(splitIndex)
def iterator(context: TaskContext) = rdd.iterator(split, context)
@@ -18,7 +18,7 @@ private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIn
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent split at the time of task serialization
- split = rdd.splits(splitIndex)
+ split = rdd.partitions(splitIndex)
oos.defaultWriteObject()
}
}
@@ -28,11 +28,11 @@ class UnionRDD[T: ClassManifest](
@transient var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
- override def getSplits: Array[Split] = {
- val array = new Array[Split](rdds.map(_.splits.size).sum)
+ override def getPartitions: Array[Partition] = {
+ val array = new Array[Partition](rdds.map(_.partitions.size).sum)
var pos = 0
- for (rdd <- rdds; split <- rdd.splits) {
- array(pos) = new UnionSplit(pos, rdd, split.index)
+ for (rdd <- rdds; split <- rdd.partitions) {
+ array(pos) = new UnionPartition(pos, rdd, split.index)
pos += 1
}
array
@@ -42,19 +42,15 @@ class UnionRDD[T: ClassManifest](
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
- deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
- pos += rdd.splits.size
+ deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size)
+ pos += rdd.partitions.size
}
deps
}
- override def compute(s: Split, context: TaskContext): Iterator[T] =
- s.asInstanceOf[UnionSplit[T]].iterator(context)
+ override def compute(s: Partition, context: TaskContext): Iterator[T] =
+ s.asInstanceOf[UnionPartition[T]].iterator(context)
- override def getPreferredLocations(s: Split): Seq[String] =
- s.asInstanceOf[UnionSplit[T]].preferredLocations()
-
- override def clearDependencies() {
- rdds = null
- }
+ override def getPreferredLocations(s: Partition): Seq[String] =
+ s.asInstanceOf[UnionPartition[T]].preferredLocations()
}
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index e5df6d8c72..e80ec17aa5 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -1,17 +1,17 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext}
+import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext}
import java.io.{ObjectOutputStream, IOException}
-private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
+private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest](
idx: Int,
@transient rdd1: RDD[T],
@transient rdd2: RDD[U]
- ) extends Split {
+ ) extends Partition {
- var split1 = rdd1.splits(idx)
- var split2 = rdd1.splits(idx)
+ var split1 = rdd1.partitions(idx)
+ var split2 = rdd1.partitions(idx)
override val index: Int = idx
def splits = (split1, split2)
@@ -19,8 +19,8 @@ private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent split at the time of task serialization
- split1 = rdd1.splits(idx)
- split2 = rdd2.splits(idx)
+ split1 = rdd1.partitions(idx)
+ split2 = rdd2.partitions(idx)
oos.defaultWriteObject()
}
}
@@ -29,31 +29,31 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
sc: SparkContext,
var rdd1: RDD[T],
var rdd2: RDD[U])
- extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
- with Serializable {
+ extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) {
- override def getSplits: Array[Split] = {
- if (rdd1.splits.size != rdd2.splits.size) {
+ override def getPartitions: Array[Partition] = {
+ if (rdd1.partitions.size != rdd2.partitions.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
- val array = new Array[Split](rdd1.splits.size)
- for (i <- 0 until rdd1.splits.size) {
- array(i) = new ZippedSplit(i, rdd1, rdd2)
+ val array = new Array[Partition](rdd1.partitions.size)
+ for (i <- 0 until rdd1.partitions.size) {
+ array(i) = new ZippedPartition(i, rdd1, rdd2)
}
array
}
- override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
- val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = {
+ val (split1, split2) = s.asInstanceOf[ZippedPartition[T, U]].splits
rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
}
- override def getPreferredLocations(s: Split): Seq[String] = {
- val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ override def getPreferredLocations(s: Partition): Seq[String] = {
+ val (split1, split2) = s.asInstanceOf[ZippedPartition[T, U]].splits
rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
}
override def clearDependencies() {
+ super.clearDependencies()
rdd1 = null
rdd2 = null
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 319eef6978..bf0837c066 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -106,7 +106,7 @@ class DAGScheduler(
private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
if (!cacheLocs.contains(rdd.id)) {
- val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
+ val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
locations => locations.map(_.ip).toList
}.toArray
@@ -141,9 +141,9 @@ class DAGScheduler(
private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
if (shuffleDep != None) {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
- // since we can't do it in the RDD constructor because # of splits is unknown
+ // since we can't do it in the RDD constructor because # of partitions is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
- mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
+ mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
@@ -162,7 +162,7 @@ class DAGScheduler(
if (!visited(r)) {
visited += r
// Kind of ugly: need to register RDDs with the cache here since
- // we can't do it in its constructor because # of splits is unknown
+ // we can't do it in its constructor because # of partitions is unknown
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
@@ -257,7 +257,7 @@ class DAGScheduler(
{
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- val partitions = (0 until rdd.splits.size).toArray
+ val partitions = (0 until rdd.partitions.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
return listener.awaitResult() // Will throw an exception if the job fails
}
@@ -386,7 +386,7 @@ class DAGScheduler(
try {
SparkEnv.set(env)
val rdd = job.finalStage.rdd
- val split = rdd.splits(job.partitions(0))
+ val split = rdd.partitions(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
@@ -672,7 +672,7 @@ class DAGScheduler(
return cached
}
// If the RDD has some placement preferences (as is the case for input RDDs), get those
- val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList
+ val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
if (rddPrefs != Nil) {
return rddPrefs
}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index 8cd4c661eb..1721f78f48 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -67,7 +67,7 @@ private[spark] class ResultTask[T, U](
var split = if (rdd == null) {
null
} else {
- rdd.splits(partition)
+ rdd.partitions(partition)
}
override def run(attemptId: Long): U = {
@@ -85,7 +85,7 @@ private[spark] class ResultTask[T, U](
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.splits(partition)
+ split = rdd.partitions(partition)
out.writeInt(stageId)
val bytes = ResultTask.serializeInfo(
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
@@ -107,6 +107,6 @@ private[spark] class ResultTask[T, U](
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
partition = in.readInt()
val outputId = in.readInt()
- split = in.readObject().asInstanceOf[Split]
+ split = in.readObject().asInstanceOf[Partition]
}
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index bed9f1864f..59ee3c0a09 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -86,12 +86,12 @@ private[spark] class ShuffleMapTask(
var split = if (rdd == null) {
null
} else {
- rdd.splits(partition)
+ rdd.partitions(partition)
}
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.splits(partition)
+ split = rdd.partitions(partition)
out.writeInt(stageId)
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
@@ -112,7 +112,7 @@ private[spark] class ShuffleMapTask(
dep = dep_
partition = in.readInt()
generation = in.readLong()
- split = in.readObject().asInstanceOf[Split]
+ split = in.readObject().asInstanceOf[Partition]
}
override def run(attemptId: Long): MapStatus = {
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
index 374114d870..552061e46b 100644
--- a/core/src/main/scala/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -28,7 +28,7 @@ private[spark] class Stage(
extends Logging {
val isShuffleMap = shuffleDep != None
- val numPartitions = rdd.splits.size
+ val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 1e4fbdb874..d9c2f9517b 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -11,6 +11,7 @@ import spark.TaskState.TaskState
import spark.scheduler._
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
+import java.util.{TimerTask, Timer}
/**
* The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
@@ -22,6 +23,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// How often to check for speculative tasks
val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
+ // Threshold above which we warn user initial TaskSet may be starved
+ val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
val activeTaskSets = new HashMap[String, TaskSetManager]
var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
@@ -30,6 +33,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val taskIdToExecutorId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
+ var hasReceivedTask = false
+ var hasLaunchedTask = false
+ val starvationTimer = new Timer(true)
+
// Incrementing Mesos task IDs
val nextTaskId = new AtomicLong(0)
@@ -94,6 +101,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
activeTaskSets(taskSet.id) = manager
activeTaskSetsQueue += manager
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
+
+ if (hasReceivedTask == false) {
+ starvationTimer.scheduleAtFixedRate(new TimerTask() {
+ override def run() {
+ if (!hasLaunchedTask) {
+ logWarning("Initial job has not accepted any resources; " +
+ "check your cluster UI to ensure that workers are registered")
+ } else {
+ this.cancel()
+ }
+ }
+ }, STARVATION_TIMEOUT, STARVATION_TIMEOUT)
+ }
+ hasReceivedTask = true;
}
backend.reviveOffers()
}
@@ -150,6 +171,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
} while (launchedTask)
}
+ if (tasks.size > 0) {
+ hasLaunchedTask = true
+ }
return tasks
}
}
@@ -235,7 +259,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
override def defaultParallelism() = backend.defaultParallelism()
-
+
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
diff --git a/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala
index bba7de6a65..8bf838209f 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala
@@ -12,10 +12,10 @@ class ExecutorLossReason(val message: String) {
private[spark]
case class ExecutorExited(val exitCode: Int)
- extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) {
+ extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) {
}
private[spark]
case class SlaveLost(_message: String = "Slave lost")
- extends ExecutorLossReason(_message) {
+ extends ExecutorLossReason(_message) {
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 59ff8bcb90..bb289c9cf3 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -2,14 +2,14 @@ package spark.scheduler.cluster
import spark.{Utils, Logging, SparkContext}
import spark.deploy.client.{Client, ClientListener}
-import spark.deploy.{Command, JobDescription}
+import spark.deploy.{Command, ApplicationDescription}
import scala.collection.mutable.HashMap
private[spark] class SparkDeploySchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
master: String,
- jobName: String)
+ appName: String)
extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
with ClientListener
with Logging {
@@ -29,10 +29,11 @@ private[spark] class SparkDeploySchedulerBackend(
StandaloneSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
- val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone"))
- val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome)
+ val sparkHome = sc.getSparkHome().getOrElse(
+ throw new IllegalArgumentException("must supply spark home for spark standalone"))
+ val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome)
- client = new Client(sc.env.actorSystem, master, jobDesc, this)
+ client = new Client(sc.env.actorSystem, master, appDesc, this)
client.start()
}
@@ -45,8 +46,8 @@ private[spark] class SparkDeploySchedulerBackend(
}
}
- override def connected(jobId: String) {
- logInfo("Connected to Spark cluster with job ID " + jobId)
+ override def connected(appId: String) {
+ logInfo("Connected to Spark cluster with app ID " + appId)
}
override def disconnected() {
@@ -67,6 +68,6 @@ private[spark] class SparkDeploySchedulerBackend(
case None => SlaveLost(message)
}
logInfo("Executor %s removed: %s".format(executorId, message))
- scheduler.executorLost(executorId, reason)
+ removeExecutor(executorId, reason.toString)
}
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
index da7dcf4b6b..d766067824 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala
@@ -37,3 +37,6 @@ object StatusUpdate {
// Internal messages in driver
private[spark] case object ReviveOffers extends StandaloneClusterMessage
private[spark] case object StopDriver extends StandaloneClusterMessage
+
+private[spark] case class RemoveExecutor(executorId: String, reason: String)
+ extends StandaloneClusterMessage
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index 082022be1c..7a428e3361 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -68,6 +68,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
sender ! true
context.stop(self)
+ case RemoveExecutor(executorId, reason) =>
+ removeExecutor(executorId, reason)
+ sender ! true
+
case Terminated(actor) =>
actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated"))
@@ -100,16 +104,18 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Remove a disconnected slave from the cluster
def removeExecutor(executorId: String, reason: String) {
- logInfo("Slave " + executorId + " disconnected, so removing it")
- val numCores = freeCores(executorId)
- actorToExecutorId -= executorActor(executorId)
- addressToExecutorId -= executorAddress(executorId)
- executorActor -= executorId
- executorHost -= executorId
- freeCores -= executorId
- executorHost -= executorId
- totalCoreCount.addAndGet(-numCores)
- scheduler.executorLost(executorId, SlaveLost(reason))
+ if (executorActor.contains(executorId)) {
+ logInfo("Executor " + executorId + " disconnected, so removing it")
+ val numCores = freeCores(executorId)
+ actorToExecutorId -= executorActor(executorId)
+ addressToExecutorId -= executorAddress(executorId)
+ executorActor -= executorId
+ executorHost -= executorId
+ freeCores -= executorId
+ executorHost -= executorId
+ totalCoreCount.addAndGet(-numCores)
+ scheduler.executorLost(executorId, SlaveLost(reason))
+ }
}
}
@@ -139,7 +145,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
} catch {
case e: Exception =>
- throw new SparkException("Error stopping standalone scheduler's master actor", e)
+ throw new SparkException("Error stopping standalone scheduler's driver actor", e)
}
}
@@ -147,7 +153,20 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
driverActor ! ReviveOffers
}
- override def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2)
+ override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
+ .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2))
+
+ // Called by subclasses when notified of a lost worker
+ def removeExecutor(executorId: String, reason: String) {
+ try {
+ val timeout = 5.seconds
+ val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
+ Await.result(future, timeout)
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Error notifying standalone scheduler's driver actor", e)
+ }
+ }
}
private[spark] object StandaloneSchedulerBackend {
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
index b481ec0a72..f4a2994b6d 100644
--- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
@@ -28,7 +28,7 @@ private[spark] class CoarseMesosSchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
master: String,
- frameworkName: String)
+ appName: String)
extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
with MScheduler
with Logging {
@@ -76,7 +76,7 @@ private[spark] class CoarseMesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = CoarseMesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(frameworkName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try { {
val ret = driver.run()
@@ -239,7 +239,11 @@ private[spark] class CoarseMesosSchedulerBackend(
override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {
logInfo("Mesos slave lost: " + slaveId.getValue)
synchronized {
- slaveIdsWithExecutors -= slaveId.getValue
+ if (slaveIdsWithExecutors.contains(slaveId.getValue)) {
+ // Note that the slave ID corresponds to the executor ID on that slave
+ slaveIdsWithExecutors -= slaveId.getValue
+ removeExecutor(slaveId.getValue, "Mesos slave lost")
+ }
}
}
diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
index 300766d0f5..ca7fab4cc5 100644
--- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
@@ -24,7 +24,7 @@ private[spark] class MesosSchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
master: String,
- frameworkName: String)
+ appName: String)
extends SchedulerBackend
with MScheduler
with Logging {
@@ -49,7 +49,7 @@ private[spark] class MesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = MesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(frameworkName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try {
val ret = driver.run()
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index d9838f65ab..266191b05f 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -513,7 +513,7 @@ class BlockManager(
}
}
- // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // Partition local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
@@ -585,7 +585,7 @@ class BlockManager(
resultsGotten += 1
val result = results.take()
bytesInFlight -= result.size
- if (!fetchRequests.isEmpty &&
+ while (!fetchRequests.isEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index 7e5b820cbb..ddbf8821ad 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -178,7 +178,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
- localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
+ try {
+ localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
+ } catch {
+ case t: Throwable => logError("Exception while deleting local spark dirs", t)
+ }
}
})
}
diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala
index 5f72b67b2c..dec47a9d41 100644
--- a/core/src/main/scala/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/spark/storage/StorageUtils.scala
@@ -63,7 +63,7 @@ object StorageUtils {
val rddName = Option(rdd.name).getOrElse(rddKey)
val rddStorageLevel = rdd.getStorageLevel
- RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.splits.size, memSize, diskSize)
+ RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.partitions.size, memSize, diskSize)
}.toArray
}
diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala
index a342d378ff..dafa906712 100644
--- a/core/src/main/scala/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/spark/util/MetadataCleaner.scala
@@ -38,7 +38,7 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging
object MetadataCleaner {
- def getDelaySeconds = System.getProperty("spark.cleaner.delay", "-1").toInt
- def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.delay", delay.toString) }
+ def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt
+ def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) }
}
diff --git a/core/src/main/scala/spark/util/Vector.scala b/core/src/main/scala/spark/util/Vector.scala
index 03559751bc..835822edb2 100644
--- a/core/src/main/scala/spark/util/Vector.scala
+++ b/core/src/main/scala/spark/util/Vector.scala
@@ -11,12 +11,16 @@ class Vector(val elements: Array[Double]) extends Serializable {
return Vector(length, i => this(i) + other(i))
}
+ def add(other: Vector) = this + other
+
def - (other: Vector): Vector = {
if (length != other.length)
throw new IllegalArgumentException("Vectors of different length")
return Vector(length, i => this(i) - other(i))
}
+ def subtract(other: Vector) = this - other
+
def dot(other: Vector): Double = {
if (length != other.length)
throw new IllegalArgumentException("Vectors of different length")
@@ -61,10 +65,16 @@ class Vector(val elements: Array[Double]) extends Serializable {
this
}
+ def addInPlace(other: Vector) = this +=other
+
def * (scale: Double): Vector = Vector(length, i => this(i) * scale)
+ def multiply (d: Double) = this * d
+
def / (d: Double): Vector = this * (1 / d)
+ def divide (d: Double) = this / d
+
def unary_- = this * -1
def sum = elements.reduceLeft(_ + _)
diff --git a/core/src/main/twirl/spark/deploy/master/app_details.scala.html b/core/src/main/twirl/spark/deploy/master/app_details.scala.html
new file mode 100644
index 0000000000..301a7e2124
--- /dev/null
+++ b/core/src/main/twirl/spark/deploy/master/app_details.scala.html
@@ -0,0 +1,40 @@
+@(app: spark.deploy.master.ApplicationInfo)
+
+@spark.common.html.layout(title = "Application Details") {
+
+ <!-- Application Details -->
+ <div class="row">
+ <div class="span12">
+ <ul class="unstyled">
+ <li><strong>ID:</strong> @app.id</li>
+ <li><strong>Description:</strong> @app.desc.name</li>
+ <li><strong>User:</strong> @app.desc.user</li>
+ <li><strong>Cores:</strong>
+ @app.desc.cores
+ (@app.coresGranted Granted
+ @if(app.desc.cores == Integer.MAX_VALUE) {
+
+ } else {
+ , @app.coresLeft
+ }
+ )
+ </li>
+ <li><strong>Memory per Slave:</strong> @app.desc.memoryPerSlave</li>
+ <li><strong>Submit Date:</strong> @app.submitDate</li>
+ <li><strong>State:</strong> @app.state</li>
+ </ul>
+ </div>
+ </div>
+
+ <hr/>
+
+ <!-- Executors -->
+ <div class="row">
+ <div class="span12">
+ <h3> Executor Summary </h3>
+ <br/>
+ @executors_table(app.executors.values.toList)
+ </div>
+ </div>
+
+}
diff --git a/core/src/main/twirl/spark/deploy/master/app_row.scala.html b/core/src/main/twirl/spark/deploy/master/app_row.scala.html
new file mode 100644
index 0000000000..feb306f35c
--- /dev/null
+++ b/core/src/main/twirl/spark/deploy/master/app_row.scala.html
@@ -0,0 +1,20 @@
+@(app: spark.deploy.master.ApplicationInfo)
+
+@import spark.Utils
+@import spark.deploy.WebUI.formatDate
+@import spark.deploy.WebUI.formatDuration
+
+<tr>
+ <td>
+ <a href="app?appId=@(app.id)">@app.id</a>
+ </td>
+ <td>@app.desc.name</td>
+ <td>
+ @app.coresGranted
+ </td>
+ <td>@Utils.memoryMegabytesToString(app.desc.memoryPerSlave)</td>
+ <td>@formatDate(app.submitDate)</td>
+ <td>@app.desc.user</td>
+ <td>@app.state.toString()</td>
+ <td>@formatDuration(app.duration)</td>
+</tr>
diff --git a/core/src/main/twirl/spark/deploy/master/job_table.scala.html b/core/src/main/twirl/spark/deploy/master/app_table.scala.html
index d267d6e85e..f789cee0f1 100644
--- a/core/src/main/twirl/spark/deploy/master/job_table.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/app_table.scala.html
@@ -1,9 +1,9 @@
-@(jobs: Array[spark.deploy.master.JobInfo])
+@(apps: Array[spark.deploy.master.ApplicationInfo])
<table class="table table-bordered table-striped table-condensed sortable">
<thead>
<tr>
- <th>JobID</th>
+ <th>ID</th>
<th>Description</th>
<th>Cores</th>
<th>Memory per Node</th>
@@ -14,8 +14,8 @@
</tr>
</thead>
<tbody>
- @for(j <- jobs) {
- @job_row(j)
+ @for(j <- apps) {
+ @app_row(j)
}
</tbody>
</table>
diff --git a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
index 784d692fc2..d2d80fad48 100644
--- a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html
@@ -9,7 +9,7 @@
<td>@executor.memory</td>
<td>@executor.state</td>
<td>
- <a href="@(executor.worker.webUiAddress)/log?jobId=@(executor.job.id)&executorId=@(executor.id)&logType=stdout">stdout</a>
- <a href="@(executor.worker.webUiAddress)/log?jobId=@(executor.job.id)&executorId=@(executor.id)&logType=stderr">stderr</a>
+ <a href="@(executor.worker.webUiAddress)/log?appId=@(executor.application.id)&executorId=@(executor.id)&logType=stdout">stdout</a>
+ <a href="@(executor.worker.webUiAddress)/log?appId=@(executor.application.id)&executorId=@(executor.id)&logType=stderr">stderr</a>
</td>
-</tr> \ No newline at end of file
+</tr>
diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html
index 285645c389..ac51a39a51 100644
--- a/core/src/main/twirl/spark/deploy/master/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/index.scala.html
@@ -2,19 +2,19 @@
@import spark.deploy.master._
@import spark.Utils
-@spark.common.html.layout(title = "Spark Master on " + state.uri) {
-
+@spark.common.html.layout(title = "Spark Master on " + state.host) {
+
<!-- Cluster Details -->
<div class="row">
<div class="span12">
<ul class="unstyled">
- <li><strong>URL:</strong> spark://@(state.uri)</li>
+ <li><strong>URL:</strong> @(state.uri)</li>
<li><strong>Workers:</strong> @state.workers.size </li>
<li><strong>Cores:</strong> @{state.workers.map(_.cores).sum} Total,
@{state.workers.map(_.coresUsed).sum} Used</li>
<li><strong>Memory:</strong> @{Utils.memoryMegabytesToString(state.workers.map(_.memory).sum)} Total,
@{Utils.memoryMegabytesToString(state.workers.map(_.memoryUsed).sum)} Used</li>
- <li><strong>Jobs:</strong> @state.activeJobs.size Running, @state.completedJobs.size Completed </li>
+ <li><strong>Applications:</strong> @state.activeApps.size Running, @state.completedApps.size Completed </li>
</ul>
</div>
</div>
@@ -22,7 +22,7 @@
<!-- Worker Summary -->
<div class="row">
<div class="span12">
- <h3> Cluster Summary </h3>
+ <h3> Workers </h3>
<br/>
@worker_table(state.workers.sortBy(_.id))
</div>
@@ -30,23 +30,23 @@
<hr/>
- <!-- Job Summary (Running) -->
+ <!-- App Summary (Running) -->
<div class="row">
<div class="span12">
- <h3> Running Jobs </h3>
+ <h3> Running Applications </h3>
<br/>
- @job_table(state.activeJobs.sortBy(_.startTime).reverse)
+ @app_table(state.activeApps.sortBy(_.startTime).reverse)
</div>
</div>
<hr/>
- <!-- Job Summary (Completed) -->
+ <!-- App Summary (Completed) -->
<div class="row">
<div class="span12">
- <h3> Completed Jobs </h3>
+ <h3> Completed Applications </h3>
<br/>
- @job_table(state.completedJobs.sortBy(_.endTime).reverse)
+ @app_table(state.completedApps.sortBy(_.endTime).reverse)
</div>
</div>
diff --git a/core/src/main/twirl/spark/deploy/master/job_details.scala.html b/core/src/main/twirl/spark/deploy/master/job_details.scala.html
deleted file mode 100644
index d02a51b214..0000000000
--- a/core/src/main/twirl/spark/deploy/master/job_details.scala.html
+++ /dev/null
@@ -1,40 +0,0 @@
-@(job: spark.deploy.master.JobInfo)
-
-@spark.common.html.layout(title = "Job Details") {
-
- <!-- Job Details -->
- <div class="row">
- <div class="span12">
- <ul class="unstyled">
- <li><strong>ID:</strong> @job.id</li>
- <li><strong>Description:</strong> @job.desc.name</li>
- <li><strong>User:</strong> @job.desc.user</li>
- <li><strong>Cores:</strong>
- @job.desc.cores
- (@job.coresGranted Granted
- @if(job.desc.cores == Integer.MAX_VALUE) {
-
- } else {
- , @job.coresLeft
- }
- )
- </li>
- <li><strong>Memory per Slave:</strong> @job.desc.memoryPerSlave</li>
- <li><strong>Submit Date:</strong> @job.submitDate</li>
- <li><strong>State:</strong> @job.state</li>
- </ul>
- </div>
- </div>
-
- <hr/>
-
- <!-- Executors -->
- <div class="row">
- <div class="span12">
- <h3> Executor Summary </h3>
- <br/>
- @executors_table(job.executors.values.toList)
- </div>
- </div>
-
-}
diff --git a/core/src/main/twirl/spark/deploy/master/job_row.scala.html b/core/src/main/twirl/spark/deploy/master/job_row.scala.html
deleted file mode 100644
index 7c466a6a2c..0000000000
--- a/core/src/main/twirl/spark/deploy/master/job_row.scala.html
+++ /dev/null
@@ -1,20 +0,0 @@
-@(job: spark.deploy.master.JobInfo)
-
-@import spark.Utils
-@import spark.deploy.WebUI.formatDate
-@import spark.deploy.WebUI.formatDuration
-
-<tr>
- <td>
- <a href="job?jobId=@(job.id)">@job.id</a>
- </td>
- <td>@job.desc.name</td>
- <td>
- @job.coresGranted
- </td>
- <td>@Utils.memoryMegabytesToString(job.desc.memoryPerSlave)</td>
- <td>@formatDate(job.submitDate)</td>
- <td>@job.desc.user</td>
- <td>@job.state.toString()</td>
- <td>@formatDuration(job.duration)</td>
-</tr>
diff --git a/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html b/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html
index ea9542461e..dad0a89080 100644
--- a/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html
@@ -8,13 +8,13 @@
<td>@Utils.memoryMegabytesToString(executor.memory)</td>
<td>
<ul class="unstyled">
- <li><strong>ID:</strong> @executor.jobId</li>
- <li><strong>Name:</strong> @executor.jobDesc.name</li>
- <li><strong>User:</strong> @executor.jobDesc.user</li>
+ <li><strong>ID:</strong> @executor.appId</li>
+ <li><strong>Name:</strong> @executor.appDesc.name</li>
+ <li><strong>User:</strong> @executor.appDesc.user</li>
</ul>
</td>
<td>
- <a href="log?jobId=@(executor.jobId)&executorId=@(executor.execId)&logType=stdout">stdout</a>
- <a href="log?jobId=@(executor.jobId)&executorId=@(executor.execId)&logType=stderr">stderr</a>
+ <a href="log?appId=@(executor.appId)&executorId=@(executor.execId)&logType=stdout">stdout</a>
+ <a href="log?appId=@(executor.appId)&executorId=@(executor.execId)&logType=stderr">stderr</a>
</td>
</tr>
diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html
index 1d703dae58..c39f769a73 100644
--- a/core/src/main/twirl/spark/deploy/worker/index.scala.html
+++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html
@@ -1,8 +1,8 @@
@(worker: spark.deploy.WorkerState)
@import spark.Utils
-@spark.common.html.layout(title = "Spark Worker on " + worker.uri) {
-
+@spark.common.html.layout(title = "Spark Worker on " + worker.host) {
+
<!-- Worker Details -->
<div class="row">
<div class="span12">
@@ -10,12 +10,12 @@
<li><strong>ID:</strong> @worker.workerId</li>
<li><strong>
Master URL:</strong> @worker.masterUrl
- (WebUI at <a href="@worker.masterWebUiUrl">@worker.masterWebUiUrl</a>)
</li>
<li><strong>Cores:</strong> @worker.cores (@worker.coresUsed Used)</li>
<li><strong>Memory:</strong> @{Utils.memoryMegabytesToString(worker.memory)}
(@{Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</li>
</ul>
+ <p><a href="@worker.masterWebUiUrl">Back to Master</a></p>
</div>
</div>
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index 0b74607fb8..ca385972fb 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -34,7 +34,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithSplitRDD(r,
+ testCheckpointing(r => new MapPartitionsWithIndexRDD(r,
(i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
@@ -43,14 +43,14 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
test("ParallelCollection") {
val parCollection = sc.makeRDD(1 to 4, 2)
- val numSplits = parCollection.splits.size
+ val numPartitions = parCollection.partitions.size
parCollection.checkpoint()
assert(parCollection.dependencies === Nil)
val result = parCollection.collect()
assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
assert(parCollection.dependencies != Nil)
- assert(parCollection.splits.length === numSplits)
- assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList)
+ assert(parCollection.partitions.length === numPartitions)
+ assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList)
assert(parCollection.collect() === result)
}
@@ -59,13 +59,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
val blockManager = SparkEnv.get.blockManager
blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
val blockRDD = new BlockRDD[String](sc, Array(blockId))
- val numSplits = blockRDD.splits.size
+ val numPartitions = blockRDD.partitions.size
blockRDD.checkpoint()
val result = blockRDD.collect()
assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
assert(blockRDD.dependencies != Nil)
- assert(blockRDD.splits.length === numSplits)
- assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList)
+ assert(blockRDD.partitions.length === numPartitions)
+ assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList)
assert(blockRDD.collect() === result)
}
@@ -79,9 +79,9 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
test("UnionRDD") {
def otherRDD = sc.makeRDD(1 to 10, 1)
- // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed.
+ // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed.
// Current implementation of UnionRDD has transient reference to parent RDDs,
- // so only the splits will reduce in serialized size, not the RDD.
+ // so only the partitions will reduce in serialized size, not the RDD.
testCheckpointing(_.union(otherRDD), false, true)
testParentCheckpointing(_.union(otherRDD), false, true)
}
@@ -91,21 +91,21 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(new CartesianRDD(sc, _, otherRDD))
// Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
- // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
- // so only the RDD will reduce in serialized size, not the splits.
+ // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the partitions.
testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false)
- // Test that the CartesianRDD updates parent splits (CartesianRDD.s1/s2) after
- // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+ // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after
+ // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
// Note that this test is very specific to the current implementation of CartesianRDD.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
ones.checkpoint() // checkpoint that MappedRDD
val cartesian = new CartesianRDD(sc, ones, ones)
val splitBeforeCheckpoint =
- serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
cartesian.count() // do the checkpointing
val splitAfterCheckpoint =
- serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
assert(
(splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) &&
(splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2),
@@ -114,27 +114,27 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
test("CoalescedRDD") {
- testCheckpointing(new CoalescedRDD(_, 2))
+ testCheckpointing(_.coalesce(2))
// Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
- // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
- // so only the RDD will reduce in serialized size, not the splits.
- testParentCheckpointing(new CoalescedRDD(_, 2), true, false)
+ // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the partitions.
+ testParentCheckpointing(_.coalesce(2), true, false)
- // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after
- // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
- // Note that this test is very specific to the current implementation of CoalescedRDDSplits
+ // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after
+ // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
+ // Note that this test is very specific to the current implementation of CoalescedRDDPartitions
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
ones.checkpoint() // checkpoint that MappedRDD
val coalesced = new CoalescedRDD(ones, 2)
val splitBeforeCheckpoint =
- serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
coalesced.count() // do the checkpointing
val splitAfterCheckpoint =
- serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
assert(
splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head,
- "CoalescedRDDSplit.parents not updated after parent RDD checkpointed"
+ "CoalescedRDDPartition.parents not updated after parent RDD checkpointed"
)
}
@@ -156,30 +156,40 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
// Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed
- // Current implementation of ZippedRDDSplit has transient references to parent RDDs,
- // so only the RDD will reduce in serialized size, not the splits.
+ // Current implementation of ZippedRDDPartitions has transient references to parent RDDs,
+ // so only the RDD will reduce in serialized size, not the partitions.
testParentCheckpointing(
rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
}
+ test("CheckpointRDD with zero partitions") {
+ val rdd = new BlockRDD[Int](sc, Array[String]())
+ assert(rdd.partitions.size === 0)
+ assert(rdd.isCheckpointed === false)
+ rdd.checkpoint()
+ assert(rdd.count() === 0)
+ assert(rdd.isCheckpointed === true)
+ assert(rdd.partitions.size === 0)
+ }
+
/**
* Test checkpointing of the final RDD generated by the given operation. By default,
* this method tests whether the size of serialized RDD has reduced after checkpointing or not.
- * It can also test whether the size of serialized RDD splits has reduced after checkpointing or
- * not, but this is not done by default as usually the splits do not refer to any RDD and
+ * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or
+ * not, but this is not done by default as usually the partitions do not refer to any RDD and
* therefore never store the lineage.
*/
def testCheckpointing[U: ClassManifest](
op: (RDD[Int]) => RDD[U],
testRDDSize: Boolean = true,
- testRDDSplitSize: Boolean = false
+ testRDDPartitionSize: Boolean = false
) {
// Generate the final RDD using given RDD operation
val baseRDD = generateLongLineageRDD()
val operatedRDD = op(baseRDD)
val parentRDD = operatedRDD.dependencies.headOption.orNull
val rddType = operatedRDD.getClass.getSimpleName
- val numSplits = operatedRDD.splits.length
+ val numPartitions = operatedRDD.partitions.length
// Find serialized sizes before and after the checkpoint
val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
@@ -193,11 +203,11 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
// Test whether dependencies have been changed from its earlier parent RDD
assert(operatedRDD.dependencies.head.rdd != parentRDD)
- // Test whether the splits have been changed to the new Hadoop splits
- assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList)
+ // Test whether the partitions have been changed to the new Hadoop partitions
+ assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList)
- // Test whether the number of splits is same as before
- assert(operatedRDD.splits.length === numSplits)
+ // Test whether the number of partitions is same as before
+ assert(operatedRDD.partitions.length === numPartitions)
// Test whether the data in the checkpointed RDD is same as original
assert(operatedRDD.collect() === result)
@@ -215,18 +225,18 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
)
}
- // Test whether serialized size of the splits has reduced. If the splits
- // do not have any non-transient reference to another RDD or another RDD's splits, it
+ // Test whether serialized size of the partitions has reduced. If the partitions
+ // do not have any non-transient reference to another RDD or another RDD's partitions, it
// does not refer to a lineage and therefore may not reduce in size after checkpointing.
- // However, if the original splits before checkpointing do refer to a parent RDD, the splits
+ // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions
// must be forgotten after checkpointing (to remove all reference to parent RDDs) and
- // replaced with the HadoopSplits of the checkpointed RDD.
- if (testRDDSplitSize) {
- logInfo("Size of " + rddType + " splits "
+ // replaced with the HadooPartitions of the checkpointed RDD.
+ if (testRDDPartitionSize) {
+ logInfo("Size of " + rddType + " partitions "
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]")
assert(
splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
- "Size of " + rddType + " splits did not reduce after checkpointing " +
+ "Size of " + rddType + " partitions did not reduce after checkpointing " +
"[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
)
}
@@ -235,13 +245,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
/**
* Test whether checkpointing of the parent of the generated RDD also
* truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
- * RDDs splits. So even if the parent RDD is checkpointed and its splits changed,
- * this RDD will remember the splits and therefore potentially the whole lineage.
+ * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
+ * this RDD will remember the partitions and therefore potentially the whole lineage.
*/
def testParentCheckpointing[U: ClassManifest](
op: (RDD[Int]) => RDD[U],
testRDDSize: Boolean,
- testRDDSplitSize: Boolean
+ testRDDPartitionSize: Boolean
) {
// Generate the final RDD using given RDD operation
val baseRDD = generateLongLineageRDD()
@@ -250,9 +260,9 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
val rddType = operatedRDD.getClass.getSimpleName
val parentRDDType = parentRDD.getClass.getSimpleName
- // Get the splits and dependencies of the parent in case they're lazily computed
+ // Get the partitions and dependencies of the parent in case they're lazily computed
parentRDD.dependencies
- parentRDD.splits
+ parentRDD.partitions
// Find serialized sizes before and after the checkpoint
val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
@@ -275,16 +285,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
)
}
- // Test whether serialized size of the splits has reduced because of its parent being
- // checkpointed. If the splits do not have any non-transient reference to another RDD
- // or another RDD's splits, it does not refer to a lineage and therefore may not reduce
- // in size after checkpointing. However, if the splits do refer to the *splits* of a parent
- // RDD, then these splits must update reference to the parent RDD splits as the parent RDD's
- // splits must have changed after checkpointing.
- if (testRDDSplitSize) {
+ // Test whether serialized size of the partitions has reduced because of its parent being
+ // checkpointed. If the partitions do not have any non-transient reference to another RDD
+ // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce
+ // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent
+ // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's
+ // partitions must have changed after checkpointing.
+ if (testRDDPartitionSize) {
assert(
splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
- "Size of " + rddType + " splits did not reduce after checkpointing parent " + parentRDDType +
+ "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType +
"[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
)
}
@@ -321,12 +331,12 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
/**
- * Get serialized sizes of the RDD and its splits, in order to test whether the size shrinks
+ * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks
* upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
*/
def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
(Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length,
- Utils.serialize(rdd.splits).length)
+ Utils.serialize(rdd.partitions).length)
}
/**
@@ -347,7 +357,7 @@ object CheckpointSuite {
def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = {
//println("First = " + first + ", second = " + second)
new CoGroupedRDD[K](
- Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]),
+ Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]),
part
).asInstanceOf[RDD[(K, Seq[Seq[V]])]]
}
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index 0e2585daa4..caa4ba3a37 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -217,6 +217,27 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
assert(grouped.collect.size === 1)
}
}
+
+ test("recover from node failures with replication") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ // Using more than two nodes so we don't have a symmetric communication pattern and might
+ // cache a partially correct list of peers.
+ sc = new SparkContext("local-cluster[3,1,512]", "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, false, false, false), 4)
+ data.persist(StorageLevel.MEMORY_ONLY_2)
+
+ assert(data.count === 4)
+ assert(data.map(markNodeIfIdentity).collect.size === 4)
+ assert(data.map(failOnMarkedIdentity).collect.size === 4)
+
+ // Create a new replicated RDD to make sure that cached peer information doesn't cause
+ // problems.
+ val data2 = sc.parallelize(Seq(true, true), 2).persist(StorageLevel.MEMORY_ONLY_2)
+ assert(data2.count === 2)
+ }
+ }
}
object DistributedSuite {
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 934e4c2f67..9ffe7c5f99 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -696,4 +696,28 @@ public class JavaAPISuite implements Serializable {
JavaRDD<Integer> recovered = sc.checkpointFile(rdd.getCheckpointFile().get());
Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
}
+
+ @Test
+ public void mapOnPairRDD() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1,2,3,4));
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i, i % 2);
+ }
+ });
+ JavaPairRDD<Integer, Integer> rdd3 = rdd2.map(
+ new PairFunction<Tuple2<Integer, Integer>, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Tuple2<Integer, Integer> in) throws Exception {
+ return new Tuple2<Integer, Integer>(in._2(), in._1());
+ }
+ });
+ Assert.assertEquals(Arrays.asList(
+ new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(0, 2),
+ new Tuple2<Integer, Integer>(1, 3),
+ new Tuple2<Integer, Integer>(0, 4)), rdd3.collect());
+
+ }
}
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index af1107cd19..60db759c25 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -84,10 +84,10 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
assert(grouped4.groupByKey(3).partitioner != grouped4.partitioner)
assert(grouped4.groupByKey(4).partitioner === grouped4.partitioner)
- assert(grouped2.join(grouped4).partitioner === grouped2.partitioner)
- assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped2.partitioner)
- assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped2.partitioner)
- assert(grouped2.cogroup(grouped4).partitioner === grouped2.partitioner)
+ assert(grouped2.join(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner)
+ assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner)
assert(grouped2.join(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index fe7deb10d6..9739ba869b 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -33,6 +33,11 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
+ val partitionSumsWithIndex = nums.mapPartitionsWithIndex {
+ case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
+ }
+ assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7)))
+
intercept[UnsupportedOperationException] {
nums.filter(_ > 5).reduce(_ + _)
}
@@ -97,12 +102,12 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("caching with failures") {
sc = new SparkContext("local", "test")
- val onlySplit = new Split { override def index: Int = 0 }
+ val onlySplit = new Partition { override def index: Int = 0 }
var shouldFail = true
val rdd = new RDD[Int](sc, Nil) {
- override def getSplits: Array[Split] = Array(onlySplit)
+ override def getPartitions: Array[Partition] = Array(onlySplit)
override val getDependencies = List[Dependency[_]]()
- override def compute(split: Split, context: TaskContext): Iterator[Int] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
if (shouldFail) {
throw new Exception("injected failure")
} else {
@@ -122,7 +127,7 @@ class RDDSuite extends FunSuite with LocalSparkContext {
sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
- val coalesced1 = new CoalescedRDD(data, 2)
+ val coalesced1 = data.coalesce(2)
assert(coalesced1.collect().toList === (1 to 10).toList)
assert(coalesced1.glom().collect().map(_.toList).toList ===
List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
@@ -133,19 +138,19 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList ===
List(5, 6, 7, 8, 9))
- val coalesced2 = new CoalescedRDD(data, 3)
+ val coalesced2 = data.coalesce(3)
assert(coalesced2.collect().toList === (1 to 10).toList)
assert(coalesced2.glom().collect().map(_.toList).toList ===
List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10)))
- val coalesced3 = new CoalescedRDD(data, 10)
+ val coalesced3 = data.coalesce(10)
assert(coalesced3.collect().toList === (1 to 10).toList)
assert(coalesced3.glom().collect().map(_.toList).toList ===
(1 to 10).map(x => List(x)).toList)
// If we try to coalesce into more partitions than the original RDD, it should just
// keep the original number of partitions.
- val coalesced4 = new CoalescedRDD(data, 20)
+ val coalesced4 = data.coalesce(20)
assert(coalesced4.collect().toList === (1 to 10).toList)
assert(coalesced4.glom().collect().map(_.toList).toList ===
(1 to 10).map(x => List(x)).toList)
@@ -168,7 +173,7 @@ class RDDSuite extends FunSuite with LocalSparkContext {
val data = sc.parallelize(1 to 10, 10)
// Note that split number starts from 0, so > 8 means only 10th partition left.
val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
- assert(prunedRdd.splits.size === 1)
+ assert(prunedRdd.partitions.size === 1)
val prunedData = prunedRdd.collect()
assert(prunedData.size === 1)
assert(prunedData(0) === 10)
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 3493b9511f..8411291b2c 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -1,6 +1,7 @@
package spark
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
@@ -98,6 +99,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
val sums = pairs.reduceByKey(_+_, 10).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
+
+ test("reduceByKey with partitioner") {
+ sc = new SparkContext("local", "test")
+ val p = new Partitioner() {
+ def numPartitions = 2
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
+ val sums = pairs.reduceByKey(_+_)
+ assert(sums.collect().toSet === Set((1, 4), (0, 1)))
+ assert(sums.partitioner === Some(p))
+ // count the dependencies to make sure there is only 1 ShuffledRDD
+ val deps = new HashSet[RDD[_]]()
+ def visit(r: RDD[_]) {
+ for (dep <- r.dependencies) {
+ deps += dep.rdd
+ visit(dep.rdd)
+ }
+ }
+ visit(sums)
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+ }
test("join") {
sc = new SparkContext("local", "test")
@@ -199,7 +222,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
sc = new SparkContext("local", "test")
val emptyDir = Files.createTempDir()
val file = sc.textFile(emptyDir.getAbsolutePath)
- assert(file.splits.size == 0)
+ assert(file.partitions.size == 0)
assert(file.collect().toList === Nil)
// Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
@@ -211,6 +234,51 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(rdd.keys.collect().toList === List(1, 2))
assert(rdd.values.collect().toList === List("a", "b"))
}
+
+ test("default partitioner uses partition size") {
+ sc = new SparkContext("local", "test")
+ // specify 2000 partitions
+ val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
+ // do a map, which loses the partitioner
+ val b = a.map(a => (a, (a * 2).toString))
+ // then a group by, and see we didn't revert to 2 partitions
+ val c = b.groupByKey()
+ assert(c.partitions.size === 2000)
+ }
+
+ test("default partitioner uses largest partitioner") {
+ sc = new SparkContext("local", "test")
+ val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
+ val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
+ val c = a.join(b)
+ assert(c.partitions.size === 2000)
+ }
+
+ test("subtract") {
+ sc = new SparkContext("local", "test")
+ val a = sc.parallelize(Array(1, 2, 3), 2)
+ val b = sc.parallelize(Array(2, 3, 4), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set(1))
+ assert(c.partitions.size === a.partitions.size)
+ }
+
+ test("subtract with narrow dependency") {
+ sc = new SparkContext("local", "test")
+ // use a deterministic partitioner
+ val p = new Partitioner() {
+ def numPartitions = 5
+ def getPartition(key: Any) = key.asInstanceOf[Int]
+ }
+ // partitionBy so we have a narrow dependency
+ val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+ println(sc.runJob(a, (i: Iterator[(Int, String)]) => i.toList).toList)
+ // more partitions/no partitioner so a shuffle dependency
+ val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+ val c = a.subtract(b)
+ assert(c.collect().toSet === Set((1, "a"), (3, "c")))
+ assert(c.partitioner.get === p)
+ }
}
object ShuffleSuite {
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index edb8c839fc..495f957e53 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -19,7 +19,7 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey()
- assert(sorted.splits.size === 2)
+ assert(sorted.partitions.size === 2)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
@@ -29,17 +29,17 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey(true, 1)
- assert(sorted.splits.size === 1)
+ assert(sorted.partitions.size === 1)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
- test("large array with many splits") {
+ test("large array with many partitions") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey(true, 20)
- assert(sorted.splits.size === 20)
+ assert(sorted.partitions.size === 20)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
@@ -59,7 +59,7 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
- test("sort descending with many splits") {
+ test("sort descending with many partitions") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
diff --git a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala b/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala
index 450c69bd58..d27a2538e4 100644
--- a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala
+++ b/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala
@@ -1,4 +1,4 @@
-package spark
+package spark.rdd
import scala.collection.immutable.NumericRange
@@ -11,7 +11,7 @@ import org.scalacheck.Prop._
class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("one element per slice") {
val data = Array(1, 2, 3)
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1")
assert(slices(1).mkString(",") === "2")
@@ -20,14 +20,14 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("one slice") {
val data = Array(1, 2, 3)
- val slices = ParallelCollection.slice(data, 1)
+ val slices = ParallelCollectionRDD.slice(data, 1)
assert(slices.size === 1)
assert(slices(0).mkString(",") === "1,2,3")
}
test("equal slices") {
val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9)
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1,2,3")
assert(slices(1).mkString(",") === "4,5,6")
@@ -36,7 +36,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("non-equal slices") {
val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === "1,2,3")
assert(slices(1).mkString(",") === "4,5,6")
@@ -45,7 +45,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("splitting exclusive range") {
val data = 0 until 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === (0 to 32).mkString(","))
assert(slices(1).mkString(",") === (33 to 65).mkString(","))
@@ -54,7 +54,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("splitting inclusive range") {
val data = 0 to 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices(0).mkString(",") === (0 to 32).mkString(","))
assert(slices(1).mkString(",") === (33 to 66).mkString(","))
@@ -63,24 +63,24 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("empty data") {
val data = new Array[Int](0)
- val slices = ParallelCollection.slice(data, 5)
+ val slices = ParallelCollectionRDD.slice(data, 5)
assert(slices.size === 5)
for (slice <- slices) assert(slice.size === 0)
}
test("zero slices") {
val data = Array(1, 2, 3)
- intercept[IllegalArgumentException] { ParallelCollection.slice(data, 0) }
+ intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) }
}
test("negative number of slices") {
val data = Array(1, 2, 3)
- intercept[IllegalArgumentException] { ParallelCollection.slice(data, -5) }
+ intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) }
}
test("exclusive ranges sliced into ranges") {
val data = 1 until 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 99)
assert(slices.forall(_.isInstanceOf[Range]))
@@ -88,7 +88,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("inclusive ranges sliced into ranges") {
val data = 1 to 100
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 100)
assert(slices.forall(_.isInstanceOf[Range]))
@@ -97,7 +97,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("large ranges don't overflow") {
val N = 100 * 1000 * 1000
val data = 0 until N
- val slices = ParallelCollection.slice(data, 40)
+ val slices = ParallelCollectionRDD.slice(data, 40)
assert(slices.size === 40)
for (i <- 0 until 40) {
assert(slices(i).isInstanceOf[Range])
@@ -117,7 +117,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
(tuple: (List[Int], Int)) =>
val d = tuple._1
val n = tuple._2
- val slices = ParallelCollection.slice(d, n)
+ val slices = ParallelCollectionRDD.slice(d, n)
("n slices" |: slices.size == n) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1))
@@ -134,7 +134,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
} yield (a until b by step, n)
val prop = forAll(gen) {
case (d: Range, n: Int) =>
- val slices = ParallelCollection.slice(d, n)
+ val slices = ParallelCollectionRDD.slice(d, n)
("n slices" |: slices.size == n) &&
("all ranges" |: slices.forall(_.isInstanceOf[Range])) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
@@ -152,7 +152,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
} yield (a to b by step, n)
val prop = forAll(gen) {
case (d: Range, n: Int) =>
- val slices = ParallelCollection.slice(d, n)
+ val slices = ParallelCollectionRDD.slice(d, n)
("n slices" |: slices.size == n) &&
("all ranges" |: slices.forall(_.isInstanceOf[Range])) &&
("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) &&
@@ -163,7 +163,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("exclusive ranges of longs") {
val data = 1L until 100L
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 99)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
@@ -171,7 +171,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("inclusive ranges of longs") {
val data = 1L to 100L
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 100)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
@@ -179,7 +179,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("exclusive ranges of doubles") {
val data = 1.0 until 100.0 by 1.0
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 99)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
@@ -187,7 +187,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
test("inclusive ranges of doubles") {
val data = 1.0 to 100.0 by 1.0
- val slices = ParallelCollection.slice(data, 3)
+ val slices = ParallelCollectionRDD.slice(data, 3)
assert(slices.size === 3)
assert(slices.map(_.size).reduceLeft(_+_) === 100)
assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
index edc5a7dfba..07cccc7ce0 100644
--- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -14,7 +14,7 @@ import spark.MapOutputTracker
import spark.RDD
import spark.SparkContext
import spark.SparkException
-import spark.Split
+import spark.Partition
import spark.TaskContext
import spark.TaskEndReason
@@ -111,18 +111,18 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter {
* so we can test that DAGScheduler does not try to execute RDDs locally.
*/
private def makeRdd(
- numSplits: Int,
+ numPartitions: Int,
dependencies: List[Dependency[_]],
locations: Seq[Seq[String]] = Nil
): MyRDD = {
- val maxSplit = numSplits - 1
+ val maxPartition = numPartitions - 1
return new MyRDD(sc, dependencies) {
- override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] =
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
throw new RuntimeException("should not be reached")
- override def getSplits() = (0 to maxSplit).map(i => new Split {
+ override def getPartitions = (0 to maxPartition).map(i => new Partition {
override def index = i
}).toArray
- override def getPreferredLocations(split: Split): Seq[String] =
+ override def getPreferredLocations(split: Partition): Seq[String] =
if (locations.isDefinedAt(split.index))
locations(split.index)
else
@@ -196,9 +196,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter {
test("local job") {
val rdd = new MyRDD(sc, Nil) {
- override def compute(split: Split, context: TaskContext) = Array(42 -> 0).iterator
- override def getSplits() = Array(new Split { override def index = 0 })
- override def getPreferredLocations(split: Split) = Nil
+ override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
+ Array(42 -> 0).iterator
+ override def getPartitions = Array( new Partition { override def index = 0 } )
+ override def getPreferredLocations(split: Partition) = Nil
override def toString = "DAGSchedulerSuite Local RDD"
}
runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
@@ -397,4 +398,4 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter {
private def makeBlockManagerId(host: String): BlockManagerId =
BlockManagerId("exec-" + host, host, 12345)
-} \ No newline at end of file
+}
diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
index a5db7103f5..647bcaf860 100644
--- a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
@@ -5,7 +5,7 @@ import org.scalatest.BeforeAndAfter
import spark.TaskContext
import spark.RDD
import spark.SparkContext
-import spark.Split
+import spark.Partition
import spark.LocalSparkContext
class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
@@ -14,8 +14,8 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
var completed = false
sc = new SparkContext("local", "test")
val rdd = new RDD[String](sc, List()) {
- override def getSplits = Array[Split](StubSplit(0))
- override def compute(split: Split, context: TaskContext) = {
+ override def getPartitions = Array[Partition](StubPartition(0))
+ override def compute(split: Partition, context: TaskContext) = {
context.addOnCompleteCallback(() => completed = true)
sys.error("failed")
}
@@ -28,5 +28,5 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
assert(completed === true)
}
- case class StubSplit(val index: Int) extends Split
-} \ No newline at end of file
+ case class StubPartition(val index: Int) extends Partition
+}
diff --git a/docs/_config.yml b/docs/_config.yml
index 2bd2eecc86..09617e4a1e 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -7,3 +7,4 @@ SPARK_VERSION: 0.7.0-SNAPSHOT
SPARK_VERSION_SHORT: 0.7.0
SCALA_VERSION: 2.9.2
MESOS_VERSION: 0.9.0-incubating
+SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net
diff --git a/docs/configuration.md b/docs/configuration.md
index a7054b4321..04eb6daaa5 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -183,7 +183,7 @@ Apart from these, the following properties are also available, and may be useful
</tr>
<tr>
<td>spark.broadcast.factory</td>
- <td>spark.broadcast. HttpBroadcastFactory</td>
+ <td>spark.broadcast.HttpBroadcastFactory</td>
<td>
Which broadcast implementation to use.
</td>
@@ -198,6 +198,14 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
+ <td>spark.worker.timeout</td>
+ <td>60</td>
+ <td>
+ Number of seconds after which the standalone deploy master considers a worker lost if it
+ receives no heartbeats.
+ </td>
+</tr>
+<tr>
<td>spark.akka.frameSize</td>
<td>10</td>
<td>
@@ -218,7 +226,7 @@ Apart from these, the following properties are also available, and may be useful
<td>spark.akka.timeout</td>
<td>20</td>
<td>
- Communication timeout between Spark nodes.
+ Communication timeout between Spark nodes, in seconds.
</td>
</tr>
<tr>
@@ -236,10 +244,10 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
- <td>spark.cleaner.delay</td>
+ <td>spark.cleaner.ttl</td>
<td>(disable)</td>
<td>
- Duration (minutes) of how long Spark will remember any metadata (stages generated, tasks generated, etc.).
+ Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks generated, etc.).
Periodic cleanups will ensure that metadata older than this duration will be forgetten. This is
useful for running Spark for many hours / days (for example, running 24/7 in case of Spark Streaming
applications). Note that any RDD that persists in memory for more than this duration will be cleared as well.
diff --git a/docs/contributing-to-spark.md b/docs/contributing-to-spark.md
index c6e01c62d8..50feeb2d6c 100644
--- a/docs/contributing-to-spark.md
+++ b/docs/contributing-to-spark.md
@@ -15,7 +15,7 @@ The Spark team welcomes contributions in the form of GitHub pull requests. Here
But first, make sure that you have [configured a spark-env.sh](configuration.html) with at least
`SCALA_HOME`, as some of the tests try to spawn subprocesses using this.
- Add new unit tests for your code. We use [ScalaTest](http://www.scalatest.org/) for testing. Just add a new Suite in `core/src/test`, or methods to an existing Suite.
-- If you'd like to report a bug but don't have time to fix it, you can still post it to our [issues page](https://github.com/mesos/spark/issues), or email the [mailing list](http://www.spark-project.org/mailing-lists.html).
+- If you'd like to report a bug but don't have time to fix it, you can still post it to our [issue tracker]({{site.SPARK_ISSUE_TRACKER_URL}}), or email the [mailing list](http://www.spark-project.org/mailing-lists.html).
# Licensing of Contributions
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index 4e84d23edf..2012241a6a 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -87,7 +87,7 @@ By default, the `pyspark` shell creates SparkContext that runs jobs locally.
To connect to a non-local cluster, set the `MASTER` environment variable.
For example, to use the `pyspark` shell with a [standalone Spark cluster](spark-standalone.html):
-{% highlight shell %}
+{% highlight bash %}
$ MASTER=spark://IP:PORT ./pyspark
{% endhighlight %}
diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md
index 301b330a79..b98718a553 100644
--- a/docs/scala-programming-guide.md
+++ b/docs/scala-programming-guide.md
@@ -203,7 +203,7 @@ A complete list of transformations is available in the [RDD API doc](api/core/in
<tr><th>Action</th><th>Meaning</th></tr>
<tr>
<td> <b>reduce</b>(<i>func</i>) </td>
- <td> Aggregate the elements of the dataset using a function <i>func</i> (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel. </td>
+ <td> Aggregate the elements of the dataset using a function <i>func</i> (which takes two arguments and returns one). The function should be commutative and associative so that it can be computed correctly in parallel. </td>
</tr>
<tr>
<td> <b>collect</b>() </td>
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index bf296221b8..3986c0c79d 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -115,6 +115,14 @@ You can optionally configure the cluster further by setting environment variable
<td><code>SPARK_WORKER_WEBUI_PORT</code></td>
<td>Port for the worker web UI (default: 8081)</td>
</tr>
+ <tr>
+ <td><code>SPARK_DAEMON_MEMORY</code></td>
+ <td>Memory to allocate to the Spark master and worker daemons themselves (default: 512m)</td>
+ </tr>
+ <tr>
+ <td><code>SPARK_DAEMON_JAVA_OPTS</code></td>
+ <td>JVM options for the Spark master and worker daemons themselves (default: none)</td>
+ </tr>
</table>
diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md
new file mode 100644
index 0000000000..5476c00d02
--- /dev/null
+++ b/docs/streaming-custom-receivers.md
@@ -0,0 +1,101 @@
+---
+layout: global
+title: Tutorial - Spark Streaming, Plugging in a custom receiver.
+---
+
+A "Spark Streaming" receiver can be a simple network stream, streams of messages from a message queue, files etc. A receiver can also assume roles more than just receiving data like filtering, preprocessing, to name a few of the possibilities. The api to plug-in any user defined custom receiver is thus provided to encourage development of receivers which may be well suited to ones specific need.
+
+This guide shows the programming model and features by walking through a simple sample receiver and corresponding Spark Streaming application.
+
+
+## A quick and naive walk-through
+
+### Write a simple receiver
+
+This starts with implementing [Actor](#References)
+
+Following is a simple socket text-stream receiver, which is appearently overly simplified using Akka's socket.io api.
+
+{% highlight scala %}
+
+ class SocketTextStreamReceiver (host:String,
+ port:Int,
+ bytesToString: ByteString => String) extends Actor with Receiver {
+
+ override def preStart = IOManager(context.system).connect(host, port)
+
+ def receive = {
+ case IO.Read(socket, bytes) => pushBlock(bytesToString(bytes))
+ }
+
+ }
+
+
+{% endhighlight %}
+
+All we did here is mixed in trait Receiver and called pushBlock api method to push our blocks of data. Please refer to scala-docs of Receiver for more details.
+
+### A sample spark application
+
+* First create a Spark streaming context with master url and batchduration.
+
+{% highlight scala %}
+
+ val ssc = new StreamingContext(master, "WordCountCustomStreamSource",
+ Seconds(batchDuration))
+
+{% endhighlight %}
+
+* Plug-in the actor configuration into the spark streaming context and create a DStream.
+
+{% highlight scala %}
+
+ val lines = ssc.actorStream[String](Props(new SocketTextStreamReceiver(
+ "localhost",8445, z => z.utf8String)),"SocketReceiver")
+
+{% endhighlight %}
+
+* Process it.
+
+{% highlight scala %}
+
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+
+ wordCounts.print()
+ ssc.start()
+
+
+{% endhighlight %}
+
+* After processing it, stream can be tested using the netcat utility.
+
+ $ nc -l localhost 8445
+ hello world
+ hello hello
+
+
+## Multiple homogeneous/heterogeneous receivers.
+
+A DStream union operation is provided for taking union on multiple input streams.
+
+{% highlight scala %}
+
+ val lines = ssc.actorStream[String](Props(new SocketTextStreamReceiver(
+ "localhost",8445, z => z.utf8String)),"SocketReceiver")
+
+ // Another socket stream receiver
+ val lines2 = ssc.actorStream[String](Props(new SocketTextStreamReceiver(
+ "localhost",8446, z => z.utf8String)),"SocketReceiver")
+
+ val union = lines.union(lines2)
+
+{% endhighlight %}
+
+Above stream can be easily process as described earlier.
+
+_A more comprehensive example is provided in the spark streaming examples_
+
+## References
+
+1.[Akka Actor documentation](http://doc.akka.io/docs/akka/2.0.5/scala/actors.html)
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index b6da7af654..0e618a06c7 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -34,34 +34,34 @@ The StreamingContext is used to creating InputDStreams from input sources:
{% highlight scala %}
// Assuming ssc is the StreamingContext
-ssc.networkStream(hostname, port) // Creates a stream that uses a TCP socket to read data from hostname:port
-ssc.textFileStream(directory) // Creates a stream by monitoring and processing new files in a HDFS directory
+ssc.textFileStream(directory) // Creates a stream by monitoring and processing new files in a HDFS directory
+ssc.socketStream(hostname, port) // Creates a stream that uses a TCP socket to read data from hostname:port
{% endhighlight %}
-A complete list of input sources is available in the [StreamingContext API documentation](api/streaming/index.html#spark.streaming.StreamingContext). Data received from these sources can be processed using DStream operations, which are explained next.
+We also provide a input streams for Kafka, Flume, Akka actor, etc. For a complete list of input streams, take a look at the [StreamingContext API documentation](api/streaming/index.html#spark.streaming.StreamingContext).
# DStream Operations
-Once an input DStream has been created, you can transform it using _DStream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the DStream by writing data out to an external source.
+Data received from the input streams can be processed using _DStream operations_. There are two kinds of operations - _transformations_ and _output operations_. Similar to RDD transformations, DStream transformations operate on one or more DStreams to create new DStreams with transformed data. After applying a sequence of transformations to the input streams, you'll need to call the output operations, which writies data out to an external source.
## Transformations
DStreams support many of the transformations available on normal Spark RDD's:
<table class="table">
-<tr><th style="width:25%">Transformation</th><th>Meaning</th></tr>
+<tr><th style="width:30%">Transformation</th><th>Meaning</th></tr>
<tr>
<td> <b>map</b>(<i>func</i>) </td>
- <td> Returns a new DStream formed by passing each element of the source through a function <i>func</i>. </td>
+ <td> Returns a new DStream formed by passing each element of the source DStream through a function <i>func</i>. </td>
</tr>
<tr>
<td> <b>filter</b>(<i>func</i>) </td>
- <td> Returns a new stream formed by selecting those elements of the source on which <i>func</i> returns true. </td>
+ <td> Returns a new DStream formed by selecting those elements of the source DStream on which <i>func</i> returns true. </td>
</tr>
<tr>
<td> <b>flatMap</b>(<i>func</i>) </td>
- <td> Similar to map, but each input item can be mapped to 0 or more output items (so <i>func</i> should return a Seq rather than a single item). </td>
+ <td> Similar to map, but each input item can be mapped to 0 or more output items (so <i>func</i> should return a <code>Seq</code> rather than a single item). </td>
</tr>
<tr>
<td> <b>mapPartitions</b>(<i>func</i>) </td>
@@ -70,73 +70,92 @@ DStreams support many of the transformations available on normal Spark RDD's:
</tr>
<tr>
<td> <b>union</b>(<i>otherStream</i>) </td>
- <td> Return a new stream that contains the union of the elements in the source stream and the argument. </td>
+ <td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td>
+</tr>
+<tr>
+ <td> <b>count</b>() </td>
+ <td> Returns a new DStream of single-element RDDs by counting the number of elements in each RDD of the source DStream. </td>
+</tr>
+<tr>
+ <td> <b>reduce</b>(<i>func</i>) </td>
+ <td> Returns a new DStream of single-element RDDs by aggregating the elements in each RDD of the source DStream using a function <i>func</i> (which takes two arguments and returns one). The function should be associative so that it can be computed in parallel. </td>
+</tr>
+<tr>
+ <td> <b>countByValue</b>() </td>
+ <td> When called on a DStream of elements of type K, returns a new DStream of (K, Long) pairs where the value of each key is its frequency in each RDD of the source DStream. </td>
</tr>
<tr>
<td> <b>groupByKey</b>([<i>numTasks</i>]) </td>
- <td> When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs. <br />
-<b>Note:</b> By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks.
+ <td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, Seq[V]) pairs by grouping together all the values of each key in the RDDs of the source DStream. <br />
+ <b>Note:</b> By default, this uses Spark's default number of parallel tasks (2 for local machine, 8 for a cluser) to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks.
</td>
</tr>
<tr>
<td> <b>reduceByKey</b>(<i>func</i>, [<i>numTasks</i>]) </td>
- <td> When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td>
+ <td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td>
</tr>
<tr>
<td> <b>join</b>(<i>otherStream</i>, [<i>numTasks</i>]) </td>
- <td> When called on streams of type (K, V) and (K, W), returns a stream of (K, (V, W)) pairs with all pairs of elements for each key. </td>
+ <td> When called on two DStreams of (K, V) and (K, W) pairs, returns a new DStream of (K, (V, W)) pairs with all pairs of elements for each key. </td>
</tr>
<tr>
<td> <b>cogroup</b>(<i>otherStream</i>, [<i>numTasks</i>]) </td>
- <td> When called on DStream of type (K, V) and (K, W), returns a DStream of (K, Seq[V], Seq[W]) tuples.</td>
-</tr>
-<tr>
- <td> <b>reduce</b>(<i>func</i>) </td>
- <td> Returns a new DStream of single-element RDDs by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel. </td>
+ <td> When called on DStream of (K, V) and (K, W) pairs, returns a new DStream of (K, Seq[V], Seq[W]) tuples.</td>
</tr>
<tr>
<td> <b>transform</b>(<i>func</i>) </td>
<td> Returns a new DStream by applying func (a RDD-to-RDD function) to every RDD of the stream. This can be used to do arbitrary RDD operations on the DStream. </td>
</tr>
+<tr>
+ <td> <b>updateStateByKey</b>(<i>func</i>) </td>
+ <td> Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of each key. This can be used to track session state by using the session-id as the key and updating the session state as new data is received.</td>
+</tr>
+
</table>
-Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a <i>windowDuration</i>, which represents the width of the window and a <i>slideTime</i>, which represents the frequency during which the window is calculated.
+Spark Streaming features windowed computations, which allow you to apply transformations over a sliding window of data. All window functions take a <i>windowDuration</i>, which represents the width of the window and a <i>slideTime</i>, which represents the frequency during which the window is calculated.
<table class="table">
-<tr><th style="width:25%">Transformation</th><th>Meaning</th></tr>
+<tr><th style="width:30%">Transformation</th><th>Meaning</th></tr>
<tr>
- <td> <b>window</b>(<i>windowDuration</i>, </i>slideTime</i>) </td>
- <td> Return a new stream which is computed based on windowed batches of the source stream. <i>windowDuration</i> is the width of the window and <i>slideTime</i> is the frequency during which the window is calculated. Both times must be multiples of the batch interval.
+ <td> <b>window</b>(<i>windowDuration</i>, </i>slideDuration</i>) </td>
+ <td> Return a new DStream which is computed based on windowed batches of the source DStream. <i>windowDuration</i> is the width of the window and <i>slideTime</i> is the frequency during which the window is calculated. Both times must be multiples of the batch interval.
</td>
</tr>
<tr>
- <td> <b>countByWindow</b>(<i>windowDuration</i>, </i>slideTime</i>) </td>
+ <td> <b>countByWindow</b>(<i>windowDuration</i>, </i>slideDuration</i>) </td>
<td> Return a sliding count of elements in the stream. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
</td>
</tr>
<tr>
- <td> <b>reduceByWindow</b>(<i>func</i>, <i>windowDuration</i>, </i>slideDuration</i>) </td>
+ <td> <b>reduceByWindow</b>(<i>func</i>, <i>windowDuration</i>, <i>slideDuration</i>) </td>
<td> Return a new single-element stream, created by aggregating elements in the stream over a sliding interval using <i>func</i>. The function should be associative so that it can be computed correctly in parallel. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
</td>
</tr>
<tr>
- <td> <b>groupByKeyAndWindow</b>(windowDuration, slideDuration, [<i>numTasks</i>])
+ <td> <b>groupByKeyAndWindow</b>(<i>windowDuration</i>, <i>slideDuration</i>, [<i>numTasks</i>])
</td>
- <td> When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs over a sliding window. <br />
-<b>Note:</b> By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
-</td>
+ <td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, Seq[V]) pairs by grouping together values of each key over batches in a sliding window. <br />
+<b>Note:</b> By default, this uses Spark's default number of parallel tasks (2 for local machine, 8 for a cluser) to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks.</td>
</tr>
<tr>
- <td> <b>reduceByKeyAndWindow</b>(<i>func</i>, [<i>numTasks</i>]) </td>
- <td> When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function over batches within a sliding window. Like in <code>groupByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument.
+ <td> <b>reduceByKeyAndWindow</b>(<i>func</i>, <i>windowDuration</i>, <i>slideDuration</i>, [<i>numTasks</i>]) </td>
+ <td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, V) pairs where the values for each key are aggregated using the given reduce function <i>func</i> over batches in a sliding window. Like in <code>groupByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument.
<i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
</td>
</tr>
<tr>
- <td> <b>countByKeyAndWindow</b>([<i>numTasks</i>]) </td>
- <td> When called on a stream of (K, V) pairs, returns a stream of (K, Int) pairs where the values for each key are the count within a sliding window. Like in <code>countByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument.
+ <td> <b>reduceByKeyAndWindow</b>(<i>func</i>, <i>invFunc</i>, <i>windowDuration</i>, <i>slideDuration</i>, [<i>numTasks</i>]) </td>
+ <td> A more efficient version of the above <code>reduceByKeyAndWindow()</code> where the reduce value of each window is calculated
+ incrementally using the reduce values of the previous window. This is done by reducing the new data that enter the sliding window, and "inverse reducing" the old data that leave the window. An example would be that of "adding" and "subtracting" counts of keys as the window slides. However, it is applicable to only "invertible reduce functions", that is, those reduce functions which have a corresponding "inverse reduce" function (taken as parameter <i>invFunc</i>. Like in <code>groupByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument.
<i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
-</td>
+</td>
+</tr>
+<tr>
+ <td> <b>countByValueAndWindow</b>(<i>windowDuration</i>, <i>slideDuration</i>, [<i>numTasks</i>]) </td>
+ <td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, Long) pairs where the value of each key is its frequency within a sliding window. Like in <code>groupByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument.
+ <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
+</td>
</tr>
</table>
@@ -147,7 +166,7 @@ A complete list of DStream operations is available in the API documentation of [
When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined:
<table class="table">
-<tr><th style="width:25%">Operator</th><th>Meaning</th></tr>
+<tr><th style="width:30%">Operator</th><th>Meaning</th></tr>
<tr>
<td> <b>foreach</b>(<i>func</i>) </td>
<td> The fundamental output operator. Applies a function, <i>func</i>, to each RDD generated from the stream. This function should have side effects, such as printing output, saving the RDD to external files, or writing it over the network to an external system. </td>
@@ -176,11 +195,6 @@ When an output operator is called, it triggers the computation of a stream. Curr
</table>
-## DStream Persistence
-Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple DStream operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`.
-
-Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More information on different persistence levels can be found in [Spark Programming Guide](scala-programming-guide.html#rdd-persistence).
-
# Starting the Streaming computation
All the above DStream operations are completely lazy, that is, the operations will start executing only after the context is started by using
{% highlight scala %}
@@ -192,8 +206,8 @@ Conversely, the computation can be stopped by using
ssc.stop()
{% endhighlight %}
-# Example - NetworkWordCount.scala
-A good example to start off is the spark.streaming.examples.NetworkWordCount. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in <Spark repo>/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala.
+# Example
+A simple example to start off is the [NetworkWordCount](https://github.com/mesos/spark/tree/master/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala). This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in `<Spark repo>/streaming/src/main/scala/spark/streaming/examples/NetworkWordCount.scala` .
{% highlight scala %}
import spark.streaming.{Seconds, StreamingContext}
@@ -202,7 +216,7 @@ import spark.streaming.StreamingContext._
// Create the context and set up a network input stream to receive from a host:port
val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1))
-val lines = ssc.networkTextStream(args(1), args(2).toInt)
+val lines = ssc.socketTextStream(args(1), args(2).toInt)
// Split the lines into words, count them, and print some of the counts on the master
val words = lines.flatMap(_.split(" "))
@@ -213,6 +227,8 @@ wordCounts.print()
ssc.start()
{% endhighlight %}
+The `socketTextStream` returns a DStream of lines received from a TCP socket-based source. The `lines` DStream is _transformed_ into a DStream using the `flatMap` operation, where each line is split into words. This `words` DStream is then mapped to a DStream of `(word, 1)` pairs, which is finally reduced to get the word counts. `wordCounts.print()` will print 10 of the counts generated every second.
+
To run this example on your local machine, you need to first run a Netcat server by using
{% highlight bash %}
@@ -260,6 +276,31 @@ Time: 1357008430000 ms
</td>
</table>
+You can find more examples in `<Spark repo>/streaming/src/main/scala/spark/streaming/examples/`. They can be run in the similar manner using `./run spark.streaming.examples....` . Executing without any parameter would give the required parameter list. Further explanation to run them can be found in comments in the files.
+
+# DStream Persistence
+Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`.
+
+For input streams that receive data from the network (that is, subclasses of NetworkInputDStream like FlumeInputDStream and KafkaInputDStream), the default persistence level is set to replicate the data to two nodes for fault-tolerance.
+
+Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More information on different persistence levels can be found in [Spark Programming Guide](scala-programming-guide.html#rdd-persistence).
+
+# RDD Checkpointing within DStreams
+DStreams created by stateful operations like `updateStateByKey` require the RDDs in the DStream to be periodically saved to HDFS files for checkpointing. This is because, unless checkpointed, the lineage of operations of the state RDDs can increase indefinitely (since each RDD in the DStream depends on the previous RDD). This leads to two problems - (i) the size of Spark tasks increase proportionally with the RDD lineage leading higher task launch times, (ii) no limit on the amount of recomputation required on failure. Checkpointing RDDs at some interval by writing them to HDFS allows the lineage to be truncated. Note that checkpointing also incurs the cost of saving to HDFS which may cause the corresponding batch to take longer to process. Hence, the interval of checkpointing needs to be set carefully. At small batch sizes (say 1 second), checkpointing every batch may significantly reduce operation throughput. Conversely, checkpointing too slowly causes the lineage and task sizes to grow which may have detrimental effects. Typically, a checkpoint interval of 5 - 10 times of sliding interval of a DStream is good setting to try.
+
+To enable checkpointing, the developer has to provide the HDFS path to which RDD will be saved. This is done by using
+
+{% highlight scala %}
+ssc.checkpoint(hdfsPath) // assuming ssc is the StreamingContext
+{% endhighlight %}
+
+The interval of checkpointing of a DStream can be set by using
+
+{% highlight scala %}
+dstream.checkpoint(checkpointInterval) // checkpointInterval must be a multiple of slide duration of dstream
+{% endhighlight %}
+
+For DStreams that must be checkpointed (that is, DStreams created by `updateStateByKey` and `reduceByKeyAndWindow` with inverse function), the checkpoint interval of the DStream is by default set to a multiple of the DStream's sliding interval such that its at least 10 seconds.
# Performance Tuning
@@ -273,17 +314,21 @@ Getting the best performance of a Spark Streaming application on a cluster requi
There are a number of optimizations that can be done in Spark to minimize the processing time of each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section highlights some of the most important ones.
### Level of Parallelism
-Cluster resources maybe underutilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is 8. You can pass the level of parallelism as an argument (see the [`spark.PairDStreamFunctions`](api/streaming/index.html#spark.PairDStreamFunctions) documentation), or set the system property `spark.default.parallelism` to change the default.
+Cluster resources maybe under-utilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is 8. You can pass the level of parallelism as an argument (see the [`spark.PairDStreamFunctions`](api/streaming/index.html#spark.PairDStreamFunctions) documentation), or set the system property `spark.default.parallelism` to change the default.
### Data Serialization
The overhead of data serialization can be significant, especially when sub-second batch sizes are to be achieved. There are two aspects to it.
-* Serialization of RDD data in Spark: Please refer to the detailed discussion on data serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default RDDs are persisted as serialized byte arrays to minimize pauses related to GC.
-* Serialization of input data: To ingest external data into Spark, data received as bytes (say, from the network) needs to deserialized from bytes and re-serialized into Spark's serialization format. Hence, the deserialization overhead of input data may be a bottleneck.
+
+* **Serialization of RDD data in Spark**: Please refer to the detailed discussion on data serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default RDDs are persisted as serialized byte arrays to minimize pauses related to GC.
+
+* **Serialization of input data**: To ingest external data into Spark, data received as bytes (say, from the network) needs to deserialized from bytes and re-serialized into Spark's serialization format. Hence, the deserialization overhead of input data may be a bottleneck.
### Task Launching Overheads
If the number of tasks launched per second is high (say, 50 or more per second), then the overhead of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes:
-* Task Serialization: Using Kryo serialization for serializing tasks can reduced the task sizes, and therefore reduce the time taken to send them to the slaves.
-* Execution mode: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the [Running on Mesos guide](running-on-mesos.html) for more details.
+
+* **Task Serialization**: Using Kryo serialization for serializing tasks can reduced the task sizes, and therefore reduce the time taken to send them to the slaves.
+
+* **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the [Running on Mesos guide](running-on-mesos.html) for more details.
These changes may reduce batch processing time by 100s of milliseconds, thus allowing sub-second batch size to be viable.
## Setting the Right Batch Size
@@ -292,22 +337,182 @@ For a Spark Streaming application running on a cluster to be stable, the process
A good approach to figure out the right batch size for your application is to test it with a conservative batch size (say, 5-10 seconds) and a low data rate. To verify whether the system is able to keep up with data rate, you can check the value of the end-to-end delay experienced by each processed batch (in the Spark master logs, find the line having the phrase "Total delay"). If the delay is maintained to be less than the batch size, then system is stable. Otherwise, if the delay is continuously increasing, it means that the system is unable to keep up and it therefore unstable. Once you have an idea of a stable configuration, you can try increasing the data rate and/or reducing the batch size. Note that momentary increase in the delay due to temporary data rate increases maybe fine as long as the delay reduces back to a low value (i.e., less than batch size).
## 24/7 Operation
-By default, Spark does not forget any of the metadata (RDDs generated, stages processed, etc.). But for a Spark Streaming application to operate 24/7, it is necessary for Spark to do periodic cleanup of it metadata. This can be enabled by setting the Java system property `spark.cleaner.delay` to the number of minutes you want any metadata to persist. For example, setting `spark.cleaner.delay` to 10 would cause Spark periodically cleanup all metadata and persisted RDDs that are older than 10 minutes. Note, that this property needs to be set before the SparkContext is created.
+By default, Spark does not forget any of the metadata (RDDs generated, stages processed, etc.). But for a Spark Streaming application to operate 24/7, it is necessary for Spark to do periodic cleanup of it metadata. This can be enabled by setting the Java system property `spark.cleaner.ttl` to the number of seconds you want any metadata to persist. For example, setting `spark.cleaner.ttl` to 600 would cause Spark periodically cleanup all metadata and persisted RDDs that are older than 10 minutes. Note, that this property needs to be set before the SparkContext is created.
This value is closely tied with any window operation that is being used. Any window operation would require the input data to be persisted in memory for at least the duration of the window. Hence it is necessary to set the delay to at least the value of the largest window operation used in the Spark Streaming application. If this delay is set too low, the application will throw an exception saying so.
## Memory Tuning
Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail in the [Tuning Guide](tuning.html). It is recommended that you read that. In this section, we highlight a few customizations that are strongly recommended to minimize GC related pauses in Spark Streaming applications and achieving more consistent batch processing times.
-* <b>Default persistence level of DStreams</b>: Unlike RDDs, the default persistence level of DStreams serializes the data in memory (that is, [StorageLevel.MEMORY_ONLY_SER](api/core/index.html#spark.storage.StorageLevel$) for DStream compared to [StorageLevel.MEMORY_ONLY](api/core/index.html#spark.storage.StorageLevel$) for RDDs). Even though keeping the data serialized incurs a higher serialization overheads, it significantly reduces GC pauses.
+* **Default persistence level of DStreams**: Unlike RDDs, the default persistence level of DStreams serializes the data in memory (that is, [StorageLevel.MEMORY_ONLY_SER](api/core/index.html#spark.storage.StorageLevel$) for DStream compared to [StorageLevel.MEMORY_ONLY](api/core/index.html#spark.storage.StorageLevel$) for RDDs). Even though keeping the data serialized incurs a higher serialization overheads, it significantly reduces GC pauses.
+
+* **Concurrent garbage collector**: Using the concurrent mark-and-sweep GC further minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more consistent batch processing times.
+
+# Fault-tolerance Properties
+In this section, we are going to discuss the behavior of Spark Streaming application in the event of a node failure. To understand this, let us remember the basic fault-tolerance properties of Spark's RDDs.
+
+ 1. An RDD is an immutable, and deterministically re-computable, distributed dataset. Each RDD remembers the lineage of deterministic operations that were used on a fault-tolerant input dataset to create it.
+ 1. If any partition of an RDD is lost due to a worker node failure, then that partition can be re-computed from the original fault-tolerant dataset using the lineage of operations.
+
+Since all data transformations in Spark Streaming are based on RDD operations, as long as the input dataset is present, all intermediate data can recomputed. Keeping these properties in mind, we are going to discuss the failure semantics in more detail.
+
+## Failure of a Worker Node
+
+There are two failure behaviors based on which input sources are used.
+
+1. _Using HDFS files as input source_ - Since the data is reliably stored on HDFS, all data can re-computed and therefore no data will be lost due to any failure.
+1. _Using any input source that receives data through a network_ - For network-based data sources like Kafka and Flume, the received input data is replicated in memory between nodes of the cluster (default replication factor is 2). So if a worker node fails, then the system can recompute the lost from the the left over copy of the input data. However, if the worker node where a network receiver was running fails, then a tiny bit of data may be lost, that is, the data received by the system but not yet replicated to other node(s). The receiver will be started on a different node and it will continue to receive data.
+
+Since all data is modeled as RDDs with their lineage of deterministic operations, any recomputation always leads to the same result. As a result, all DStream transformations are guaranteed to have _exactly-once_ semantics. That is, the final transformed result will be same even if there were was a worker node failure. However, output operations (like `foreach`) have _at-least once_ semantics, that is, the transformed data may get written to an external entity more than once in the event of a worker failure. While this is acceptable for saving to HDFS using the `saveAs*Files` operations (as the file will simply get over-written by the same data), additional transactions-like mechanisms may be necessary to achieve exactly-once semantics for output operations.
+
+## Failure of the Driver Node
+A system that is required to operate 24/7 needs to be able tolerate the failure of the driver node as well. Spark Streaming does this by saving the state of the DStream computation periodically to a HDFS file, that can be used to restart the streaming computation in the event of a failure of the driver node. This checkpointing is enabled by setting a HDFS directory for checkpointing using `ssc.checkpoint(<checkpoint directory>)` as described [earlier](#rdd-checkpointing-within-dstreams). To elaborate, the following state is periodically saved to a file.
+
+1. The DStream operator graph (input streams, output streams, etc.)
+1. The configuration of each DStream (checkpoint interval, etc.)
+1. The RDD checkpoint files of each DStream
+
+All this is periodically saved in the file `<checkpoint directory>/graph`. To recover, a new Streaming Context can be created with this directory by using
+
+{% highlight scala %}
+val ssc = new StreamingContext(checkpointDirectory)
+{% endhighlight %}
+
+On calling `ssc.start()` on this new context, the following steps are taken by the system
+
+1. Schedule the transformations and output operations for all the time steps between the time when the driver failed and when it was restarted. This is also done for those time steps that were scheduled but not processed due to the failure. This will make the system recompute all the intermediate data from the checkpointed RDD files, etc.
+1. Restart the network receivers, if any, and continue receiving new data.
-* <b>Concurrent garbage collector</b>: Using the concurrent mark-and-sweep GC further minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more consistent batch processing times.
+In the current _alpha_ release, there are two different failure behaviors based on which input sources are used.
-# Master Fault-tolerance (Alpha)
-TODO
+1. _Using HDFS files as input source_ - Since the data is reliably stored on HDFS, all data can re-computed and therefore no data will be lost due to any failure.
+1. _Using any input source that receives data through a network_ - As aforesaid, the received input data is replicated in memory to multiple nodes. Since, all the data in the Spark worker's memory is lost when the Spark driver fails, the past input data will not be accessible and driver recovers. Hence, if stateful and window-based operations are used (like `updateStateByKey`, `window`, `countByValueAndWindow`, etc.), then the intermediate state will not be recovered completely.
+
+In future releases, this behaviour will be fixed for all input sources, that is, all data will be recovered irrespective of which input sources are used. Note that for non-stateful transformations like `map`, `count`, and `reduceByKey`, with _all_ input streams, the system, upon restarting, will continue to receive and process new data.
+
+To better understand the behavior of the system under driver failure with a HDFS source, lets consider what will happen with a file input stream Specifically, in the case of the file input stream, it will correctly identify new files that were created while the driver was down and process them in the same way as it would have if the driver had not failed. To explain further in the case of file input stream, we shall use an example. Lets say, files are being generated every second, and a Spark Streaming program reads every new file and output the number of lines in the file. This is what the sequence of outputs would be with and without a driver failure.
+
+<table class="table">
+ <!-- Results table headers -->
+ <tr>
+ <th> Time </th>
+ <th> Number of lines in input file </th>
+ <th> Output without driver failure </th>
+ <th> Output with driver failure </th>
+ </tr>
+ <tr>
+ <td>1</td>
+ <td>10</td>
+ <td>10</td>
+ <td>10</td>
+ </tr>
+ <tr>
+ <td>2</td>
+ <td>20</td>
+ <td>20</td>
+ <td>20</td>
+ </tr>
+ <tr>
+ <td>3</td>
+ <td>30</td>
+ <td>30</td>
+ <td>30</td>
+ </tr>
+ <tr>
+ <td>4</td>
+ <td>40</td>
+ <td>40</td>
+ <td>[DRIVER FAILS]<br />no output</td>
+ </tr>
+ <tr>
+ <td>5</td>
+ <td>50</td>
+ <td>50</td>
+ <td>no output</td>
+ </tr>
+ <tr>
+ <td>6</td>
+ <td>60</td>
+ <td>60</td>
+ <td>no output</td>
+ </tr>
+ <tr>
+ <td>7</td>
+ <td>70</td>
+ <td>70</td>
+ <td>[DRIVER RECOVERS]<br />40, 50, 60, 70</td>
+ </tr>
+ <tr>
+ <td>8</td>
+ <td>80</td>
+ <td>80</td>
+ <td>80</td>
+ </tr>
+ <tr>
+ <td>9</td>
+ <td>90</td>
+ <td>90</td>
+ <td>90</td>
+ </tr>
+ <tr>
+ <td>10</td>
+ <td>100</td>
+ <td>100</td>
+ <td>100</td>
+ </tr>
+</table>
+
+If the driver had crashed in the middle of the processing of time 3, then it will process time 3 and output 30 after recovery.
+
+# Java API
+
+Similar to [Spark's Java API](java-programming-guide.html), we also provide a Java API for Spark Streaming which allows all its features to be accessible from a Java program. This is defined in [spark.streaming.api.java] (api/streaming/index.html#spark.streaming.api.java.package) package and includes [JavaStreamingContext](api/streaming/index.html#spark.streaming.api.java.JavaStreamingContext) and [JavaDStream](api/streaming/index.html#spark.streaming.api.java.JavaDStream) classes that provide the same methods as their Scala counterparts, but take Java functions (that is, Function, and Function2) and return Java data and collection types. Some of the key points to note are:
+
+1. Functions for transformations must be implemented as subclasses of [Function](api/core/index.html#spark.api.java.function.Function) and [Function2](api/core/index.html#spark.api.java.function.Function2)
+1. Unlike the Scala API, the Java API handles DStreams for key-value pairs using a separate [JavaPairDStream](api/streaming/index.html#spark.streaming.api.java.JavaPairDStream) class(similar to [JavaRDD and JavaPairRDD](java-programming-guide.html#rdd-classes). DStream functions like `map` and `filter` are implemented separately by JavaDStreams and JavaPairDStream to return DStreams of appropriate types.
+
+Spark's [Java Programming Guide](java-programming-guide.html) gives more ideas about using the Java API. To extends the ideas presented for the RDDs to DStreams, we present parts of the Java version of the same NetworkWordCount example presented above. The full source code is given at `<spark repo>/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java`
+
+The streaming context and the socket stream from input source is started by using a `JavaStreamingContext`, that has the same parameters and provides the same input streams as its Scala counterpart.
+
+{% highlight java %}
+JavaStreamingContext ssc = new JavaStreamingContext(mesosUrl, "NetworkWordCount", Seconds(1));
+JavaDStream<String> lines = ssc.socketTextStream(ip, port);
+{% endhighlight %}
+
+
+Then the `lines` are split into words by using the `flatMap` function and [FlatMapFunction](api/core/index.html#spark.api.java.function.FlatMapFunction).
+
+{% highlight java %}
+JavaDStream<String> words = lines.flatMap(
+ new FlatMapFunction<String, String>() {
+ @Override
+ public Iterable<String> call(String x) {
+ return Lists.newArrayList(x.split(" "));
+ }
+ });
+{% endhighlight %}
+
+The `words` is then mapped to a [JavaPairDStream](api/streaming/index.html#spark.streaming.api.java.JavaPairDStream) of `(word, 1)` pairs using `map` and [PairFunction](api/core/index.html#spark.api.java.function.PairFunction). This is reduced by using `reduceByKey` and [Function2](api/core/index.html#spark.api.java.function.Function2).
+
+{% highlight java %}
+JavaPairDStream<String, Integer> wordCounts = words.map(
+ new PairFunction<String, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(String s) throws Exception {
+ return new Tuple2<String, Integer>(s, 1);
+ }
+ }).reduceByKey(
+ new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer i1, Integer i2) throws Exception {
+ return i1 + i2;
+ }
+ });
+{% endhighlight %}
-* Checkpointing of DStream graph
-* Recovery from master faults
-* Current state and future directions \ No newline at end of file
+# Where to Go from Here
+* Documentation - [Scala](api/streaming/index.html#spark.streaming.package) and [Java](api/streaming/index.html#spark.streaming.api.java.package)
+* More examples - [Scala](https://github.com/mesos/spark/tree/master/examples/src/main/scala/spark/streaming/examples) and [Java](https://github.com/mesos/spark/tree/master/examples/src/main/java/spark/streaming/examples)
diff --git a/docs/tuning.md b/docs/tuning.md
index 9aaa53cd65..843380b9a2 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -213,10 +213,10 @@ but at a high level, managing how frequently full GC takes place can help in red
Clusters will not be fully utilized unless you set the level of parallelism for each operation high
enough. Spark automatically sets the number of "map" tasks to run on each file according to its size
-(though you can control it through optional parameters to `SparkContext.textFile`, etc), but for
-distributed "reduce" operations, such as `groupByKey` and `reduceByKey`, it uses a default value of 8.
-You can pass the level of parallelism as a second argument (see the
-[`spark.PairRDDFunctions`](api/core/index.html#spark.PairRDDFunctions) documentation),
+(though you can control it through optional parameters to `SparkContext.textFile`, etc), and for
+distributed "reduce" operations, such as `groupByKey` and `reduceByKey`, it uses the largest
+parent RDD's number of partitions. You can pass the level of parallelism as a second argument
+(see the [`spark.PairRDDFunctions`](api/core/index.html#spark.PairRDDFunctions) documentation),
or set the system property `spark.default.parallelism` to change the default.
In general, we recommend 2-3 tasks per CPU core in your cluster.
@@ -233,7 +233,7 @@ number of cores in your clusters.
## Broadcasting Large Variables
-Using the [broadcast functionality](scala-programming-guide#broadcast-variables)
+Using the [broadcast functionality](scala-programming-guide.html#broadcast-variables)
available in `SparkContext` can greatly reduce the size of each serialized task, and the cost
of launching a job over a cluster. If your tasks use any large object from the driver program
inside of them (e.g. a static lookup table), consider turning it into a broadcast variable.
diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh
new file mode 100644
index 0000000000..166a884c88
--- /dev/null
+++ b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+# These variables are automatically filled in by the mesos-ec2 script.
+export MESOS_MASTERS="{{master_list}}"
+export MESOS_SLAVES="{{slave_list}}"
+export MESOS_ZOO_LIST="{{zoo_list}}"
+export MESOS_HDFS_DATA_DIRS="{{hdfs_data_dirs}}"
+export MESOS_MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}"
+export MESOS_SPARK_LOCAL_DIRS="{{spark_local_dirs}}"
+export MODULES="{{modules}}"
+export SWAP_MB="{{swap}}"
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index a5384d3bda..66b1faf2cd 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -82,12 +82,21 @@ def parse_args():
parser.add_option("--spot-price", metavar="PRICE", type="float",
help="If specified, launch slaves as spot instances with the given " +
"maximum price (in dollars)")
- parser.add_option("-c", "--cluster-type", default="mesos",
- help="'mesos' for a mesos cluster, 'standalone' for a standalone spark cluster (default: mesos)")
+ parser.add_option("--cluster-type", type="choice", metavar="TYPE",
+ choices=["mesos", "standalone"], default="mesos",
+ help="'mesos' for a Mesos cluster, 'standalone' for a standalone " +
+ "Spark cluster (default: mesos)")
+ parser.add_option("--ganglia", action="store_true", default=True,
+ help="Setup Ganglia monitoring on cluster (default: on). NOTE: " +
+ "the Ganglia page will be publicly accessible")
+ parser.add_option("--no-ganglia", action="store_false", dest="ganglia",
+ help="Disable Ganglia monitoring for the cluster")
+ parser.add_option("--new-scripts", action="store_true", default=False,
+ help="Use new spark-ec2 scripts, for Spark >= 0.7 AMIs")
parser.add_option("-u", "--user", default="root",
- help="The ssh user you want to connect as (default: root)")
+ help="The SSH user you want to connect as (default: root)")
parser.add_option("--delete-groups", action="store_true", default=False,
- help="When destroying a cluster, also destroy the security groups that were created")
+ help="When destroying a cluster, delete the security groups that were created")
(opts, args) = parser.parse_args()
if len(args) != 2:
@@ -164,22 +173,23 @@ def launch_cluster(conn, opts, cluster_name):
master_group.authorize(src_group=zoo_group)
master_group.authorize('tcp', 22, 22, '0.0.0.0/0')
master_group.authorize('tcp', 8080, 8081, '0.0.0.0/0')
+ master_group.authorize('tcp', 50030, 50030, '0.0.0.0/0')
+ master_group.authorize('tcp', 50070, 50070, '0.0.0.0/0')
+ master_group.authorize('tcp', 60070, 60070, '0.0.0.0/0')
if opts.cluster_type == "mesos":
- master_group.authorize('tcp', 50030, 50030, '0.0.0.0/0')
- master_group.authorize('tcp', 50070, 50070, '0.0.0.0/0')
- master_group.authorize('tcp', 60070, 60070, '0.0.0.0/0')
master_group.authorize('tcp', 38090, 38090, '0.0.0.0/0')
+ if opts.ganglia:
+ master_group.authorize('tcp', 5080, 5080, '0.0.0.0/0')
if slave_group.rules == []: # Group was just now created
slave_group.authorize(src_group=master_group)
slave_group.authorize(src_group=slave_group)
slave_group.authorize(src_group=zoo_group)
slave_group.authorize('tcp', 22, 22, '0.0.0.0/0')
slave_group.authorize('tcp', 8080, 8081, '0.0.0.0/0')
- if opts.cluster_type == "mesos":
- slave_group.authorize('tcp', 50060, 50060, '0.0.0.0/0')
- slave_group.authorize('tcp', 50075, 50075, '0.0.0.0/0')
- slave_group.authorize('tcp', 60060, 60060, '0.0.0.0/0')
- slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0')
+ slave_group.authorize('tcp', 50060, 50060, '0.0.0.0/0')
+ slave_group.authorize('tcp', 50075, 50075, '0.0.0.0/0')
+ slave_group.authorize('tcp', 60060, 60060, '0.0.0.0/0')
+ slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0')
if zoo_group.rules == []: # Group was just now created
zoo_group.authorize(src_group=master_group)
zoo_group.authorize(src_group=slave_group)
@@ -358,19 +368,38 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
# Deploy configuration files and run setup scripts on a newly launched
# or started EC2 cluster.
def setup_cluster(conn, master_nodes, slave_nodes, zoo_nodes, opts, deploy_ssh_key):
- print "Deploying files to master..."
- deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, zoo_nodes)
master = master_nodes[0].public_dns_name
if deploy_ssh_key:
print "Copying SSH key %s to master..." % opts.identity_file
ssh(master, opts, 'mkdir -p ~/.ssh')
scp(master, opts, opts.identity_file, '~/.ssh/id_rsa')
ssh(master, opts, 'chmod 600 ~/.ssh/id_rsa')
- print "Running setup on master..."
+
if opts.cluster_type == "mesos":
- setup_mesos_cluster(master, opts)
+ modules = ['ephemeral-hdfs', 'persistent-hdfs', 'mesos']
elif opts.cluster_type == "standalone":
- setup_standalone_cluster(master, slave_nodes, opts)
+ modules = ['ephemeral-hdfs', 'persistent-hdfs', 'spark-standalone']
+
+ if opts.ganglia:
+ modules.append('ganglia')
+
+ if opts.new_scripts:
+ # NOTE: We should clone the repository before running deploy_files to
+ # prevent ec2-variables.sh from being overwritten
+ ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git")
+
+ print "Deploying files to master..."
+ deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes,
+ zoo_nodes, modules)
+
+ print "Running setup on master..."
+ if not opts.new_scripts:
+ if opts.cluster_type == "mesos":
+ setup_mesos_cluster(master, opts)
+ elif opts.cluster_type == "standalone":
+ setup_standalone_cluster(master, slave_nodes, opts)
+ else:
+ setup_spark_cluster(master, opts)
print "Done!"
def setup_mesos_cluster(master, opts):
@@ -383,6 +412,17 @@ def setup_standalone_cluster(master, slave_nodes, opts):
ssh(master, opts, "echo \"%s\" > spark/conf/slaves" % (slave_ips))
ssh(master, opts, "/root/spark/bin/start-all.sh")
+def setup_spark_cluster(master, opts):
+ ssh(master, opts, "chmod u+x spark-ec2/setup.sh")
+ ssh(master, opts, "spark-ec2/setup.sh")
+ if opts.cluster_type == "mesos":
+ print "Mesos cluster started at http://%s:8080" % master
+ elif opts.cluster_type == "standalone":
+ print "Spark standalone cluster started at http://%s:8080" % master
+
+ if opts.ganglia:
+ print "Ganglia started at http://%s:5080/ganglia" % master
+
# Wait for a whole cluster (masters, slaves and ZooKeeper) to start up
def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes, zoo_nodes):
@@ -427,7 +467,8 @@ def get_num_disks(instance_type):
# cluster (e.g. lists of masters and slaves). Files are only deployed to
# the first master instance in the cluster, and we expect the setup
# script to be run on that instance to copy them to other nodes.
-def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, zoo_nodes):
+def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, zoo_nodes,
+ modules):
active_master = master_nodes[0].public_dns_name
num_disks = get_num_disks(opts.instance_type)
@@ -459,7 +500,9 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, zoo_nodes):
"cluster_url": cluster_url,
"hdfs_data_dirs": hdfs_data_dirs,
"mapred_local_dirs": mapred_local_dirs,
- "spark_local_dirs": spark_local_dirs
+ "spark_local_dirs": spark_local_dirs,
+ "swap": str(opts.swap),
+ "modules": '\n'.join(modules)
}
# Create a temp directory in which we will place all the files to be
diff --git a/examples/pom.xml b/examples/pom.xml
index f43af670c6..7d975875fa 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -20,11 +20,10 @@
<artifactId>jetty-server</artifactId>
</dependency>
<dependency>
- <groupId>org.twitter4j</groupId>
- <artifactId>twitter4j-stream</artifactId>
- <version>3.0.3</version>
+ <groupId>com.twitter</groupId>
+ <artifactId>algebird-core_2.9.2</artifactId>
+ <version>0.1.8</version>
</dependency>
-
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.version}</artifactId>
diff --git a/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java
index cddce16e39..cddce16e39 100644
--- a/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java
+++ b/examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java
diff --git a/examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java
index 4299febfd6..0e9eadd01b 100644
--- a/examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java
+++ b/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java
@@ -23,7 +23,7 @@ import spark.streaming.api.java.JavaStreamingContext;
*/
public class JavaNetworkWordCount {
public static void main(String[] args) {
- if (args.length < 2) {
+ if (args.length < 3) {
System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" +
"In local mode, <master> should be 'local[n]' with n > 1");
System.exit(1);
@@ -35,7 +35,7 @@ public class JavaNetworkWordCount {
// Create a NetworkInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
- JavaDStream<String> lines = ssc.networkTextStream(args[1], Integer.parseInt(args[2]));
+ JavaDStream<String> lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2]));
JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
@Override
public Iterable<String> call(String x) {
diff --git a/examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/java/spark/streaming/examples/JavaQueueStream.java
index 43c3cd4dfa..43c3cd4dfa 100644
--- a/examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java
+++ b/examples/src/main/java/spark/streaming/examples/JavaQueueStream.java
diff --git a/examples/src/main/scala/spark/examples/LogQuery.scala b/examples/src/main/scala/spark/examples/LogQuery.scala
new file mode 100644
index 0000000000..5330b8da94
--- /dev/null
+++ b/examples/src/main/scala/spark/examples/LogQuery.scala
@@ -0,0 +1,66 @@
+package spark.examples
+
+import spark.SparkContext
+import spark.SparkContext._
+/**
+ * Executes a roll up-style query against Apache logs.
+ */
+object LogQuery {
+ val exampleApacheLogs = List(
+ """10.10.10.10 - "FRED" [18/Jan/2013:17:56:07 +1100] "GET http://images.com/2013/Generic.jpg
+ | HTTP/1.1" 304 315 "http://referall.com/" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1;
+ | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR
+ | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR
+ | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.350 "-" - "" 265 923 934 ""
+ | 62.24.11.25 images.com 1358492167 - Whatup""".stripMargin.replace("\n", ""),
+ """10.10.10.10 - "FRED" [18/Jan/2013:18:02:37 +1100] "GET http://images.com/2013/Generic.jpg
+ | HTTP/1.1" 304 306 "http:/referall.com" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1;
+ | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR
+ | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR
+ | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.352 "-" - "" 256 977 988 ""
+ | 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.replace("\n", "")
+ )
+
+ def main(args: Array[String]) {
+ if (args.length == 0) {
+ System.err.println("Usage: LogQuery <master> [logFile]")
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "Log Query")
+
+ val dataSet =
+ if (args.length == 2) sc.textFile(args(1))
+ else sc.parallelize(exampleApacheLogs)
+
+ val apacheLogRegex =
+ """^([\d.]+) (\S+) (\S+) \[([\w\d:/]+\s[+\-]\d{4})\] "(.+?)" (\d{3}) ([\d\-]+) "([^"]+)" "([^"]+)".*""".r
+
+ /** Tracks the total query count and number of aggregate bytes for a particular group. */
+ class Stats(val count: Int, val numBytes: Int) extends Serializable {
+ def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes)
+ override def toString = "bytes=%s\tn=%s".format(numBytes, count)
+ }
+
+ def extractKey(line: String): (String, String, String) = {
+ apacheLogRegex.findFirstIn(line) match {
+ case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) =>
+ if (user != "\"-\"") (ip, user, query)
+ else (null, null, null)
+ case _ => (null, null, null)
+ }
+ }
+
+ def extractStats(line: String): Stats = {
+ apacheLogRegex.findFirstIn(line) match {
+ case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) =>
+ new Stats(1, bytes.toInt)
+ case _ => new Stats(1, 0)
+ }
+ }
+
+ dataSet.map(line => (extractKey(line), extractStats(line)))
+ .reduceByKey((a, b) => a.merge(b))
+ .collect().foreach{
+ case (user, query) => println("%s\t%s".format(user, query))}
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala
new file mode 100644
index 0000000000..76293fbb96
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala
@@ -0,0 +1,157 @@
+package spark.streaming.examples
+
+import scala.collection.mutable.LinkedList
+import scala.util.Random
+
+import akka.actor.Actor
+import akka.actor.ActorRef
+import akka.actor.Props
+import akka.actor.actorRef2Scala
+
+import spark.streaming.Seconds
+import spark.streaming.StreamingContext
+import spark.streaming.StreamingContext.toPairDStreamFunctions
+import spark.streaming.receivers.Receiver
+import spark.util.AkkaUtils
+
+case class SubscribeReceiver(receiverActor: ActorRef)
+case class UnsubscribeReceiver(receiverActor: ActorRef)
+
+/**
+ * Sends the random content to every receiver subscribed with 1/2
+ * second delay.
+ */
+class FeederActor extends Actor {
+
+ val rand = new Random()
+ var receivers: LinkedList[ActorRef] = new LinkedList[ActorRef]()
+
+ val strings: Array[String] = Array("words ", "may ", "count ")
+
+ def makeMessage(): String = {
+ val x = rand.nextInt(3)
+ strings(x) + strings(2 - x)
+ }
+
+ /*
+ * A thread to generate random messages
+ */
+ new Thread() {
+ override def run() {
+ while (true) {
+ Thread.sleep(500)
+ receivers.foreach(_ ! makeMessage)
+ }
+ }
+ }.start()
+
+ def receive: Receive = {
+
+ case SubscribeReceiver(receiverActor: ActorRef) =>
+ println("received subscribe from %s".format(receiverActor.toString))
+ receivers = LinkedList(receiverActor) ++ receivers
+
+ case UnsubscribeReceiver(receiverActor: ActorRef) =>
+ println("received unsubscribe from %s".format(receiverActor.toString))
+ receivers = receivers.dropWhile(x => x eq receiverActor)
+
+ }
+}
+
+/**
+ * A sample actor as receiver, is also simplest. This receiver actor
+ * goes and subscribe to a typical publisher/feeder actor and receives
+ * data.
+ *
+ * @see [[spark.streaming.examples.FeederActor]]
+ */
+class SampleActorReceiver[T: ClassManifest](urlOfPublisher: String)
+extends Actor with Receiver {
+
+ lazy private val remotePublisher = context.actorFor(urlOfPublisher)
+
+ override def preStart = remotePublisher ! SubscribeReceiver(context.self)
+
+ def receive = {
+ case msg ⇒ context.parent ! pushBlock(msg.asInstanceOf[T])
+ }
+
+ override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self)
+
+}
+
+/**
+ * A sample feeder actor
+ *
+ * Usage: FeederActor <hostname> <port>
+ * <hostname> and <port> describe the AkkaSystem that Spark Sample feeder would start on.
+ */
+object FeederActor {
+
+ def main(args: Array[String]) {
+ if(args.length < 2){
+ System.err.println(
+ "Usage: FeederActor <hostname> <port>\n"
+ )
+ System.exit(1)
+ }
+ val Seq(host, port) = args.toSeq
+
+
+ val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt)._1
+ val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor")
+
+ println("Feeder started as:" + feeder)
+
+ actorSystem.awaitTermination();
+ }
+}
+
+/**
+ * A sample word count program demonstrating the use of plugging in
+ * Actor as Receiver
+ * Usage: ActorWordCount <master> <hostname> <port>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <hostname> and <port> describe the AkkaSystem that Spark Sample feeder is running on.
+ *
+ * To run this example locally, you may run Feeder Actor as
+ * `$ ./run spark.streaming.examples.FeederActor 127.0.1.1 9999`
+ * and then run the example
+ * `$ ./run spark.streaming.examples.ActorWordCount local[2] 127.0.1.1 9999`
+ */
+object ActorWordCount {
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println(
+ "Usage: ActorWordCount <master> <hostname> <port>" +
+ "In local mode, <master> should be 'local[n]' with n > 1")
+ System.exit(1)
+ }
+
+ val Seq(master, host, port) = args.toSeq
+
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(master, "ActorWordCount", Seconds(2))
+
+ /*
+ * Following is the use of actorStream to plug in custom actor as receiver
+ *
+ * An important point to note:
+ * Since Actor may exist outside the spark framework, It is thus user's responsibility
+ * to ensure the type safety, i.e type of data received and InputDstream
+ * should be same.
+ *
+ * For example: Both actorStream and SampleActorReceiver are parameterized
+ * to same type to ensure type safety.
+ */
+
+ val lines = ssc.actorStream[String](
+ Props(new SampleActorReceiver[String]("akka://test@%s:%s/user/FeederActor".format(
+ host, port.toInt))), "SampleReceiver")
+
+ //compute wordcount
+ lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print()
+
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
index fe55db6e2c..9b135a5c54 100644
--- a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
+++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
@@ -10,22 +10,34 @@ import spark.streaming.StreamingContext._
import spark.storage.StorageLevel
import spark.streaming.util.RawTextHelper._
+/**
+ * Consumes messages from one or more topics in Kafka and does wordcount.
+ * Usage: KafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <zkQuorum> is a list of one or more zookeeper servers that make quorum
+ * <group> is the name of kafka consumer group
+ * <topics> is a list of one or more kafka topics to consume from
+ * <numThreads> is the number of threads the kafka consumer should use
+ *
+ * Example:
+ * `./run spark.streaming.examples.KafkaWordCount local[2] zoo01,zoo02,zoo03 my-consumer-group topic1,topic2 1`
+ */
object KafkaWordCount {
def main(args: Array[String]) {
- if (args.length < 6) {
- System.err.println("Usage: KafkaWordCount <master> <hostname> <port> <group> <topics> <numThreads>")
+ if (args.length < 5) {
+ System.err.println("Usage: KafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>")
System.exit(1)
}
- val Array(master, hostname, port, group, topics, numThreads) = args
+ val Array(master, zkQuorum, group, topics, numThreads) = args
val sc = new SparkContext(master, "KafkaWordCount")
val ssc = new StreamingContext(sc, Seconds(2))
ssc.checkpoint("checkpoint")
val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
- val lines = ssc.kafkaStream[String](hostname, port.toInt, group, topicpMap)
+ val lines = ssc.kafkaStream[String](zkQuorum, group, topicpMap)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
wordCounts.print()
@@ -38,16 +50,16 @@ object KafkaWordCount {
object KafkaWordCountProducer {
def main(args: Array[String]) {
- if (args.length < 3) {
- System.err.println("Usage: KafkaWordCountProducer <hostname> <port> <topic> <messagesPerSec> <wordsPerMessage>")
+ if (args.length < 2) {
+ System.err.println("Usage: KafkaWordCountProducer <zkQuorum> <topic> <messagesPerSec> <wordsPerMessage>")
System.exit(1)
}
- val Array(hostname, port, topic, messagesPerSec, wordsPerMessage) = args
+ val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args
// Zookeper connection properties
val props = new Properties()
- props.put("zk.connect", hostname + ":" + port)
+ props.put("zk.connect", zkQuorum)
props.put("serializer.class", "kafka.serializer.StringEncoder")
val config = new ProducerConfig(props)
diff --git a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala
index 32f7d57bea..5ac6d19b34 100644
--- a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala
+++ b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala
@@ -16,7 +16,7 @@ import spark.streaming.StreamingContext._
*/
object NetworkWordCount {
def main(args: Array[String]) {
- if (args.length < 2) {
+ if (args.length < 3) {
System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" +
"In local mode, <master> should be 'local[n]' with n > 1")
System.exit(1)
@@ -27,7 +27,7 @@ object NetworkWordCount {
// Create a NetworkInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
- val lines = ssc.networkTextStream(args(1), args(2).toInt)
+ val lines = ssc.socketTextStream(args(1), args(2).toInt)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
wordCounts.print()
diff --git a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala
index 2eec777c54..66e709b7a3 100644
--- a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala
+++ b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala
@@ -37,7 +37,7 @@ object RawNetworkGrep {
RawTextHelper.warmUp(ssc.sc)
val rawStreams = (1 to numStreams).map(_ =>
- ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray
+ ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray
val union = ssc.union(rawStreams)
union.filter(_.contains("the")).count().foreach(r =>
println("Grep count: " + r.collect().mkString))
diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala
new file mode 100644
index 0000000000..39a1a702ee
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala
@@ -0,0 +1,93 @@
+package spark.streaming.examples
+
+import spark.streaming.{Seconds, StreamingContext}
+import spark.storage.StorageLevel
+import com.twitter.algebird._
+import spark.streaming.StreamingContext._
+import spark.SparkContext._
+
+/**
+ * Illustrates the use of the Count-Min Sketch, from Twitter's Algebird library, to compute
+ * windowed and global Top-K estimates of user IDs occurring in a Twitter stream.
+ * <br>
+ * <strong>Note</strong> that since Algebird's implementation currently only supports Long inputs,
+ * the example operates on Long IDs. Once the implementation supports other inputs (such as String),
+ * the same approach could be used for computing popular topics for example.
+ * <p>
+ * <p>
+ * <a href="http://highlyscalable.wordpress.com/2012/05/01/probabilistic-structures-web-analytics-data-mining/">
+ * This blog post</a> has a good overview of the Count-Min Sketch (CMS). The CMS is a datastructure
+ * for approximate frequency estimation in data streams (e.g. Top-K elements, frequency of any given element, etc),
+ * that uses space sub-linear in the number of elements in the stream. Once elements are added to the CMS, the
+ * estimated count of an element can be computed, as well as "heavy-hitters" that occur more than a threshold
+ * percentage of the overall total count.
+ * <p><p>
+ * Algebird's implementation is a monoid, so we can succinctly merge two CMS instances in the reduce operation.
+ */
+object TwitterAlgebirdCMS {
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: TwitterAlgebirdCMS <master> <twitter_username> <twitter_password>" +
+ " [filter1] [filter2] ... [filter n]")
+ System.exit(1)
+ }
+
+ // CMS parameters
+ val DELTA = 1E-3
+ val EPS = 0.01
+ val SEED = 1
+ val PERC = 0.001
+ // K highest frequency elements to take
+ val TOPK = 10
+
+ val Array(master, username, password) = args.slice(0, 3)
+ val filters = args.slice(3, args.length)
+
+ val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10))
+ val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER)
+
+ val users = stream.map(status => status.getUser.getId)
+
+ val cms = new CountMinSketchMonoid(DELTA, EPS, SEED, PERC)
+ var globalCMS = cms.zero
+ val mm = new MapMonoid[Long, Int]()
+ var globalExact = Map[Long, Int]()
+
+ val approxTopUsers = users.mapPartitions(ids => {
+ ids.map(id => cms.create(id))
+ }).reduce(_ ++ _)
+
+ val exactTopUsers = users.map(id => (id, 1))
+ .reduceByKey((a, b) => a + b)
+
+ approxTopUsers.foreach(rdd => {
+ if (rdd.count() != 0) {
+ val partial = rdd.first()
+ val partialTopK = partial.heavyHitters.map(id =>
+ (id, partial.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK)
+ globalCMS ++= partial
+ val globalTopK = globalCMS.heavyHitters.map(id =>
+ (id, globalCMS.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK)
+ println("Approx heavy hitters at %2.2f%% threshold this batch: %s".format(PERC,
+ partialTopK.mkString("[", ",", "]")))
+ println("Approx heavy hitters at %2.2f%% threshold overall: %s".format(PERC,
+ globalTopK.mkString("[", ",", "]")))
+ }
+ })
+
+ exactTopUsers.foreach(rdd => {
+ if (rdd.count() != 0) {
+ val partialMap = rdd.collect().toMap
+ val partialTopK = rdd.map(
+ {case (id, count) => (count, id)})
+ .sortByKey(ascending = false).take(TOPK)
+ globalExact = mm.plus(globalExact.toMap, partialMap)
+ val globalTopK = globalExact.toSeq.sortBy(_._2).reverse.slice(0, TOPK)
+ println("Exact heavy hitters this batch: %s".format(partialTopK.mkString("[", ",", "]")))
+ println("Exact heavy hitters overall: %s".format(globalTopK.mkString("[", ",", "]")))
+ }
+ })
+
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala
new file mode 100644
index 0000000000..914fba4ca2
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala
@@ -0,0 +1,71 @@
+package spark.streaming.examples
+
+import spark.streaming.{Seconds, StreamingContext}
+import spark.storage.StorageLevel
+import com.twitter.algebird.HyperLogLog._
+import com.twitter.algebird.HyperLogLogMonoid
+import spark.streaming.dstream.TwitterInputDStream
+
+/**
+ * Illustrates the use of the HyperLogLog algorithm, from Twitter's Algebird library, to compute
+ * a windowed and global estimate of the unique user IDs occurring in a Twitter stream.
+ * <p>
+ * <p>
+ * This <a href="http://highlyscalable.wordpress.com/2012/05/01/probabilistic-structures-web-analytics-data-mining/">
+ * blog post</a> and this
+ * <a href="http://highscalability.com/blog/2012/4/5/big-data-counting-how-to-count-a-billion-distinct-objects-us.html">blog post</a>
+ * have good overviews of HyperLogLog (HLL). HLL is a memory-efficient datastructure for estimating
+ * the cardinality of a data stream, i.e. the number of unique elements.
+ * <p><p>
+ * Algebird's implementation is a monoid, so we can succinctly merge two HLL instances in the reduce operation.
+ */
+object TwitterAlgebirdHLL {
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: TwitterAlgebirdHLL <master> <twitter_username> <twitter_password>" +
+ " [filter1] [filter2] ... [filter n]")
+ System.exit(1)
+ }
+
+ /** Bit size parameter for HyperLogLog, trades off accuracy vs size */
+ val BIT_SIZE = 12
+ val Array(master, username, password) = args.slice(0, 3)
+ val filters = args.slice(3, args.length)
+
+ val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5))
+ val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER)
+
+ val users = stream.map(status => status.getUser.getId)
+
+ val hll = new HyperLogLogMonoid(BIT_SIZE)
+ var globalHll = hll.zero
+ var userSet: Set[Long] = Set()
+
+ val approxUsers = users.mapPartitions(ids => {
+ ids.map(id => hll(id))
+ }).reduce(_ + _)
+
+ val exactUsers = users.map(id => Set(id)).reduce(_ ++ _)
+
+ approxUsers.foreach(rdd => {
+ if (rdd.count() != 0) {
+ val partial = rdd.first()
+ globalHll += partial
+ println("Approx distinct users this batch: %d".format(partial.estimatedSize.toInt))
+ println("Approx distinct users overall: %d".format(globalHll.estimatedSize.toInt))
+ }
+ })
+
+ exactUsers.foreach(rdd => {
+ if (rdd.count() != 0) {
+ val partial = rdd.first()
+ userSet ++= partial
+ println("Exact distinct users this batch: %d".format(partial.size))
+ println("Exact distinct users overall: %d".format(userSet.size))
+ println("Error rate: %2.5f%%".format(((globalHll.estimatedSize / userSet.size.toDouble) - 1) * 100))
+ }
+ })
+
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala
index 377bc0c98e..fdb3a4c73c 100644
--- a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala
+++ b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala
@@ -1,19 +1,19 @@
-package spark.streaming.examples.twitter
+package spark.streaming.examples
-import spark.streaming.StreamingContext._
import spark.streaming.{Seconds, StreamingContext}
+import StreamingContext._
import spark.SparkContext._
-import spark.storage.StorageLevel
/**
* Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter
* stream. The stream is instantiated with credentials and optionally filters supplied by the
* command line arguments.
+ *
*/
-object TwitterBasic {
+object TwitterPopularTags {
def main(args: Array[String]) {
if (args.length < 3) {
- System.err.println("Usage: TwitterBasic <master> <twitter_username> <twitter_password>" +
+ System.err.println("Usage: TwitterPopularTags <master> <twitter_username> <twitter_password>" +
" [filter1] [filter2] ... [filter n]")
System.exit(1)
}
@@ -21,10 +21,8 @@ object TwitterBasic {
val Array(master, username, password) = args.slice(0, 3)
val filters = args.slice(3, args.length)
- val ssc = new StreamingContext(master, "TwitterBasic", Seconds(2))
- val stream = new TwitterInputDStream(ssc, username, password, filters,
- StorageLevel.MEMORY_ONLY_SER)
- ssc.registerInputStream(stream)
+ val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2))
+ val stream = ssc.twitterStream(username, password, filters)
val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#")))
@@ -39,22 +37,17 @@ object TwitterBasic {
// Print popular hashtags
topCounts60.foreach(rdd => {
- if (rdd.count() != 0) {
- val topList = rdd.take(5)
- println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count()))
- topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))}
- }
+ val topList = rdd.take(5)
+ println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count()))
+ topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))}
})
topCounts10.foreach(rdd => {
- if (rdd.count() != 0) {
- val topList = rdd.take(5)
- println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count()))
- topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))}
- }
+ val topList = rdd.take(5)
+ println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count()))
+ topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))}
})
ssc.start()
}
-
}
diff --git a/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala
new file mode 100644
index 0000000000..5ed9b7cb76
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala
@@ -0,0 +1,73 @@
+package spark.streaming.examples
+
+import akka.actor.ActorSystem
+import akka.actor.actorRef2Scala
+import akka.zeromq._
+import spark.streaming.{ Seconds, StreamingContext }
+import spark.streaming.StreamingContext._
+import akka.zeromq.Subscribe
+
+/**
+ * A simple publisher for demonstration purposes, repeatedly publishes random Messages
+ * every one second.
+ */
+object SimpleZeroMQPublisher {
+
+ def main(args: Array[String]) = {
+ if (args.length < 2) {
+ System.err.println("Usage: SimpleZeroMQPublisher <zeroMQUrl> <topic> ")
+ System.exit(1)
+ }
+
+ val Seq(url, topic) = args.toSeq
+ val acs: ActorSystem = ActorSystem()
+
+ val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url))
+ val messages: Array[String] = Array("words ", "may ", "count ")
+ while (true) {
+ Thread.sleep(1000)
+ pubSocket ! ZMQMessage(Frame(topic) :: messages.map(x => Frame(x.getBytes)).toList)
+ }
+ acs.awaitTermination()
+ }
+}
+
+/**
+ * A sample wordcount with ZeroMQStream stream
+ *
+ * To work with zeroMQ, some native libraries have to be installed.
+ * Install zeroMQ (release 2.1) core libraries. [ZeroMQ Install guide](http://www.zeromq.org/intro:get-the-software)
+ *
+ * Usage: ZeroMQWordCount <master> <zeroMQurl> <topic>
+ * In local mode, <master> should be 'local[n]' with n > 1
+ * <zeroMQurl> and <topic> describe where zeroMq publisher is running.
+ *
+ * To run this example locally, you may run publisher as
+ * `$ ./run spark.streaming.examples.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar`
+ * and run the example as
+ * `$ ./run spark.streaming.examples.ZeroMQWordCount local[2] tcp://127.0.1.1:1234 foo`
+ */
+object ZeroMQWordCount {
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println(
+ "Usage: ZeroMQWordCount <master> <zeroMQurl> <topic>" +
+ "In local mode, <master> should be 'local[n]' with n > 1")
+ System.exit(1)
+ }
+ val Seq(master, url, topic) = args.toSeq
+
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(master, "ZeroMQWordCount", Seconds(2))
+
+ def bytesToStringIterator(x: Seq[Seq[Byte]]) = (x.map(x => new String(x.toArray))).iterator
+
+ //For this stream, a zeroMQ publisher should be running.
+ val lines = ssc.zeroMQStream(url, Subscribe(topic), bytesToStringIterator)
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.print()
+ ssc.start()
+ }
+
+} \ No newline at end of file
diff --git a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala
index a191321d91..fba72519a9 100644
--- a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala
+++ b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala
@@ -27,17 +27,16 @@ object PageViewStream {
val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1))
// Create a NetworkInputDStream on target host:port and convert each line to a PageView
- val pageViews = ssc.networkTextStream(host, port)
- .flatMap(_.split("\n"))
- .map(PageView.fromString(_))
+ val pageViews = ssc.socketTextStream(host, port)
+ .flatMap(_.split("\n"))
+ .map(PageView.fromString(_))
// Return a count of views per URL seen in each batch
- val pageCounts = pageViews.map(view => ((view.url, 1))).countByKey()
+ val pageCounts = pageViews.map(view => view.url).countByValue()
// Return a sliding window of page views per URL in the last ten seconds
- val slidingPageCounts = pageViews.map(view => ((view.url, 1)))
- .window(Seconds(10), Seconds(2))
- .countByKey()
+ val slidingPageCounts = pageViews.map(view => view.url)
+ .countByValueAndWindow(Seconds(10), Seconds(2))
// Return the rate of error pages (a non 200 status) in each zip code over the last 30 seconds
diff --git a/pom.xml b/pom.xml
index 7e06cae052..99eb17856a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -84,9 +84,9 @@
</snapshots>
</repository>
<repository>
- <id>typesafe-repo</id>
- <name>Typesafe Repository</name>
- <url>http://repo.typesafe.com/typesafe/releases/</url>
+ <id>akka-repo</id>
+ <name>Akka Repository</name>
+ <url>http://repo.akka.io/releases/</url>
<releases>
<enabled>true</enabled>
</releases>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index af8b5ba017..25c2328373 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -44,6 +44,9 @@ object SparkBuild extends Build {
transitiveClassifiers in Scope.GlobalScope := Seq("sources"),
testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))),
+ // shared between both core and streaming.
+ resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"),
+
// For Sonatype publishing
resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
"sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"),
@@ -114,7 +117,6 @@ object SparkBuild extends Build {
def coreSettings = sharedSettings ++ Seq(
name := "spark-core",
resolvers ++= Seq(
- "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/",
"JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/",
"Spray Repository" at "http://repo.spray.cc/",
"Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/",
@@ -155,9 +157,7 @@ object SparkBuild extends Build {
def examplesSettings = sharedSettings ++ Seq(
name := "spark-examples",
- libraryDependencies ++= Seq(
- "org.twitter4j" % "twitter4j-stream" % "3.0.3"
- )
+ libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.8")
)
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
@@ -166,7 +166,9 @@ object SparkBuild extends Build {
name := "spark-streaming",
libraryDependencies ++= Seq(
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile",
- "com.github.sgroschupf" % "zkclient" % "0.1"
+ "com.github.sgroschupf" % "zkclient" % "0.1",
+ "org.twitter4j" % "twitter4j-stream" % "3.0.3",
+ "com.typesafe.akka" % "akka-zeromq" % "2.0.3"
)
) ++ assemblySettings ++ extraAssemblySettings
diff --git a/pyspark b/pyspark
index ab7f4f50c0..d662e90287 100755
--- a/pyspark
+++ b/pyspark
@@ -36,4 +36,9 @@ if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then
export SPARK_LAUNCH_WITH_SCALA=1
fi
-exec "$PYSPARK_PYTHON" "$@"
+if [[ "$IPYTHON" = "1" ]] ; then
+ export PYSPARK_PYTHON="ipython"
+ exec "$PYSPARK_PYTHON" -i -c "%run $PYTHONSTARTUP"
+else
+ exec "$PYSPARK_PYTHON" "$@"
+fi
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index 7036c47980..5f4294fb1b 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -32,13 +32,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
-def _do_python_join(rdd, other, numSplits, dispatch):
+def _do_python_join(rdd, other, numPartitions, dispatch):
vs = rdd.map(lambda (k, v): (k, (1, v)))
ws = other.map(lambda (k, v): (k, (2, v)))
- return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch)
+ return vs.union(ws).groupByKey(numPartitions).flatMapValues(dispatch)
-def python_join(rdd, other, numSplits):
+def python_join(rdd, other, numPartitions):
def dispatch(seq):
vbuf, wbuf = [], []
for (n, v) in seq:
@@ -47,10 +47,10 @@ def python_join(rdd, other, numSplits):
elif n == 2:
wbuf.append(v)
return [(v, w) for v in vbuf for w in wbuf]
- return _do_python_join(rdd, other, numSplits, dispatch)
+ return _do_python_join(rdd, other, numPartitions, dispatch)
-def python_right_outer_join(rdd, other, numSplits):
+def python_right_outer_join(rdd, other, numPartitions):
def dispatch(seq):
vbuf, wbuf = [], []
for (n, v) in seq:
@@ -61,10 +61,10 @@ def python_right_outer_join(rdd, other, numSplits):
if not vbuf:
vbuf.append(None)
return [(v, w) for v in vbuf for w in wbuf]
- return _do_python_join(rdd, other, numSplits, dispatch)
+ return _do_python_join(rdd, other, numPartitions, dispatch)
-def python_left_outer_join(rdd, other, numSplits):
+def python_left_outer_join(rdd, other, numPartitions):
def dispatch(seq):
vbuf, wbuf = [], []
for (n, v) in seq:
@@ -75,10 +75,10 @@ def python_left_outer_join(rdd, other, numSplits):
if not wbuf:
wbuf.append(None)
return [(v, w) for v in vbuf for w in wbuf]
- return _do_python_join(rdd, other, numSplits, dispatch)
+ return _do_python_join(rdd, other, numPartitions, dispatch)
-def python_cogroup(rdd, other, numSplits):
+def python_cogroup(rdd, other, numPartitions):
vs = rdd.map(lambda (k, v): (k, (1, v)))
ws = other.map(lambda (k, v): (k, (2, v)))
def dispatch(seq):
@@ -89,4 +89,4 @@ def python_cogroup(rdd, other, numSplits):
elif n == 2:
wbuf.append(v)
return (vbuf, wbuf)
- return vs.union(ws).groupByKey(numSplits).mapValues(dispatch)
+ return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 4cda6cf661..172ed85fab 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -215,7 +215,7 @@ class RDD(object):
yield pair
return java_cartesian.flatMap(unpack_batches)
- def groupBy(self, f, numSplits=None):
+ def groupBy(self, f, numPartitions=None):
"""
Return an RDD of grouped items.
@@ -224,7 +224,7 @@ class RDD(object):
>>> sorted([(x, sorted(y)) for (x, y) in result])
[(0, [2, 8]), (1, [1, 1, 3, 5])]
"""
- return self.map(lambda x: (f(x), x)).groupByKey(numSplits)
+ return self.map(lambda x: (f(x), x)).groupByKey(numPartitions)
def pipe(self, command, env={}):
"""
@@ -274,8 +274,8 @@ class RDD(object):
def reduce(self, f):
"""
- Reduces the elements of this RDD using the specified associative binary
- operator.
+ Reduces the elements of this RDD using the specified commutative and
+ associative binary operator.
>>> from operator import add
>>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
@@ -422,22 +422,22 @@ class RDD(object):
"""
return dict(self.collect())
- def reduceByKey(self, func, numSplits=None):
+ def reduceByKey(self, func, numPartitions=None):
"""
Merge the values for each key using an associative reduce function.
This will also perform the merging locally on each mapper before
sending results to a reducer, similarly to a "combiner" in MapReduce.
- Output will be hash-partitioned with C{numSplits} splits, or the
- default parallelism level if C{numSplits} is not specified.
+ Output will be hash-partitioned with C{numPartitions} partitions, or
+ the default parallelism level if C{numPartitions} is not specified.
>>> from operator import add
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(rdd.reduceByKey(add).collect())
[('a', 2), ('b', 1)]
"""
- return self.combineByKey(lambda x: x, func, func, numSplits)
+ return self.combineByKey(lambda x: x, func, func, numPartitions)
def reduceByKeyLocally(self, func):
"""
@@ -474,7 +474,7 @@ class RDD(object):
"""
return self.map(lambda x: x[0]).countByValue()
- def join(self, other, numSplits=None):
+ def join(self, other, numPartitions=None):
"""
Return an RDD containing all pairs of elements with matching keys in
C{self} and C{other}.
@@ -489,9 +489,9 @@ class RDD(object):
>>> sorted(x.join(y).collect())
[('a', (1, 2)), ('a', (1, 3))]
"""
- return python_join(self, other, numSplits)
+ return python_join(self, other, numPartitions)
- def leftOuterJoin(self, other, numSplits=None):
+ def leftOuterJoin(self, other, numPartitions=None):
"""
Perform a left outer join of C{self} and C{other}.
@@ -506,9 +506,9 @@ class RDD(object):
>>> sorted(x.leftOuterJoin(y).collect())
[('a', (1, 2)), ('b', (4, None))]
"""
- return python_left_outer_join(self, other, numSplits)
+ return python_left_outer_join(self, other, numPartitions)
- def rightOuterJoin(self, other, numSplits=None):
+ def rightOuterJoin(self, other, numPartitions=None):
"""
Perform a right outer join of C{self} and C{other}.
@@ -523,10 +523,10 @@ class RDD(object):
>>> sorted(y.rightOuterJoin(x).collect())
[('a', (2, 1)), ('b', (None, 4))]
"""
- return python_right_outer_join(self, other, numSplits)
+ return python_right_outer_join(self, other, numPartitions)
# TODO: add option to control map-side combining
- def partitionBy(self, numSplits, partitionFunc=hash):
+ def partitionBy(self, numPartitions, partitionFunc=hash):
"""
Return a copy of the RDD partitioned using the specified partitioner.
@@ -535,22 +535,22 @@ class RDD(object):
>>> set(sets[0]).intersection(set(sets[1]))
set([])
"""
- if numSplits is None:
- numSplits = self.ctx.defaultParallelism
+ if numPartitions is None:
+ numPartitions = self.ctx.defaultParallelism
# Transferring O(n) objects to Java is too expensive. Instead, we'll
- # form the hash buckets in Python, transferring O(numSplits) objects
+ # form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
for (k, v) in iterator:
- buckets[partitionFunc(k) % numSplits].append((k, v))
+ buckets[partitionFunc(k) % numPartitions].append((k, v))
for (split, items) in buckets.iteritems():
yield str(split)
yield dump_pickle(Batch(items))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx._jvm.PythonPartitioner(numSplits,
+ partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx)
@@ -561,7 +561,7 @@ class RDD(object):
# TODO: add control over map-side aggregation
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
- numSplits=None):
+ numPartitions=None):
"""
Generic function to combine the elements for each key using a custom
set of aggregation functions.
@@ -586,8 +586,8 @@ class RDD(object):
>>> sorted(x.combineByKey(str, add, add).collect())
[('a', '11'), ('b', '1')]
"""
- if numSplits is None:
- numSplits = self.ctx.defaultParallelism
+ if numPartitions is None:
+ numPartitions = self.ctx.defaultParallelism
def combineLocally(iterator):
combiners = {}
for (k, v) in iterator:
@@ -597,7 +597,7 @@ class RDD(object):
combiners[k] = mergeValue(combiners[k], v)
return combiners.iteritems()
locally_combined = self.mapPartitions(combineLocally)
- shuffled = locally_combined.partitionBy(numSplits)
+ shuffled = locally_combined.partitionBy(numPartitions)
def _mergeCombiners(iterator):
combiners = {}
for (k, v) in iterator:
@@ -609,10 +609,10 @@ class RDD(object):
return shuffled.mapPartitions(_mergeCombiners)
# TODO: support variant with custom partitioner
- def groupByKey(self, numSplits=None):
+ def groupByKey(self, numPartitions=None):
"""
Group the values for each key in the RDD into a single sequence.
- Hash-partitions the resulting RDD with into numSplits partitions.
+ Hash-partitions the resulting RDD with into numPartitions partitions.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(x.groupByKey().collect())
@@ -630,7 +630,7 @@ class RDD(object):
return a + b
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
- numSplits)
+ numPartitions)
# TODO: add tests
def flatMapValues(self, f):
@@ -659,7 +659,7 @@ class RDD(object):
return self.cogroup(other)
# TODO: add variant with custom parittioner
- def cogroup(self, other, numSplits=None):
+ def cogroup(self, other, numPartitions=None):
"""
For each key k in C{self} or C{other}, return a resulting RDD that
contains a tuple with the list of values for that key in C{self} as well
@@ -670,7 +670,7 @@ class RDD(object):
>>> sorted(x.cogroup(y).collect())
[('a', ([1], [2])), ('b', ([4], []))]
"""
- return python_cogroup(self, other, numSplits)
+ return python_cogroup(self, other, numPartitions)
# TODO: `lookup` is disabled because we can't make direct comparisons based
# on the key; we need to compare the hash of the key to the hash of the
diff --git a/run b/run
index a094629449..ecbf7673c6 100755
--- a/run
+++ b/run
@@ -13,6 +13,38 @@ if [ -e $FWDIR/conf/spark-env.sh ] ; then
. $FWDIR/conf/spark-env.sh
fi
+if [ -z "$1" ]; then
+ echo "Usage: run <spark-class> [<args>]" >&2
+ exit 1
+fi
+
+# If this is a standalone cluster daemon, reset SPARK_JAVA_OPTS and SPARK_MEM to reasonable
+# values for that; it doesn't need a lot
+if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then
+ SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m}
+ SPARK_JAVA_OPTS=$SPARK_DAEMON_JAVA_OPTS # Empty by default
+fi
+
+
+# Add java opts for master, worker, executor. The opts maybe null
+case "$1" in
+ 'spark.deploy.master.Master')
+ SPARK_JAVA_OPTS+=" $SPARK_MASTER_OPTS"
+ ;;
+ 'spark.deploy.worker.Worker')
+ SPARK_JAVA_OPTS+=" $SPARK_WORKER_OPTS"
+ ;;
+ 'spark.executor.StandaloneExecutorBackend')
+ SPARK_JAVA_OPTS+=" $SPARK_EXECUTOR_OPTS"
+ ;;
+ 'spark.executor.MesosExecutorBackend')
+ SPARK_JAVA_OPTS+=" $SPARK_EXECUTOR_OPTS"
+ ;;
+ 'spark.repl.Main')
+ SPARK_JAVA_OPTS+=" $SPARK_REPL_OPTS"
+ ;;
+esac
+
if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
if [ `command -v scala` ]; then
RUNNER="scala"
@@ -79,11 +111,13 @@ CLASSPATH+=":$FWDIR/conf"
CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/classes"
if [ -n "$SPARK_TESTING" ] ; then
CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/test-classes"
+ CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes"
fi
CLASSPATH+=":$CORE_DIR/src/main/resources"
CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
+CLASSPATH+=":$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar
if [ -e "$FWDIR/lib_managed" ]; then
CLASSPATH+=":$FWDIR/lib_managed/jars/*"
CLASSPATH+=":$FWDIR/lib_managed/bundles/*"
diff --git a/run2.cmd b/run2.cmd
index 67f1e465e4..705a4d1ff6 100644
--- a/run2.cmd
+++ b/run2.cmd
@@ -11,9 +11,22 @@ set SPARK_HOME=%FWDIR%
rem Load environment variables from conf\spark-env.cmd, if it exists
if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
+rem Test that an argument was given
+if not "x%1"=="x" goto arg_given
+ echo Usage: run ^<spark-class^> [^<args^>]
+ goto exit
+:arg_given
+
+set RUNNING_DAEMON=0
+if "%1"=="spark.deploy.master.Master" set RUNNING_DAEMON=1
+if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1
+if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m
+if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY%
+if "%RUNNING_DAEMON%"=="1" set SPARK_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
+
rem Check that SCALA_HOME has been specified
if not "x%SCALA_HOME%"=="x" goto scala_exists
- echo "SCALA_HOME is not set"
+ echo SCALA_HOME is not set
goto exit
:scala_exists
@@ -34,16 +47,19 @@ set CORE_DIR=%FWDIR%core
set REPL_DIR=%FWDIR%repl
set EXAMPLES_DIR=%FWDIR%examples
set BAGEL_DIR=%FWDIR%bagel
+set STREAMING_DIR=%FWDIR%streaming
set PYSPARK_DIR=%FWDIR%python
rem Build up classpath
set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources
+set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes
+set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\*
set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes
-for /R "%FWDIR%\lib_managed\jars" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
-for /R "%FWDIR%\lib_managed\bundles" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
-for /R "%REPL_DIR%\lib" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
-for /R "%PYSPARK_DIR%\lib" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
+set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\*
+set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\*
+set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\*
+set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\*
set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
rem Figure out whether to run our class with java or with the scala launcher.
diff --git a/sbt/sbt.cmd b/sbt/sbt.cmd
index 6b289ab447..ce3ae70174 100644
--- a/sbt/sbt.cmd
+++ b/sbt/sbt.cmd
@@ -2,4 +2,4 @@
set EXTRA_ARGS=
if not "%MESOS_HOME%x"=="x" set EXTRA_ARGS=-Djava.library.path=%MESOS_HOME%\lib\java
set SPARK_HOME=%~dp0..
-java -Xmx1200M -XX:MaxPermSize=200m %EXTRA_ARGS% -jar %SPARK_HOME%\sbt\sbt-launch-*.jar "%*"
+java -Xmx1200M -XX:MaxPermSize=200m %EXTRA_ARGS% -jar %SPARK_HOME%\sbt\sbt-launch-0.11.3-2.jar "%*"
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 6ee7e59df3..15523eadcb 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -47,7 +47,16 @@
<artifactId>zkclient</artifactId>
<version>0.1</version>
</dependency>
-
+ <dependency>
+ <groupId>org.twitter4j</groupId>
+ <artifactId>twitter4j-stream</artifactId>
+ <version>3.0.3</version>
+ </dependency>
+ <dependency>
+ <groupId>com.typesafe.akka</groupId>
+ <artifactId>akka-zeromq</artifactId>
+ <version>2.0.3</version>
+ </dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.version}</artifactId>
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
index 2f3adb39c2..e7a392fbbf 100644
--- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -6,18 +6,21 @@ import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
import java.io._
+import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
+import java.util.concurrent.Executors
private[streaming]
class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
extends Logging with Serializable {
val master = ssc.sc.master
- val framework = ssc.sc.jobName
+ val framework = ssc.sc.appName
val sparkHome = ssc.sc.sparkHome
val jars = ssc.sc.jars
val graph = ssc.graph
val checkpointDir = ssc.checkpointDir
- val checkpointDuration: Duration = ssc.checkpointDuration
+ val checkpointDuration = ssc.checkpointDuration
+ val pendingTimes = ssc.scheduler.jobManager.getPendingTimes()
def validate() {
assert(master != null, "Checkpoint.master is null")
@@ -37,32 +40,50 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
val conf = new Configuration()
var fs = file.getFileSystem(conf)
val maxAttempts = 3
+ val executor = Executors.newFixedThreadPool(1)
- def write(checkpoint: Checkpoint) {
- // TODO: maybe do this in a different thread from the main stream execution thread
- var attempts = 0
- while (attempts < maxAttempts) {
- attempts += 1
- try {
- logDebug("Saving checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'")
- if (fs.exists(file)) {
- val bkFile = new Path(file.getParent, file.getName + ".bk")
- FileUtil.copy(fs, file, fs, bkFile, true, true, conf)
- logDebug("Moved existing checkpoint file to " + bkFile)
+ class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable {
+ def run() {
+ var attempts = 0
+ val startTime = System.currentTimeMillis()
+ while (attempts < maxAttempts) {
+ attempts += 1
+ try {
+ logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'")
+ if (fs.exists(file)) {
+ val bkFile = new Path(file.getParent, file.getName + ".bk")
+ FileUtil.copy(fs, file, fs, bkFile, true, true, conf)
+ logDebug("Moved existing checkpoint file to " + bkFile)
+ }
+ val fos = fs.create(file)
+ fos.write(bytes)
+ fos.close()
+ fos.close()
+ val finishTime = System.currentTimeMillis();
+ logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file +
+ "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds")
+ return
+ } catch {
+ case ioe: IOException =>
+ logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe)
}
- val fos = fs.create(file)
- val oos = new ObjectOutputStream(fos)
- oos.writeObject(checkpoint)
- oos.close()
- logInfo("Checkpoint for time " + checkpoint.checkpointTime + " saved to file '" + file + "'")
- fos.close()
- return
- } catch {
- case ioe: IOException =>
- logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe)
}
+ logError("Could not write checkpoint for time " + checkpointTime + " to file '" + file + "'")
}
- logError("Could not write checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'")
+ }
+
+ def write(checkpoint: Checkpoint) {
+ val bos = new ByteArrayOutputStream()
+ val zos = new LZFOutputStream(bos)
+ val oos = new ObjectOutputStream(zos)
+ oos.writeObject(checkpoint)
+ oos.close()
+ bos.close()
+ executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray))
+ }
+
+ def stop() {
+ executor.shutdown()
}
}
@@ -84,7 +105,8 @@ object CheckpointReader extends Logging {
// of ObjectInputStream is used to explicitly use the current thread's default class
// loader to find and load classes. This is a well know Java issue and has popped up
// in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
- val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader)
+ val zis = new LZFInputStream(fis)
+ val ois = new ObjectInputStreamWithLoader(zis, Thread.currentThread().getContextClassLoader)
val cp = ois.readObject.asInstanceOf[Checkpoint]
ois.close()
fs.close()
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index 352f83fe0c..e1be5ef51c 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -12,7 +12,7 @@ import scala.collection.mutable.HashMap
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.conf.Configuration
/**
@@ -75,7 +75,7 @@ abstract class DStream[T: ClassManifest] (
// Checkpoint details
protected[streaming] val mustCheckpoint = false
protected[streaming] var checkpointDuration: Duration = null
- protected[streaming] var checkpointData = new DStreamCheckpointData(HashMap[Time, Any]())
+ protected[streaming] val checkpointData = new DStreamCheckpointData(this)
// Reference to whole DStream graph
protected[streaming] var graph: DStreamGraph = null
@@ -85,10 +85,10 @@ abstract class DStream[T: ClassManifest] (
// Duration for which the DStream requires its parent DStream to remember each RDD created
protected[streaming] def parentRememberDuration = rememberDuration
- /** Returns the StreamingContext associated with this DStream */
- def context() = ssc
+ /** Return the StreamingContext associated with this DStream */
+ def context = ssc
- /** Persists the RDDs of this DStream with the given storage level */
+ /** Persist the RDDs of this DStream with the given storage level */
def persist(level: StorageLevel): DStream[T] = {
if (this.isInitialized) {
throw new UnsupportedOperationException(
@@ -132,7 +132,7 @@ abstract class DStream[T: ClassManifest] (
// Set the checkpoint interval to be slideDuration or 10 seconds, which ever is larger
if (mustCheckpoint && checkpointDuration == null) {
- checkpointDuration = slideDuration.max(Seconds(10))
+ checkpointDuration = slideDuration * math.ceil(Seconds(10) / slideDuration).toInt
logInfo("Checkpoint interval automatically set to " + checkpointDuration)
}
@@ -159,7 +159,7 @@ abstract class DStream[T: ClassManifest] (
)
assert(
- checkpointDuration == null || ssc.sc.checkpointDir.isDefined,
+ checkpointDuration == null || context.sparkContext.checkpointDir.isDefined,
"The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" +
" or SparkContext.checkpoint() to set the checkpoint directory."
)
@@ -238,13 +238,15 @@ abstract class DStream[T: ClassManifest] (
dependencies.foreach(_.remember(parentRememberDuration))
}
- /** This method checks whether the 'time' is valid wrt slideDuration for generating RDD */
+ /** Checks whether the 'time' is valid wrt slideDuration for generating RDD */
protected def isTimeValid(time: Time): Boolean = {
if (!isInitialized) {
throw new Exception (this + " has not been initialized")
} else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) {
+ logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime))
false
} else {
+ logInfo("Time " + time + " is valid")
true
}
}
@@ -292,14 +294,14 @@ abstract class DStream[T: ClassManifest] (
* Generate a SparkStreaming job for the given time. This is an internal method that
* should not be called directly. This default implementation creates a job
* that materializes the corresponding RDD. Subclasses of DStream may override this
- * (eg. ForEachDStream).
+ * to generate their own jobs.
*/
protected[streaming] def generateJob(time: Time): Option[Job] = {
getOrCompute(time) match {
case Some(rdd) => {
val jobFunc = () => {
- val emptyFunc = { (iterator: Iterator[T]) => {} }
- ssc.sc.runJob(rdd, emptyFunc)
+ val emptyFunc = { (iterator: Iterator[T]) => {} }
+ context.sparkContext.runJob(rdd, emptyFunc)
}
Some(new Job(time, jobFunc))
}
@@ -308,20 +310,18 @@ abstract class DStream[T: ClassManifest] (
}
/**
- * Dereference RDDs that are older than rememberDuration.
+ * Clear metadata that are older than `rememberDuration` of this DStream.
+ * This is an internal method that should not be called directly. This default
+ * implementation clears the old generated RDDs. Subclasses of DStream may override
+ * this to clear their own metadata along with the generated RDDs.
*/
- protected[streaming] def forgetOldRDDs(time: Time) {
- val keys = generatedRDDs.keys
+ protected[streaming] def clearOldMetadata(time: Time) {
var numForgotten = 0
- keys.foreach(t => {
- if (t <= (time - rememberDuration)) {
- generatedRDDs.remove(t)
- numForgotten += 1
- logInfo("Forgot RDD of time " + t + " from " + this)
- }
- })
- logInfo("Forgot " + numForgotten + " RDDs from " + this)
- dependencies.foreach(_.forgetOldRDDs(time))
+ val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration))
+ generatedRDDs --= oldRDDs.keys
+ logInfo("Cleared " + oldRDDs.size + " RDDs that were older than " +
+ (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", "))
+ dependencies.foreach(_.clearOldMetadata(time))
}
/* Adds metadata to the Stream while it is running.
@@ -342,40 +342,10 @@ abstract class DStream[T: ClassManifest] (
*/
protected[streaming] def updateCheckpointData(currentTime: Time) {
logInfo("Updating checkpoint data for time " + currentTime)
-
- // Get the checkpointed RDDs from the generated RDDs
- val newRdds = generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
- .map(x => (x._1, x._2.getCheckpointFile.get))
-
- // Make a copy of the existing checkpoint data (checkpointed RDDs)
- val oldRdds = checkpointData.rdds.clone()
-
- // If the new checkpoint data has checkpoints then replace existing with the new one
- if (newRdds.size > 0) {
- checkpointData.rdds.clear()
- checkpointData.rdds ++= newRdds
- }
-
- // Make parent DStreams update their checkpoint data
+ checkpointData.update()
dependencies.foreach(_.updateCheckpointData(currentTime))
-
- // TODO: remove this, this is just for debugging
- newRdds.foreach {
- case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") }
- }
-
- if (newRdds.size > 0) {
- (oldRdds -- newRdds.keySet).foreach {
- case (time, data) => {
- val path = new Path(data.toString)
- val fs = path.getFileSystem(new Configuration())
- fs.delete(path, true)
- logInfo("Deleted checkpoint file '" + path + "' for time " + time)
- }
- }
- }
- logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.rdds.size + " checkpoints, "
- + "[" + checkpointData.rdds.mkString(",") + "]")
+ checkpointData.cleanup()
+ logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData)
}
/**
@@ -386,14 +356,8 @@ abstract class DStream[T: ClassManifest] (
*/
protected[streaming] def restoreCheckpointData() {
// Create RDDs from the checkpoint data
- logInfo("Restoring checkpoint data from " + checkpointData.rdds.size + " checkpointed RDDs")
- checkpointData.rdds.foreach {
- case(time, data) => {
- logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'")
- val rdd = ssc.sc.checkpointFile[T](data.toString)
- generatedRDDs += ((time, rdd))
- }
- }
+ logInfo("Restoring checkpoint data")
+ checkpointData.restore()
dependencies.foreach(_.restoreCheckpointData())
logInfo("Restored checkpoint data")
}
@@ -433,7 +397,7 @@ abstract class DStream[T: ClassManifest] (
/** Return a new DStream by applying a function to all elements of this DStream. */
def map[U: ClassManifest](mapFunc: T => U): DStream[U] = {
- new MappedDStream(this, ssc.sc.clean(mapFunc))
+ new MappedDStream(this, context.sparkContext.clean(mapFunc))
}
/**
@@ -441,7 +405,7 @@ abstract class DStream[T: ClassManifest] (
* and then flattening the results
*/
def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = {
- new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc))
+ new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc))
}
/** Return a new DStream containing only the elements that satisfy a predicate. */
@@ -463,7 +427,7 @@ abstract class DStream[T: ClassManifest] (
mapPartFunc: Iterator[T] => Iterator[U],
preservePartitioning: Boolean = false
): DStream[U] = {
- new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc), preservePartitioning)
+ new MapPartitionedDStream(this, context.sparkContext.clean(mapPartFunc), preservePartitioning)
}
/**
@@ -480,6 +444,15 @@ abstract class DStream[T: ClassManifest] (
def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _)
/**
+ * Return a new DStream in which each RDD contains the counts of each distinct value in
+ * each RDD of this DStream. Hash partitioning is used to generate
+ * the RDDs with `numPartitions` partitions (Spark's default number of partitions if
+ * `numPartitions` not specified).
+ */
+ def countByValue(numPartitions: Int = ssc.sc.defaultParallelism): DStream[(T, Long)] =
+ this.map(x => (x, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions)
+
+ /**
* Apply a function to each RDD in this DStream. This is an output operator, so
* this DStream will be registered as an output stream and therefore materialized.
*/
@@ -492,7 +465,7 @@ abstract class DStream[T: ClassManifest] (
* this DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: (RDD[T], Time) => Unit) {
- val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc))
+ val newStream = new ForEachDStream(this, context.sparkContext.clean(foreachFunc))
ssc.registerOutputStream(newStream)
newStream
}
@@ -510,7 +483,7 @@ abstract class DStream[T: ClassManifest] (
* on each RDD of this DStream.
*/
def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
- new TransformedDStream(this, ssc.sc.clean(transformFunc))
+ new TransformedDStream(this, context.sparkContext.clean(transformFunc))
}
/**
@@ -527,19 +500,21 @@ abstract class DStream[T: ClassManifest] (
if (first11.size > 10) println("...")
println()
}
- val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc))
+ val newStream = new ForEachDStream(this, context.sparkContext.clean(foreachFunc))
ssc.registerOutputStream(newStream)
}
/**
- * Return a new DStream which is computed based on windowed batches of this DStream.
- * The new DStream generates RDDs with the same interval as this DStream.
+ * Return a new DStream in which each RDD contains all the elements in seen in a
+ * sliding window of time over this DStream. The new DStream generates RDDs with
+ * the same interval as this DStream.
* @param windowDuration width of the window; must be a multiple of this DStream's interval.
*/
def window(windowDuration: Duration): DStream[T] = window(windowDuration, this.slideDuration)
/**
- * Return a new DStream which is computed based on windowed batches of this DStream.
+ * Return a new DStream in which each RDD contains all the elements in seen in a
+ * sliding window of time over this DStream.
* @param windowDuration width of the window; must be a multiple of this DStream's
* batching interval
* @param slideDuration sliding interval of the window (i.e., the interval after which
@@ -551,27 +526,39 @@ abstract class DStream[T: ClassManifest] (
}
/**
- * Return a new DStream which computed based on tumbling window on this DStream.
- * This is equivalent to window(batchTime, batchTime).
- * @param batchDuration tumbling window duration; must be a multiple of this DStream's
- * batching interval
- */
- def tumble(batchDuration: Duration): DStream[T] = window(batchDuration, batchDuration)
-
- /**
* Return a new DStream in which each RDD has a single element generated by reducing all
- * elements in a window over this DStream. windowDuration and slideDuration are as defined
- * in the window() operation. This is equivalent to
- * window(windowDuration, slideDuration).reduce(reduceFunc)
+ * elements in a sliding window over this DStream.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
*/
def reduceByWindow(
reduceFunc: (T, T) => T,
windowDuration: Duration,
slideDuration: Duration
): DStream[T] = {
- this.window(windowDuration, slideDuration).reduce(reduceFunc)
+ this.reduce(reduceFunc).window(windowDuration, slideDuration).reduce(reduceFunc)
}
+ /**
+ * Return a new DStream in which each RDD has a single element generated by reducing all
+ * elements in a sliding window over this DStream. However, the reduction is done incrementally
+ * using the old window's reduced value :
+ * 1. reduce the new values that entered the window (e.g., adding new counts)
+ * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+ * This is more efficient than reduceByWindow without "inverse reduce" function.
+ * However, it is applicable to only "invertible reduce functions".
+ * @param reduceFunc associative reduce function
+ * @param invReduceFunc inverse reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
def reduceByWindow(
reduceFunc: (T, T) => T,
invReduceFunc: (T, T) => T,
@@ -585,14 +572,47 @@ abstract class DStream[T: ClassManifest] (
/**
* Return a new DStream in which each RDD has a single element generated by counting the number
- * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the
- * window() operation. This is equivalent to window(windowDuration, slideDuration).count()
+ * of elements in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with
+ * Spark's default number of partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
*/
def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = {
this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration)
}
/**
+ * Return a new DStream in which each RDD contains the count of distinct elements in
+ * RDDs in a sliding window over this DStream. Hash partitioning is used to generate
+ * the RDDs with `numPartitions` partitions (Spark's default number of partitions if
+ * `numPartitions` not specified).
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param numPartitions number of partitions of each RDD in the new DStream.
+ */
+ def countByValueAndWindow(
+ windowDuration: Duration,
+ slideDuration: Duration,
+ numPartitions: Int = ssc.sc.defaultParallelism
+ ): DStream[(T, Long)] = {
+
+ this.map(x => (x, 1L)).reduceByKeyAndWindow(
+ (x: Long, y: Long) => x + y,
+ (x: Long, y: Long) => x - y,
+ windowDuration,
+ slideDuration,
+ numPartitions,
+ (x: (T, Long)) => x._2 != 0L
+ )
+ }
+
+ /**
* Return a new DStream by unifying data of another DStream with this DStream.
* @param that Another DStream having the same slideDuration as this DStream.
*/
@@ -609,16 +629,21 @@ abstract class DStream[T: ClassManifest] (
* Return all the RDDs between 'fromTime' to 'toTime' (both included)
*/
def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = {
- val rdds = new ArrayBuffer[RDD[T]]()
- var time = toTime.floor(slideDuration)
- while (time >= zeroTime && time >= fromTime) {
- getOrCompute(time) match {
- case Some(rdd) => rdds += rdd
- case None => //throw new Exception("Could not get RDD for time " + time)
- }
- time -= slideDuration
+ if (!(fromTime - zeroTime).isMultipleOf(slideDuration)) {
+ logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")")
+ }
+ if (!(toTime - zeroTime).isMultipleOf(slideDuration)) {
+ logWarning("toTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")")
}
- rdds.toSeq
+ val alignedToTime = toTime.floor(slideDuration)
+ val alignedFromTime = fromTime.floor(slideDuration)
+
+ logInfo("Slicing from " + fromTime + " to " + toTime +
+ " (aligned to " + alignedFromTime + " and " + alignedToTime + ")")
+
+ alignedFromTime.to(alignedToTime, slideDuration).flatMap(time => {
+ if (time >= zeroTime) getOrCompute(time) else None
+ })
}
/**
@@ -651,7 +676,3 @@ abstract class DStream[T: ClassManifest] (
ssc.registerOutputStream(this)
}
}
-
-private[streaming]
-case class DStreamCheckpointData(rdds: HashMap[Time, Any])
-
diff --git a/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala
new file mode 100644
index 0000000000..6b0fade7c6
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala
@@ -0,0 +1,93 @@
+package spark.streaming
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.conf.Configuration
+import collection.mutable.HashMap
+import spark.Logging
+
+
+
+private[streaming]
+class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T])
+ extends Serializable with Logging {
+ protected val data = new HashMap[Time, AnyRef]()
+
+ @transient private var fileSystem : FileSystem = null
+ @transient private var lastCheckpointFiles: HashMap[Time, String] = null
+
+ protected[streaming] def checkpointFiles = data.asInstanceOf[HashMap[Time, String]]
+
+ /**
+ * Updates the checkpoint data of the DStream. This gets called every time
+ * the graph checkpoint is initiated. Default implementation records the
+ * checkpoint files to which the generate RDDs of the DStream has been saved.
+ */
+ def update() {
+
+ // Get the checkpointed RDDs from the generated RDDs
+ val newCheckpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
+ .map(x => (x._1, x._2.getCheckpointFile.get))
+
+ // Make a copy of the existing checkpoint data (checkpointed RDDs)
+ lastCheckpointFiles = checkpointFiles.clone()
+
+ // If the new checkpoint data has checkpoints then replace existing with the new one
+ if (newCheckpointFiles.size > 0) {
+ checkpointFiles.clear()
+ checkpointFiles ++= newCheckpointFiles
+ }
+
+ // TODO: remove this, this is just for debugging
+ newCheckpointFiles.foreach {
+ case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") }
+ }
+ }
+
+ /**
+ * Cleanup old checkpoint data. This gets called every time the graph
+ * checkpoint is initiated, but after `update` is called. Default
+ * implementation, cleans up old checkpoint files.
+ */
+ def cleanup() {
+ // If there is at least on checkpoint file in the current checkpoint files,
+ // then delete the old checkpoint files.
+ if (checkpointFiles.size > 0 && lastCheckpointFiles != null) {
+ (lastCheckpointFiles -- checkpointFiles.keySet).foreach {
+ case (time, file) => {
+ try {
+ val path = new Path(file)
+ if (fileSystem == null) {
+ fileSystem = path.getFileSystem(new Configuration())
+ }
+ fileSystem.delete(path, true)
+ logInfo("Deleted checkpoint file '" + file + "' for time " + time)
+ } catch {
+ case e: Exception =>
+ logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e)
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Restore the checkpoint data. This gets called once when the DStream graph
+ * (along with its DStreams) are being restored from a graph checkpoint file.
+ * Default implementation restores the RDDs from their checkpoint files.
+ */
+ def restore() {
+ // Create RDDs from the checkpoint data
+ checkpointFiles.foreach {
+ case(time, file) => {
+ logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'")
+ dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file)))
+ }
+ }
+ }
+
+ override def toString() = {
+ "[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]"
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
index bc4a40d7bc..adb7f3a24d 100644
--- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -11,17 +11,20 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
private val inputStreams = new ArrayBuffer[InputDStream[_]]()
private val outputStreams = new ArrayBuffer[DStream[_]]()
- private[streaming] var zeroTime: Time = null
- private[streaming] var batchDuration: Duration = null
- private[streaming] var rememberDuration: Duration = null
- private[streaming] var checkpointInProgress = false
+ var rememberDuration: Duration = null
+ var checkpointInProgress = false
- private[streaming] def start(time: Time) {
+ var zeroTime: Time = null
+ var startTime: Time = null
+ var batchDuration: Duration = null
+
+ def start(time: Time) {
this.synchronized {
if (zeroTime != null) {
throw new Exception("DStream graph computation already started")
}
zeroTime = time
+ startTime = time
outputStreams.foreach(_.initialize(zeroTime))
outputStreams.foreach(_.remember(rememberDuration))
outputStreams.foreach(_.validate)
@@ -29,19 +32,23 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
}
}
- private[streaming] def stop() {
+ def restart(time: Time) {
+ this.synchronized { startTime = time }
+ }
+
+ def stop() {
this.synchronized {
inputStreams.par.foreach(_.stop())
}
}
- private[streaming] def setContext(ssc: StreamingContext) {
+ def setContext(ssc: StreamingContext) {
this.synchronized {
outputStreams.foreach(_.setContext(ssc))
}
}
- private[streaming] def setBatchDuration(duration: Duration) {
+ def setBatchDuration(duration: Duration) {
this.synchronized {
if (batchDuration != null) {
throw new Exception("Batch duration already set as " + batchDuration +
@@ -51,59 +58,68 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
batchDuration = duration
}
- private[streaming] def remember(duration: Duration) {
+ def remember(duration: Duration) {
this.synchronized {
if (rememberDuration != null) {
throw new Exception("Batch duration already set as " + batchDuration +
". cannot set it again.")
}
+ rememberDuration = duration
}
- rememberDuration = duration
}
- private[streaming] def addInputStream(inputStream: InputDStream[_]) {
+ def addInputStream(inputStream: InputDStream[_]) {
this.synchronized {
inputStream.setGraph(this)
inputStreams += inputStream
}
}
- private[streaming] def addOutputStream(outputStream: DStream[_]) {
+ def addOutputStream(outputStream: DStream[_]) {
this.synchronized {
outputStream.setGraph(this)
outputStreams += outputStream
}
}
- private[streaming] def getInputStreams() = this.synchronized { inputStreams.toArray }
+ def getInputStreams() = this.synchronized { inputStreams.toArray }
- private[streaming] def getOutputStreams() = this.synchronized { outputStreams.toArray }
+ def getOutputStreams() = this.synchronized { outputStreams.toArray }
- private[streaming] def generateRDDs(time: Time): Seq[Job] = {
+ def generateJobs(time: Time): Seq[Job] = {
this.synchronized {
- outputStreams.flatMap(outputStream => outputStream.generateJob(time))
+ logInfo("Generating jobs for time " + time)
+ val jobs = outputStreams.flatMap(outputStream => outputStream.generateJob(time))
+ logInfo("Generated " + jobs.length + " jobs for time " + time)
+ jobs
}
}
- private[streaming] def forgetOldRDDs(time: Time) {
+ def clearOldMetadata(time: Time) {
this.synchronized {
- outputStreams.foreach(_.forgetOldRDDs(time))
+ logInfo("Clearing old metadata for time " + time)
+ outputStreams.foreach(_.clearOldMetadata(time))
+ logInfo("Cleared old metadata for time " + time)
}
}
- private[streaming] def updateCheckpointData(time: Time) {
+ def updateCheckpointData(time: Time) {
this.synchronized {
+ logInfo("Updating checkpoint data for time " + time)
outputStreams.foreach(_.updateCheckpointData(time))
+ logInfo("Updated checkpoint data for time " + time)
}
}
- private[streaming] def restoreCheckpointData() {
+ def restoreCheckpointData() {
this.synchronized {
+ logInfo("Restoring checkpoint data")
outputStreams.foreach(_.restoreCheckpointData())
+ logInfo("Restored checkpoint data")
}
}
- private[streaming] def validate() {
+ def validate() {
this.synchronized {
assert(batchDuration != null, "Batch duration has not been set")
//assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + " is very low")
diff --git a/streaming/src/main/scala/spark/streaming/Duration.scala b/streaming/src/main/scala/spark/streaming/Duration.scala
index e4dc579a17..ee26206e24 100644
--- a/streaming/src/main/scala/spark/streaming/Duration.scala
+++ b/streaming/src/main/scala/spark/streaming/Duration.scala
@@ -16,7 +16,7 @@ case class Duration (private val millis: Long) {
def * (times: Int): Duration = new Duration(millis * times)
- def / (that: Duration): Long = millis / that.millis
+ def / (that: Duration): Double = millis.toDouble / that.millis.toDouble
def isMultipleOf(that: Duration): Boolean =
(this.millis % that.millis == 0)
diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala
index dc21dfb722..6a8b81760e 100644
--- a/streaming/src/main/scala/spark/streaming/Interval.scala
+++ b/streaming/src/main/scala/spark/streaming/Interval.scala
@@ -30,6 +30,7 @@ class Interval(val beginTime: Time, val endTime: Time) {
override def toString = "[" + beginTime + ", " + endTime + "]"
}
+private[streaming]
object Interval {
def currentInterval(duration: Duration): Interval = {
val time = new Time(System.currentTimeMillis)
diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala
index 3b910538e0..7696c4a592 100644
--- a/streaming/src/main/scala/spark/streaming/JobManager.scala
+++ b/streaming/src/main/scala/spark/streaming/JobManager.scala
@@ -3,6 +3,8 @@ package spark.streaming
import spark.Logging
import spark.SparkEnv
import java.util.concurrent.Executors
+import collection.mutable.HashMap
+import collection.mutable.ArrayBuffer
private[streaming]
@@ -13,21 +15,57 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging {
SparkEnv.set(ssc.env)
try {
val timeTaken = job.run()
- logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format(
- (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, timeTaken / 1000.0))
+ logInfo("Total delay: %.5f s for job %s of time %s (execution: %.5f s)".format(
+ (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, job.time.milliseconds, timeTaken / 1000.0))
} catch {
case e: Exception =>
logError("Running " + job + " failed", e)
}
+ clearJob(job)
}
}
initLogging()
val jobExecutor = Executors.newFixedThreadPool(numThreads)
-
+ val jobs = new HashMap[Time, ArrayBuffer[Job]]
+
def runJob(job: Job) {
+ jobs.synchronized {
+ jobs.getOrElseUpdate(job.time, new ArrayBuffer[Job]) += job
+ }
jobExecutor.execute(new JobHandler(ssc, job))
logInfo("Added " + job + " to queue")
}
+
+ def stop() {
+ jobExecutor.shutdown()
+ }
+
+ private def clearJob(job: Job) {
+ var timeCleared = false
+ val time = job.time
+ jobs.synchronized {
+ val jobsOfTime = jobs.get(time)
+ if (jobsOfTime.isDefined) {
+ jobsOfTime.get -= job
+ if (jobsOfTime.get.isEmpty) {
+ jobs -= time
+ timeCleared = true
+ }
+ } else {
+ throw new Exception("Job finished for time " + job.time +
+ " but time does not exist in jobs")
+ }
+ }
+ if (timeCleared) {
+ ssc.scheduler.clearOldMetadata(time)
+ }
+ }
+
+ def getPendingTimes(): Array[Time] = {
+ jobs.synchronized {
+ jobs.keySet.toArray
+ }
+ }
}
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
index e4152f3a61..b159d26c02 100644
--- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
@@ -4,6 +4,7 @@ import spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
import spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
import spark.Logging
import spark.SparkEnv
+import spark.SparkContext._
import scala.collection.mutable.HashMap
import scala.collection.mutable.Queue
@@ -23,7 +24,7 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) ext
*/
private[streaming]
class NetworkInputTracker(
- @transient ssc: StreamingContext,
+ @transient ssc: StreamingContext,
@transient networkInputStreams: Array[NetworkInputDStream[_]])
extends Logging {
@@ -65,12 +66,12 @@ class NetworkInputTracker(
def receive = {
case RegisterReceiver(streamId, receiverActor) => {
if (!networkInputStreamMap.contains(streamId)) {
- throw new Exception("Register received for unexpected id " + streamId)
+ throw new Exception("Register received for unexpected id " + streamId)
}
receiverInfo += ((streamId, receiverActor))
logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address)
sender ! true
- }
+ }
case AddBlocks(streamId, blockIds, metadata) => {
val tmp = receivedBlockIds.synchronized {
if (!receivedBlockIds.contains(streamId)) {
@@ -85,7 +86,7 @@ class NetworkInputTracker(
}
case DeregisterReceiver(streamId, msg) => {
receiverInfo -= streamId
- logInfo("De-registered receiver for network stream " + streamId
+ logError("De-registered receiver for network stream " + streamId
+ " with message " + msg)
//TODO: Do something about the corresponding NetworkInputDStream
}
@@ -95,8 +96,8 @@ class NetworkInputTracker(
/** This thread class runs all the receivers on the cluster. */
class ReceiverExecutor extends Thread {
val env = ssc.env
-
- override def run() {
+
+ override def run() {
try {
SparkEnv.set(env)
startReceivers()
@@ -113,7 +114,7 @@ class NetworkInputTracker(
*/
def startReceivers() {
val receivers = networkInputStreams.map(nis => {
- val rcvr = nis.createReceiver()
+ val rcvr = nis.getReceiver()
rcvr.setStreamId(nis.id)
rcvr
})
@@ -138,10 +139,14 @@ class NetworkInputTracker(
}
iterator.next().start()
}
+ // Run the dummy Spark job to ensure that all slaves have registered.
+ // This avoids all the receivers to be scheduled on the same node.
+ ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect()
+
// Distribute the receivers and start them
- ssc.sc.runJob(tempRDD, startReceiver)
+ ssc.sparkContext.runJob(tempRDD, startReceiver)
}
-
+
/** Stops the receivers. */
def stopReceivers() {
// Signal the receivers to stop
diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
index fbcf061126..3ec922957d 100644
--- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
@@ -18,15 +18,15 @@ import org.apache.hadoop.conf.Configuration
class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)])
extends Serializable {
-
- def ssc = self.ssc
+
+ private[streaming] def ssc = self.ssc
private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = {
new HashPartitioner(numPartitions)
}
/**
- * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
+ * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
* generate the RDDs with Spark's default number of partitions.
*/
def groupByKey(): DStream[(K, Seq[V])] = {
@@ -34,7 +34,7 @@ extends Serializable {
}
/**
- * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
+ * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
* generate the RDDs with `numPartitions` partitions.
*/
def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = {
@@ -42,7 +42,7 @@ extends Serializable {
}
/**
- * Create a new DStream by applying `groupByKey` on each RDD. The supplied [[spark.Partitioner]]
+ * Return a new DStream by applying `groupByKey` on each RDD. The supplied [[spark.Partitioner]]
* is used to control the partitioning of each RDD.
*/
def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = {
@@ -54,7 +54,7 @@ extends Serializable {
}
/**
- * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are
* merged using the associative reduce function. Hash partitioning is used to generate the RDDs
* with Spark's default number of partitions.
*/
@@ -63,7 +63,7 @@ extends Serializable {
}
/**
- * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are
* merged using the supplied reduce function. Hash partitioning is used to generate the RDDs
* with `numPartitions` partitions.
*/
@@ -72,7 +72,7 @@ extends Serializable {
}
/**
- * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are
* merged using the supplied reduce function. [[spark.Partitioner]] is used to control the
* partitioning of each RDD.
*/
@@ -82,7 +82,7 @@ extends Serializable {
}
/**
- * Combine elements of each key in DStream's RDDs using custom function. This is similar to the
+ * Combine elements of each key in DStream's RDDs using custom functions. This is similar to the
* combineByKey for RDDs. Please refer to combineByKey in [[spark.PairRDDFunctions]] for more
* information.
*/
@@ -95,15 +95,7 @@ extends Serializable {
}
/**
- * Create a new DStream by counting the number of values of each key in each RDD. Hash
- * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions.
- */
- def countByKey(numPartitions: Int = self.ssc.sc.defaultParallelism): DStream[(K, Long)] = {
- self.map(x => (x._1, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions)
- }
-
- /**
- * Creates a new DStream by applying `groupByKey` over a sliding window. This is similar to
+ * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to
* `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs
* with the same interval as this DStream. Hash partitioning is used to generate the RDDs with
* Spark's default number of partitions.
@@ -115,7 +107,7 @@ extends Serializable {
}
/**
- * Create a new DStream by applying `groupByKey` over a sliding window. Similar to
+ * Return a new DStream by applying `groupByKey` over a sliding window. Similar to
* `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to
* generate the RDDs with Spark's default number of partitions.
* @param windowDuration width of the window; must be a multiple of this DStream's
@@ -129,7 +121,7 @@ extends Serializable {
}
/**
- * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream.
+ * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream.
* Similar to `DStream.groupByKey()`, but applies it over a sliding window.
* Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
* @param windowDuration width of the window; must be a multiple of this DStream's
@@ -137,7 +129,8 @@ extends Serializable {
* @param slideDuration sliding interval of the window (i.e., the interval after which
* the new DStream will generate RDDs); must be a multiple of this
* DStream's batching interval
- * @param numPartitions Number of partitions of each RDD in the new DStream.
+ * @param numPartitions number of partitions of each RDD in the new DStream; if not specified
+ * then Spark's default number of partitions will be used
*/
def groupByKeyAndWindow(
windowDuration: Duration,
@@ -155,7 +148,7 @@ extends Serializable {
* @param slideDuration sliding interval of the window (i.e., the interval after which
* the new DStream will generate RDDs); must be a multiple of this
* DStream's batching interval
- * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ * @param partitioner partitioner for controlling the partitioning of each RDD in the new DStream.
*/
def groupByKeyAndWindow(
windowDuration: Duration,
@@ -166,7 +159,7 @@ extends Serializable {
}
/**
- * Create a new DStream by applying `reduceByKey` over a sliding window on `this` DStream.
+ * Return a new DStream by applying `reduceByKey` over a sliding window on `this` DStream.
* Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream
* generates RDDs with the same interval as this DStream. Hash partitioning is used to generate
* the RDDs with Spark's default number of partitions.
@@ -182,7 +175,7 @@ extends Serializable {
}
/**
- * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to
+ * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to
* `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to
* generate the RDDs with Spark's default number of partitions.
* @param reduceFunc associative reduce function
@@ -201,7 +194,7 @@ extends Serializable {
}
/**
- * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to
+ * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to
* `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to
* generate the RDDs with `numPartitions` partitions.
* @param reduceFunc associative reduce function
@@ -210,10 +203,10 @@ extends Serializable {
* @param slideDuration sliding interval of the window (i.e., the interval after which
* the new DStream will generate RDDs); must be a multiple of this
* DStream's batching interval
- * @param numPartitions Number of partitions of each RDD in the new DStream.
+ * @param numPartitions number of partitions of each RDD in the new DStream.
*/
def reduceByKeyAndWindow(
- reduceFunc: (V, V) => V,
+ reduceFunc: (V, V) => V,
windowDuration: Duration,
slideDuration: Duration,
numPartitions: Int
@@ -222,7 +215,7 @@ extends Serializable {
}
/**
- * Create a new DStream by applying `reduceByKey` over a sliding window. Similar to
+ * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to
* `DStream.reduceByKey()`, but applies it over a sliding window.
* @param reduceFunc associative reduce function
* @param windowDuration width of the window; must be a multiple of this DStream's
@@ -230,7 +223,8 @@ extends Serializable {
* @param slideDuration sliding interval of the window (i.e., the interval after which
* the new DStream will generate RDDs); must be a multiple of this
* DStream's batching interval
- * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ * @param partitioner partitioner for controlling the partitioning of each RDD
+ * in the new DStream.
*/
def reduceByKeyAndWindow(
reduceFunc: (V, V) => V,
@@ -245,118 +239,78 @@ extends Serializable {
}
/**
- * Create a new DStream by reducing over a using incremental computation.
- * The reduced value of over a new window is calculated using the old window's reduce value :
+ * Return a new DStream by applying incremental `reduceByKey` over a sliding window.
+ * The reduced value of over a new window is calculated using the old window's reduced value :
* 1. reduce the new values that entered the window (e.g., adding new counts)
+ *
* 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
- * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function.
+ *
+ * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function.
* However, it is applicable to only "invertible reduce functions".
* Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
* @param reduceFunc associative reduce function
- * @param invReduceFunc inverse function
+ * @param invReduceFunc inverse reduce function
* @param windowDuration width of the window; must be a multiple of this DStream's
* batching interval
* @param slideDuration sliding interval of the window (i.e., the interval after which
* the new DStream will generate RDDs); must be a multiple of this
* DStream's batching interval
+ * @param filterFunc Optional function to filter expired key-value pairs;
+ * only pairs that satisfy the function are retained
*/
def reduceByKeyAndWindow(
reduceFunc: (V, V) => V,
invReduceFunc: (V, V) => V,
windowDuration: Duration,
- slideDuration: Duration
+ slideDuration: Duration = self.slideDuration,
+ numPartitions: Int = ssc.sc.defaultParallelism,
+ filterFunc: ((K, V)) => Boolean = null
): DStream[(K, V)] = {
reduceByKeyAndWindow(
- reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner())
- }
-
- /**
- * Create a new DStream by reducing over a using incremental computation.
- * The reduced value of over a new window is calculated using the old window's reduce value :
- * 1. reduce the new values that entered the window (e.g., adding new counts)
- * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
- * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function.
- * However, it is applicable to only "invertible reduce functions".
- * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
- * @param reduceFunc associative reduce function
- * @param invReduceFunc inverse function
- * @param windowDuration width of the window; must be a multiple of this DStream's
- * batching interval
- * @param slideDuration sliding interval of the window (i.e., the interval after which
- * the new DStream will generate RDDs); must be a multiple of this
- * DStream's batching interval
- * @param numPartitions Number of partitions of each RDD in the new DStream.
- */
- def reduceByKeyAndWindow(
- reduceFunc: (V, V) => V,
- invReduceFunc: (V, V) => V,
- windowDuration: Duration,
- slideDuration: Duration,
- numPartitions: Int
- ): DStream[(K, V)] = {
-
- reduceByKeyAndWindow(
- reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions))
+ reduceFunc, invReduceFunc, windowDuration,
+ slideDuration, defaultPartitioner(numPartitions), filterFunc
+ )
}
/**
- * Create a new DStream by reducing over a using incremental computation.
- * The reduced value of over a new window is calculated using the old window's reduce value :
+ * Return a new DStream by applying incremental `reduceByKey` over a sliding window.
+ * The reduced value of over a new window is calculated using the old window's reduced value :
* 1. reduce the new values that entered the window (e.g., adding new counts)
* 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
- * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function.
+ * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function.
* However, it is applicable to only "invertible reduce functions".
- * @param reduceFunc associative reduce function
- * @param invReduceFunc inverse function
+ * @param reduceFunc associative reduce function
+ * @param invReduceFunc inverse reduce function
* @param windowDuration width of the window; must be a multiple of this DStream's
* batching interval
* @param slideDuration sliding interval of the window (i.e., the interval after which
* the new DStream will generate RDDs); must be a multiple of this
* DStream's batching interval
- * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ * @param partitioner partitioner for controlling the partitioning of each RDD in the new DStream.
+ * @param filterFunc Optional function to filter expired key-value pairs;
+ * only pairs that satisfy the function are retained
*/
def reduceByKeyAndWindow(
reduceFunc: (V, V) => V,
invReduceFunc: (V, V) => V,
windowDuration: Duration,
slideDuration: Duration,
- partitioner: Partitioner
+ partitioner: Partitioner,
+ filterFunc: ((K, V)) => Boolean
): DStream[(K, V)] = {
val cleanedReduceFunc = ssc.sc.clean(reduceFunc)
val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc)
+ val cleanedFilterFunc = if (filterFunc != null) Some(ssc.sc.clean(filterFunc)) else None
new ReducedWindowedDStream[K, V](
- self, cleanedReduceFunc, cleanedInvReduceFunc, windowDuration, slideDuration, partitioner)
- }
-
- /**
- * Create a new DStream by counting the number of values for each key over a window.
- * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
- * @param windowDuration width of the window; must be a multiple of this DStream's
- * batching interval
- * @param slideDuration sliding interval of the window (i.e., the interval after which
- * the new DStream will generate RDDs); must be a multiple of this
- * DStream's batching interval
- * @param numPartitions Number of partitions of each RDD in the new DStream.
- */
- def countByKeyAndWindow(
- windowDuration: Duration,
- slideDuration: Duration,
- numPartitions: Int = self.ssc.sc.defaultParallelism
- ): DStream[(K, Long)] = {
-
- self.map(x => (x._1, 1L)).reduceByKeyAndWindow(
- (x: Long, y: Long) => x + y,
- (x: Long, y: Long) => x - y,
- windowDuration,
- slideDuration,
- numPartitions
+ self, cleanedReduceFunc, cleanedInvReduceFunc, cleanedFilterFunc,
+ windowDuration, slideDuration, partitioner
)
}
/**
- * Create a new "state" DStream where the state for each key is updated by applying
+ * Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
* Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
* @param updateFunc State update function. If `this` function returns None, then
@@ -370,7 +324,7 @@ extends Serializable {
}
/**
- * Create a new "state" DStream where the state for each key is updated by applying
+ * Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
* Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
* @param updateFunc State update function. If `this` function returns None, then
@@ -405,7 +359,7 @@ extends Serializable {
}
/**
- * Create a new "state" DStream where the state for each key is updated by applying
+ * Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
* [[spark.Paxrtitioner]] is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
@@ -447,7 +401,7 @@ extends Serializable {
}
/**
- * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this`
+ * Cogroup `this` DStream with `other` DStream using a partitioner. For each key k in corresponding RDDs of `this`
* or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
* key in both RDDs. Partitioner is used to partition each generated RDD.
*/
@@ -457,7 +411,7 @@ extends Serializable {
): DStream[(K, (Seq[V], Seq[W]))] = {
val cgd = new CoGroupedDStream[K](
- Seq(self.asInstanceOf[DStream[(_, _)]], other.asInstanceOf[DStream[(_, _)]]),
+ Seq(self.asInstanceOf[DStream[(K, _)]], other.asInstanceOf[DStream[(K, _)]]),
partitioner
)
val pdfs = new PairDStreamFunctions[K, Seq[Seq[_]]](cgd)(
diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala
index c04ed37de8..1c4b22a898 100644
--- a/streaming/src/main/scala/spark/streaming/Scheduler.scala
+++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala
@@ -9,11 +9,8 @@ class Scheduler(ssc: StreamingContext) extends Logging {
initLogging()
- val graph = ssc.graph
-
val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt
val jobManager = new JobManager(ssc, concurrentJobs)
-
val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
new CheckpointWriter(ssc.checkpointDir)
} else {
@@ -23,54 +20,93 @@ class Scheduler(ssc: StreamingContext) extends Logging {
val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock")
val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock]
val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
- longTime => generateRDDs(new Time(longTime)))
+ longTime => generateJobs(new Time(longTime)))
+ val graph = ssc.graph
+ var latestTime: Time = null
- def start() {
- // If context was started from checkpoint, then restart timer such that
- // this timer's triggers occur at the same time as the original timer.
- // Otherwise just start the timer from scratch, and initialize graph based
- // on this first trigger time of the timer.
+ def start() = synchronized {
if (ssc.isCheckpointPresent) {
- // If manual clock is being used for testing, then
- // either set the manual clock to the last checkpointed time,
- // or if the property is defined set it to that time
- if (clock.isInstanceOf[ManualClock]) {
- val lastTime = ssc.getInitialCheckpoint.checkpointTime.milliseconds
- val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong
- clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime)
- }
- timer.restart(graph.zeroTime.milliseconds)
- logInfo("Scheduler's timer restarted")
+ restart()
} else {
- val firstTime = new Time(timer.start())
- graph.start(firstTime - ssc.graph.batchDuration)
- logInfo("Scheduler's timer started")
+ startFirstTime()
}
logInfo("Scheduler started")
}
- def stop() {
+ def stop() = synchronized {
timer.stop()
- graph.stop()
+ jobManager.stop()
+ if (checkpointWriter != null) checkpointWriter.stop()
+ ssc.graph.stop()
logInfo("Scheduler stopped")
}
-
- private def generateRDDs(time: Time) {
+
+ private def startFirstTime() {
+ val startTime = new Time(timer.getStartTime())
+ graph.start(startTime - graph.batchDuration)
+ timer.start(startTime.milliseconds)
+ logInfo("Scheduler's timer started at " + startTime)
+ }
+
+ private def restart() {
+
+ // If manual clock is being used for testing, then
+ // either set the manual clock to the last checkpointed time,
+ // or if the property is defined set it to that time
+ if (clock.isInstanceOf[ManualClock]) {
+ val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds
+ val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong
+ clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime)
+ }
+
+ val batchDuration = ssc.graph.batchDuration
+
+ // Batches when the master was down, that is,
+ // between the checkpoint and current restart time
+ val checkpointTime = ssc.initialCheckpoint.checkpointTime
+ val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds))
+ val downTimes = checkpointTime.until(restartTime, batchDuration)
+ logInfo("Batches during down time: " + downTimes.mkString(", "))
+
+ // Batches that were unprocessed before failure
+ val pendingTimes = ssc.initialCheckpoint.pendingTimes
+ logInfo("Batches pending processing: " + pendingTimes.mkString(", "))
+ // Reschedule jobs for these times
+ val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
+ logInfo("Batches to reschedule: " + timesToReschedule.mkString(", "))
+ timesToReschedule.foreach(time =>
+ graph.generateJobs(time).foreach(jobManager.runJob)
+ )
+
+ // Restart the timer
+ timer.start(restartTime.milliseconds)
+ logInfo("Scheduler's timer restarted at " + restartTime)
+ }
+
+ /** Generate jobs and perform checkpoint for the given `time`. */
+ def generateJobs(time: Time) {
SparkEnv.set(ssc.env)
logInfo("\n-----------------------------------------------------\n")
- graph.generateRDDs(time).foreach(jobManager.runJob)
- graph.forgetOldRDDs(time)
+ graph.generateJobs(time).foreach(jobManager.runJob)
+ latestTime = time
+ doCheckpoint(time)
+ }
+
+ /**
+ * Clear old metadata assuming jobs of `time` have finished processing.
+ * And also perform checkpoint.
+ */
+ def clearOldMetadata(time: Time) {
+ ssc.graph.clearOldMetadata(time)
doCheckpoint(time)
- logInfo("Generated RDDs for time " + time)
}
- private def doCheckpoint(time: Time) {
+ /** Perform checkpoint for the give `time`. */
+ def doCheckpoint(time: Time) = synchronized {
if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
- val startTime = System.currentTimeMillis()
+ logInfo("Checkpointing graph for time " + time)
ssc.graph.updateCheckpointData(time)
checkpointWriter.write(new Checkpoint(ssc, time))
- val stopTime = System.currentTimeMillis()
- logInfo("Checkpointing the graph took " + (stopTime - startTime) + " ms")
}
}
}
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index 37ba524b48..25c67b279b 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -1,10 +1,19 @@
package spark.streaming
+import akka.actor.Props
+import akka.actor.SupervisorStrategy
+import akka.zeromq.Subscribe
+
import spark.streaming.dstream._
import spark.{RDD, Logging, SparkEnv, SparkContext}
+import spark.streaming.receivers.ActorReceiver
+import spark.streaming.receivers.ReceiverSupervisorStrategy
+import spark.streaming.receivers.ZeroMQReceiver
import spark.storage.StorageLevel
import spark.util.MetadataCleaner
+import spark.streaming.receivers.ActorReceiver
+
import scala.collection.mutable.Queue
@@ -17,6 +26,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
import org.apache.hadoop.fs.Path
import java.util.UUID
+import twitter4j.Status
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -30,23 +40,23 @@ class StreamingContext private (
) extends Logging {
/**
- * Creates a StreamingContext using an existing SparkContext.
+ * Create a StreamingContext using an existing SparkContext.
* @param sparkContext Existing SparkContext
* @param batchDuration The time interval at which streaming data will be divided into batches
*/
def this(sparkContext: SparkContext, batchDuration: Duration) = this(sparkContext, null, batchDuration)
/**
- * Creates a StreamingContext by providing the details necessary for creating a new SparkContext.
+ * Create a StreamingContext by providing the details necessary for creating a new SparkContext.
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param frameworkName A name for your job, to display on the cluster web UI
+ * @param appName A name for your job, to display on the cluster web UI
* @param batchDuration The time interval at which streaming data will be divided into batches
*/
- def this(master: String, frameworkName: String, batchDuration: Duration) =
- this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration)
+ def this(master: String, appName: String, batchDuration: Duration) =
+ this(StreamingContext.createNewSparkContext(master, appName), null, batchDuration)
/**
- * Re-creates a StreamingContext from a checkpoint file.
+ * Re-create a StreamingContext from a checkpoint file.
* @param path Path either to the directory that was specified as the checkpoint directory, or
* to the checkpoint file 'graph' or 'graph.bk'.
*/
@@ -61,7 +71,7 @@ class StreamingContext private (
protected[streaming] val isCheckpointPresent = (cp_ != null)
- val sc: SparkContext = {
+ protected[streaming] val sc: SparkContext = {
if (isCheckpointPresent) {
new SparkContext(cp_.master, cp_.framework, cp_.sparkHome, cp_.jars)
} else {
@@ -101,7 +111,12 @@ class StreamingContext private (
protected[streaming] var scheduler: Scheduler = null
/**
- * Sets each DStreams in this context to remember RDDs it generated in the last given duration.
+ * Return the associated Spark context
+ */
+ def sparkContext = sc
+
+ /**
+ * Set each DStreams in this context to remember RDDs it generated in the last given duration.
* DStreams remember RDDs only for a limited duration of time and releases them for garbage
* collection. This method allows the developer to specify how to long to remember the RDDs (
* if the developer wishes to query old data outside the DStream computation).
@@ -112,71 +127,119 @@ class StreamingContext private (
}
/**
- * Sets the context to periodically checkpoint the DStream operations for master
- * fault-tolerance. By default, the graph will be checkpointed every batch interval.
+ * Set the context to periodically checkpoint the DStream operations for master
+ * fault-tolerance. The graph will be checkpointed every batch interval.
* @param directory HDFS-compatible directory where the checkpoint data will be reliably stored
- * @param interval checkpoint interval
*/
- def checkpoint(directory: String, interval: Duration = null) {
+ def checkpoint(directory: String) {
if (directory != null) {
sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(directory))
checkpointDir = directory
- checkpointDuration = interval
} else {
checkpointDir = null
- checkpointDuration = null
}
}
- protected[streaming] def getInitialCheckpoint(): Checkpoint = {
+ protected[streaming] def initialCheckpoint: Checkpoint = {
if (isCheckpointPresent) cp_ else null
}
protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
/**
+ * Create an input stream with any arbitrary user implemented network receiver.
+ * @param receiver Custom implementation of NetworkReceiver
+ */
+ def networkStream[T: ClassManifest](
+ receiver: NetworkReceiver[T]): DStream[T] = {
+ val inputStream = new PluggableInputDStream[T](this,
+ receiver)
+ graph.addInputStream(inputStream)
+ inputStream
+ }
+
+ /**
+ * Create an input stream with any arbitrary user implemented actor receiver.
+ * @param props Props object defining creation of the actor
+ * @param name Name of the actor
+ * @param storageLevel RDD storage level. Defaults to memory-only.
+ *
+ * @note An important point to note:
+ * Since Actor may exist outside the spark framework, It is thus user's responsibility
+ * to ensure the type safety, i.e parametrized type of data received and actorStream
+ * should be same.
+ */
+ def actorStream[T: ClassManifest](
+ props: Props,
+ name: String,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
+ supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy): DStream[T] = {
+ networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy))
+ }
+
+ /**
+ * Create an input stream that receives messages pushed by a zeromq publisher.
+ * @param publisherUrl Url of remote zeromq publisher
+ * @param subscribe topic to subscribe to
+ * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence
+ * of byte thus it needs the converter(which might be deserializer of bytes)
+ * to translate from sequence of sequence of bytes, where sequence refer to a frame
+ * and sub sequence refer to its payload.
+ * @param storageLevel RDD storage level. Defaults to memory-only.
+ */
+ def zeroMQStream[T: ClassManifest](
+ publisherUrl:String,
+ subscribe: Subscribe,
+ bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T],
+ storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
+ supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy
+ ): DStream[T] = {
+ actorStream(Props(new ZeroMQReceiver(publisherUrl,subscribe,bytesToObjects)),
+ "ZeroMQReceiver", storageLevel, supervisorStrategy)
+ }
+
+ /**
* Create an input stream that pulls messages form a Kafka Broker.
- * @param hostname Zookeper hostname.
- * @param port Zookeper port.
+ * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param initialOffsets Optional initial offsets for each of the partitions to consume.
* By default the value is pulled from zookeper.
- * @param storageLevel RDD storage level. Defaults to memory-only.
+ * @param storageLevel Storage level to use for storing the received objects
+ * (default: StorageLevel.MEMORY_AND_DISK_SER_2)
*/
def kafkaStream[T: ClassManifest](
- hostname: String,
- port: Int,
+ zkQuorum: String,
groupId: String,
topics: Map[String, Int],
initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](),
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
): DStream[T] = {
- val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel)
+ val inputStream = new KafkaInputDStream[T](this, zkQuorum, groupId, topics, initialOffsets, storageLevel)
registerInputStream(inputStream)
inputStream
}
/**
- * Create a input stream from network source hostname:port. Data is received using
- * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited
+ * Create a input stream from TCP source hostname:port. Data is received using
+ * a TCP socket and the receive bytes is interpreted as UTF8 encoded `\n` delimited
* lines.
* @param hostname Hostname to connect to for receiving data
* @param port Port to connect to for receiving data
* @param storageLevel Storage level to use for storing the received objects
* (default: StorageLevel.MEMORY_AND_DISK_SER_2)
*/
- def networkTextStream(
+ def socketTextStream(
hostname: String,
port: Int,
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
): DStream[String] = {
- networkStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel)
+ socketStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel)
}
/**
- * Create a input stream from network source hostname:port. Data is received using
+ * Create a input stream from TCP source hostname:port. Data is received using
* a TCP socket and the receive bytes it interepreted as object using the given
* converter.
* @param hostname Hostname to connect to for receiving data
@@ -185,7 +248,7 @@ class StreamingContext private (
* @param storageLevel Storage level to use for storing the received objects
* @tparam T Type of the objects received (after converting bytes to objects)
*/
- def networkStream[T: ClassManifest](
+ def socketStream[T: ClassManifest](
hostname: String,
port: Int,
converter: (InputStream) => Iterator[T],
@@ -197,7 +260,7 @@ class StreamingContext private (
}
/**
- * Creates a input stream from a Flume source.
+ * Create a input stream from a Flume source.
* @param hostname Hostname of the slave machine to which the flume data will be sent
* @param port Port of the slave machine to which the flume data will be sent
* @param storageLevel Storage level to use for storing the received objects
@@ -222,7 +285,7 @@ class StreamingContext private (
* @param storageLevel Storage level to use for storing the received objects
* @tparam T Type of the objects in the received blocks
*/
- def rawNetworkStream[T: ClassManifest](
+ def rawSocketStream[T: ClassManifest](
hostname: String,
port: Int,
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
@@ -233,7 +296,7 @@ class StreamingContext private (
}
/**
- * Creates a input stream that monitors a Hadoop-compatible filesystem
+ * Create a input stream that monitors a Hadoop-compatible filesystem
* for new files and reads them using the given key-value types and input format.
* File names starting with . are ignored.
* @param directory HDFS directory to monitor for new file
@@ -252,7 +315,7 @@ class StreamingContext private (
}
/**
- * Creates a input stream that monitors a Hadoop-compatible filesystem
+ * Create a input stream that monitors a Hadoop-compatible filesystem
* for new files and reads them using the given key-value types and input format.
* @param directory HDFS directory to monitor for new file
* @param filter Function to filter paths to process
@@ -271,9 +334,8 @@ class StreamingContext private (
inputStream
}
-
/**
- * Creates a input stream that monitors a Hadoop-compatible filesystem
+ * Create a input stream that monitors a Hadoop-compatible filesystem
* for new files and reads them as text files (using key as LongWritable, value
* as Text and input format as TextInputFormat). File names starting with . are ignored.
* @param directory HDFS directory to monitor for new file
@@ -283,17 +345,49 @@ class StreamingContext private (
}
/**
- * Creates a input stream from an queue of RDDs. In each batch,
+ * Create a input stream that returns tweets received from Twitter.
+ * @param username Twitter username
+ * @param password Twitter password
+ * @param filters Set of filter strings to get only those tweets that match them
+ * @param storageLevel Storage level to use for storing the received objects
+ */
+ def twitterStream(
+ username: String,
+ password: String,
+ filters: Seq[String] = Nil,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
+ ): DStream[Status] = {
+ val inputStream = new TwitterInputDStream(this, username, password, filters, storageLevel)
+ registerInputStream(inputStream)
+ inputStream
+ }
+
+ /**
+ * Create an input stream from a queue of RDDs. In each batch,
* it will process either one or all of the RDDs returned by the queue.
* @param queue Queue of RDDs
* @param oneAtATime Whether only one RDD should be consumed from the queue in every interval
- * @param defaultRDD Default RDD is returned by the DStream when the queue is empty
* @tparam T Type of objects in the RDD
*/
def queueStream[T: ClassManifest](
queue: Queue[RDD[T]],
- oneAtATime: Boolean = true,
- defaultRDD: RDD[T] = null
+ oneAtATime: Boolean = true
+ ): DStream[T] = {
+ queueStream(queue, oneAtATime, sc.makeRDD(Seq[T](), 1))
+ }
+
+ /**
+ * Create an input stream from a queue of RDDs. In each batch,
+ * it will process either one or all of the RDDs returned by the queue.
+ * @param queue Queue of RDDs
+ * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval
+ * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. Set as null if no RDD should be returned when empty
+ * @tparam T Type of objects in the RDD
+ */
+ def queueStream[T: ClassManifest](
+ queue: Queue[RDD[T]],
+ oneAtATime: Boolean,
+ defaultRDD: RDD[T]
): DStream[T] = {
val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD)
registerInputStream(inputStream)
@@ -308,7 +402,7 @@ class StreamingContext private (
}
/**
- * Registers an input stream that will be started (InputDStream.start() called) to get the
+ * Register an input stream that will be started (InputDStream.start() called) to get the
* input data.
*/
def registerInputStream(inputStream: InputDStream[_]) {
@@ -316,7 +410,7 @@ class StreamingContext private (
}
/**
- * Registers an output stream that will be computed every interval
+ * Register an output stream that will be computed every interval
*/
def registerOutputStream(outputStream: DStream[_]) {
graph.addOutputStream(outputStream)
@@ -334,7 +428,7 @@ class StreamingContext private (
}
/**
- * Starts the execution of the streams.
+ * Start the execution of the streams.
*/
def start() {
if (checkpointDir != null && checkpointDuration == null && graph != null) {
@@ -362,7 +456,7 @@ class StreamingContext private (
}
/**
- * Sstops the execution of the streams.
+ * Stop the execution of the streams.
*/
def stop() {
try {
@@ -384,14 +478,14 @@ object StreamingContext {
new PairDStreamFunctions[K, V](stream)
}
- protected[streaming] def createNewSparkContext(master: String, frameworkName: String): SparkContext = {
+ protected[streaming] def createNewSparkContext(master: String, appName: String): SparkContext = {
// Set the default cleaner delay to an hour if not already set.
// This should be sufficient for even 1 second interval.
if (MetadataCleaner.getDelaySeconds < 0) {
MetadataCleaner.setDelaySeconds(3600)
}
- new SparkContext(master, frameworkName)
+ new SparkContext(master, appName)
}
protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = {
@@ -408,4 +502,3 @@ object StreamingContext {
new Path(sscCheckpointDir, UUID.randomUUID.toString).toString
}
}
-
diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala
index 5daeb761dd..f14decf08b 100644
--- a/streaming/src/main/scala/spark/streaming/Time.scala
+++ b/streaming/src/main/scala/spark/streaming/Time.scala
@@ -37,6 +37,19 @@ case class Time(private val millis: Long) {
def max(that: Time): Time = if (this > that) this else that
+ def until(that: Time, interval: Duration): Seq[Time] = {
+ (this.milliseconds) until (that.milliseconds) by (interval.milliseconds) map (new Time(_))
+ }
+
+ def to(that: Time, interval: Duration): Seq[Time] = {
+ (this.milliseconds) to (that.milliseconds) by (interval.milliseconds) map (new Time(_))
+ }
+
+
override def toString: String = (millis.toString + " ms")
+}
+
+object Time {
+ val ordering = Ordering.by((time: Time) => time.millis)
} \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala
index 2e7466b16c..4d93f0a5f7 100644
--- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala
@@ -4,6 +4,7 @@ import spark.streaming.{Duration, Time, DStream}
import spark.api.java.function.{Function => JFunction}
import spark.api.java.JavaRDD
import spark.storage.StorageLevel
+import spark.RDD
/**
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
@@ -16,9 +17,7 @@ import spark.storage.StorageLevel
*
* This class contains the basic operations available on all DStreams, such as `map`, `filter` and
* `window`. In addition, [[spark.streaming.api.java.JavaPairDStream]] contains operations available
- * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. These operations
- * are automatically available on any DStream of the right type (e.g., DStream[(Int, Int)] through
- * implicit conversions when `spark.streaming.StreamingContext._` is imported.
+ * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`.
*
* DStreams internally is characterized by a few basic properties:
* - A list of other DStreams that the DStream depends on
@@ -26,7 +25,9 @@ import spark.storage.StorageLevel
* - A function that is used to generate an RDD after each time interval
*/
class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T])
- extends JavaDStreamLike[T, JavaDStream[T]] {
+ extends JavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] {
+
+ override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd)
/** Return a new DStream containing only the elements that satisfy a predicate. */
def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] =
@@ -36,7 +37,7 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM
def cache(): JavaDStream[T] = dstream.cache()
/** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
- def persist(): JavaDStream[T] = dstream.cache()
+ def persist(): JavaDStream[T] = dstream.persist()
/** Persist the RDDs of this DStream with the given storage level */
def persist(storageLevel: StorageLevel): JavaDStream[T] = dstream.persist(storageLevel)
@@ -50,34 +51,27 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM
}
/**
- * Return a new DStream which is computed based on windowed batches of this DStream.
- * The new DStream generates RDDs with the same interval as this DStream.
+ * Return a new DStream in which each RDD contains all the elements in seen in a
+ * sliding window of time over this DStream. The new DStream generates RDDs with
+ * the same interval as this DStream.
* @param windowDuration width of the window; must be a multiple of this DStream's interval.
- * @return
*/
def window(windowDuration: Duration): JavaDStream[T] =
dstream.window(windowDuration)
/**
- * Return a new DStream which is computed based on windowed batches of this DStream.
- * @param windowDuration duration (i.e., width) of the window;
- * must be a multiple of this DStream's interval
+ * Return a new DStream in which each RDD contains all the elements in seen in a
+ * sliding window of time over this DStream.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
* @param slideDuration sliding interval of the window (i.e., the interval after which
- * the new DStream will generate RDDs); must be a multiple of this
- * DStream's interval
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
*/
def window(windowDuration: Duration, slideDuration: Duration): JavaDStream[T] =
dstream.window(windowDuration, slideDuration)
/**
- * Return a new DStream which computed based on tumbling window on this DStream.
- * This is equivalent to window(batchDuration, batchDuration).
- * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval
- */
- def tumble(batchDuration: Duration): JavaDStream[T] =
- dstream.tumble(batchDuration)
-
- /**
* Return a new DStream by unifying data of another DStream with this DStream.
* @param that Another DStream having the same interval (i.e., slideDuration) as this DStream.
*/
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala
index b93cb7865a..548809a359 100644
--- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala
@@ -6,17 +6,20 @@ import java.lang.{Long => JLong}
import scala.collection.JavaConversions._
import spark.streaming._
-import spark.api.java.JavaRDD
+import spark.api.java.{JavaPairRDD, JavaRDDLike, JavaRDD}
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import java.util
import spark.RDD
import JavaDStream._
-trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable {
+trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]]
+ extends Serializable {
implicit val classManifest: ClassManifest[T]
def dstream: DStream[T]
+ def wrapRDD(in: RDD[T]): R
+
implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[JLong] = {
in.map(new JLong(_))
}
@@ -34,6 +37,26 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable
def count(): JavaDStream[JLong] = dstream.count()
/**
+ * Return a new DStream in which each RDD contains the counts of each distinct value in
+ * each RDD of this DStream. Hash partitioning is used to generate the RDDs with
+ * Spark's default number of partitions.
+ */
+ def countByValue(): JavaPairDStream[T, JLong] = {
+ JavaPairDStream.scalaToJavaLong(dstream.countByValue())
+ }
+
+ /**
+ * Return a new DStream in which each RDD contains the counts of each distinct value in
+ * each RDD of this DStream. Hash partitioning is used to generate the RDDs with `numPartitions`
+ * partitions.
+ * @param numPartitions number of partitions of each RDD in the new DStream.
+ */
+ def countByValue(numPartitions: Int): JavaPairDStream[T, JLong] = {
+ JavaPairDStream.scalaToJavaLong(dstream.countByValue(numPartitions))
+ }
+
+
+ /**
* Return a new DStream in which each RDD has a single element generated by counting the number
* of elements in a window over this DStream. windowDuration and slideDuration are as defined in the
* window() operation. This is equivalent to window(windowDuration, slideDuration).count()
@@ -43,6 +66,39 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable
}
/**
+ * Return a new DStream in which each RDD contains the count of distinct elements in
+ * RDDs in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with
+ * Spark's default number of partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration)
+ : JavaPairDStream[T, JLong] = {
+ JavaPairDStream.scalaToJavaLong(
+ dstream.countByValueAndWindow(windowDuration, slideDuration))
+ }
+
+ /**
+ * Return a new DStream in which each RDD contains the count of distinct elements in
+ * RDDs in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with `numPartitions`
+ * partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param numPartitions number of partitions of each RDD in the new DStream.
+ */
+ def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int)
+ : JavaPairDStream[T, JLong] = {
+ JavaPairDStream.scalaToJavaLong(
+ dstream.countByValueAndWindow(windowDuration, slideDuration, numPartitions))
+ }
+
+ /**
* Return a new DStream in which each RDD is generated by applying glom() to each RDD of
* this DStream. Applying glom() to an RDD coalesces all elements within each partition into
* an array.
@@ -59,8 +115,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable
}
/** Return a new DStream by applying a function to all elements of this DStream. */
- def map[K, V](f: PairFunction[T, K, V]): JavaPairDStream[K, V] = {
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+ def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairDStream[K2, V2] = {
+ def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType())
}
@@ -78,10 +134,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable
* Return a new DStream by applying a function to all elements of this DStream,
* and then flattening the results
*/
- def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairDStream[K, V] = {
+ def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+ def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
new JavaPairDStream(dstream.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
@@ -100,8 +156,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable
* of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
* of the RDD.
*/
- def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V])
- : JavaPairDStream[K, V] = {
+ def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2])
+ : JavaPairDStream[K2, V2] = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
new JavaPairDStream(dstream.mapPartitions(fn))(f.keyType(), f.valueType())
}
@@ -114,8 +170,38 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable
/**
* Return a new DStream in which each RDD has a single element generated by reducing all
- * elements in a window over this DStream. windowDuration and slideDuration are as defined in the
- * window() operation. This is equivalent to window(windowDuration, slideDuration).reduce(reduceFunc)
+ * elements in a sliding window over this DStream.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def reduceByWindow(
+ reduceFunc: (T, T) => T,
+ windowDuration: Duration,
+ slideDuration: Duration
+ ): DStream[T] = {
+ dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration)
+ }
+
+
+ /**
+ * Return a new DStream in which each RDD has a single element generated by reducing all
+ * elements in a sliding window over this DStream. However, the reduction is done incrementally
+ * using the old window's reduced value :
+ * 1. reduce the new values that entered the window (e.g., adding new counts)
+ * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+ * This is more efficient than reduceByWindow without "inverse reduce" function.
+ * However, it is applicable to only "invertible reduce functions".
+ * @param reduceFunc associative reduce function
+ * @param invReduceFunc inverse reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
*/
def reduceByWindow(
reduceFunc: JFunction2[T, T, T],
@@ -129,35 +215,35 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable
/**
* Return all the RDDs between 'fromDuration' to 'toDuration' (both included)
*/
- def slice(fromDuration: Duration, toDuration: Duration): JList[JavaRDD[T]] = {
- new util.ArrayList(dstream.slice(fromDuration, toDuration).map(new JavaRDD(_)).toSeq)
+ def slice(fromTime: Time, toTime: Time): JList[R] = {
+ new util.ArrayList(dstream.slice(fromTime, toTime).map(wrapRDD(_)).toSeq)
}
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
* this DStream will be registered as an output stream and therefore materialized.
*/
- def foreach(foreachFunc: JFunction[JavaRDD[T], Void]) {
- dstream.foreach(rdd => foreachFunc.call(new JavaRDD(rdd)))
+ def foreach(foreachFunc: JFunction[R, Void]) {
+ dstream.foreach(rdd => foreachFunc.call(wrapRDD(rdd)))
}
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
* this DStream will be registered as an output stream and therefore materialized.
*/
- def foreach(foreachFunc: JFunction2[JavaRDD[T], Time, Void]) {
- dstream.foreach((rdd, time) => foreachFunc.call(new JavaRDD(rdd), time))
+ def foreach(foreachFunc: JFunction2[R, Time, Void]) {
+ dstream.foreach((rdd, time) => foreachFunc.call(wrapRDD(rdd), time))
}
/**
* Return a new DStream in which each RDD is generated by applying a function
* on each RDD of this DStream.
*/
- def transform[U](transformFunc: JFunction[JavaRDD[T], JavaRDD[U]]): JavaDStream[U] = {
+ def transform[U](transformFunc: JFunction[R, JavaRDD[U]]): JavaDStream[U] = {
implicit val cm: ClassManifest[U] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
def scalaTransform (in: RDD[T]): RDD[U] =
- transformFunc.call(new JavaRDD[T](in)).rdd
+ transformFunc.call(wrapRDD(in)).rdd
dstream.transform(scalaTransform(_))
}
@@ -165,11 +251,41 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable
* Return a new DStream in which each RDD is generated by applying a function
* on each RDD of this DStream.
*/
- def transform[U](transformFunc: JFunction2[JavaRDD[T], Time, JavaRDD[U]]): JavaDStream[U] = {
+ def transform[U](transformFunc: JFunction2[R, Time, JavaRDD[U]]): JavaDStream[U] = {
implicit val cm: ClassManifest[U] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
def scalaTransform (in: RDD[T], time: Time): RDD[U] =
- transformFunc.call(new JavaRDD[T](in), time).rdd
+ transformFunc.call(wrapRDD(in), time).rdd
+ dstream.transform(scalaTransform(_, _))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of this DStream.
+ */
+ def transform[K2, V2](transformFunc: JFunction[R, JavaPairRDD[K2, V2]]):
+ JavaPairDStream[K2, V2] = {
+ implicit val cmk: ClassManifest[K2] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K2]]
+ implicit val cmv: ClassManifest[V2] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V2]]
+ def scalaTransform (in: RDD[T]): RDD[(K2, V2)] =
+ transformFunc.call(wrapRDD(in)).rdd
+ dstream.transform(scalaTransform(_))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of this DStream.
+ */
+ def transform[K2, V2](transformFunc: JFunction2[R, Time, JavaPairRDD[K2, V2]]):
+ JavaPairDStream[K2, V2] = {
+ implicit val cmk: ClassManifest[K2] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K2]]
+ implicit val cmv: ClassManifest[V2] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V2]]
+ def scalaTransform (in: RDD[T], time: Time): RDD[(K2, V2)] =
+ transformFunc.call(wrapRDD(in), time).rdd
dstream.transform(scalaTransform(_, _))
}
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala
index ef10c091ca..30240cad98 100644
--- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala
@@ -8,34 +8,37 @@ import scala.collection.JavaConversions._
import spark.streaming._
import spark.streaming.StreamingContext._
import spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
-import spark.Partitioner
+import spark.{RDD, Partitioner}
import org.apache.hadoop.mapred.{JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.hadoop.conf.Configuration
-import spark.api.java.JavaPairRDD
+import spark.api.java.{JavaRDD, JavaPairRDD}
import spark.storage.StorageLevel
import com.google.common.base.Optional
+import spark.RDD
class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
implicit val kManifiest: ClassManifest[K],
implicit val vManifest: ClassManifest[V])
- extends JavaDStreamLike[(K, V), JavaPairDStream[K, V]] {
+ extends JavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] {
+
+ override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
// =======================================================================
// Methods common to all DStream's
// =======================================================================
- /** Returns a new DStream containing only the elements that satisfy a predicate. */
+ /** Return a new DStream containing only the elements that satisfy a predicate. */
def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairDStream[K, V] =
dstream.filter((x => f(x).booleanValue()))
- /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
+ /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
def cache(): JavaPairDStream[K, V] = dstream.cache()
- /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
- def persist(): JavaPairDStream[K, V] = dstream.cache()
+ /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
+ def persist(): JavaPairDStream[K, V] = dstream.persist()
- /** Persists the RDDs of this DStream with the given storage level */
+ /** Persist the RDDs of this DStream with the given storage level */
def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel)
/** Method that generates a RDD for the given Duration */
@@ -67,15 +70,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
dstream.window(windowDuration, slideDuration)
/**
- * Returns a new DStream which computed based on tumbling window on this DStream.
- * This is equivalent to window(batchDuration, batchDuration).
- * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval
- */
- def tumble(batchDuration: Duration): JavaPairDStream[K, V] =
- dstream.tumble(batchDuration)
-
- /**
- * Returns a new DStream by unifying data of another DStream with this DStream.
+ * Return a new DStream by unifying data of another DStream with this DStream.
* @param that Another DStream having the same interval (i.e., slideDuration) as this DStream.
*/
def union(that: JavaPairDStream[K, V]): JavaPairDStream[K, V] =
@@ -86,21 +81,21 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
// =======================================================================
/**
- * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
+ * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
* generate the RDDs with Spark's default number of partitions.
*/
def groupByKey(): JavaPairDStream[K, JList[V]] =
dstream.groupByKey().mapValues(seqAsJavaList _)
/**
- * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
+ * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
* generate the RDDs with `numPartitions` partitions.
*/
def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] =
dstream.groupByKey(numPartitions).mapValues(seqAsJavaList _)
/**
- * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream.
+ * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream.
* Therefore, the values for each key in `this` DStream's RDDs are grouped into a
* single sequence to generate the RDDs of the new DStream. [[spark.Partitioner]]
* is used to control the partitioning of each RDD.
@@ -109,7 +104,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
dstream.groupByKey(partitioner).mapValues(seqAsJavaList _)
/**
- * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are
* merged using the associative reduce function. Hash partitioning is used to generate the RDDs
* with Spark's default number of partitions.
*/
@@ -117,7 +112,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
dstream.reduceByKey(func)
/**
- * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are
* merged using the supplied reduce function. Hash partitioning is used to generate the RDDs
* with `numPartitions` partitions.
*/
@@ -125,7 +120,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
dstream.reduceByKey(func, numPartitions)
/**
- * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are
* merged using the supplied reduce function. [[spark.Partitioner]] is used to control the
* partitioning of each RDD.
*/
@@ -149,24 +144,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new DStream by counting the number of values of each key in each RDD. Hash
- * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions.
- */
- def countByKey(numPartitions: Int): JavaPairDStream[K, JLong] = {
- JavaPairDStream.scalaToJavaLong(dstream.countByKey(numPartitions));
- }
-
-
- /**
- * Create a new DStream by counting the number of values of each key in each RDD. Hash
- * partitioning is used to generate the RDDs with the default number of partitions.
- */
- def countByKey(): JavaPairDStream[K, JLong] = {
- JavaPairDStream.scalaToJavaLong(dstream.countByKey());
- }
-
- /**
- * Creates a new DStream by applying `groupByKey` over a sliding window. This is similar to
+ * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to
* `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs
* with the same interval as this DStream. Hash partitioning is used to generate the RDDs with
* Spark's default number of partitions.
@@ -178,7 +156,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new DStream by applying `groupByKey` over a sliding window. Similar to
+ * Return a new DStream by applying `groupByKey` over a sliding window. Similar to
* `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to
* generate the RDDs with Spark's default number of partitions.
* @param windowDuration width of the window; must be a multiple of this DStream's
@@ -193,7 +171,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream.
+ * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream.
* Similar to `DStream.groupByKey()`, but applies it over a sliding window.
* Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
* @param windowDuration width of the window; must be a multiple of this DStream's
@@ -210,7 +188,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream.
+ * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream.
* Similar to `DStream.groupByKey()`, but applies it over a sliding window.
* @param windowDuration width of the window; must be a multiple of this DStream's
* batching interval
@@ -243,7 +221,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to
+ * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to
* `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to
* generate the RDDs with Spark's default number of partitions.
* @param reduceFunc associative reduce function
@@ -262,7 +240,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to
+ * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to
* `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to
* generate the RDDs with `numPartitions` partitions.
* @param reduceFunc associative reduce function
@@ -283,7 +261,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new DStream by applying `reduceByKey` over a sliding window. Similar to
+ * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to
* `DStream.reduceByKey()`, but applies it over a sliding window.
* @param reduceFunc associative reduce function
* @param windowDuration width of the window; must be a multiple of this DStream's
@@ -303,7 +281,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new DStream by reducing over a using incremental computation.
+ * Return a new DStream by reducing over a using incremental computation.
* The reduced value of over a new window is calculated using the old window's reduce value :
* 1. reduce the new values that entered the window (e.g., adding new counts)
* 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
@@ -328,7 +306,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new DStream by reducing over a using incremental computation.
+ * Return a new DStream by applying incremental `reduceByKey` over a sliding window.
* The reduced value of over a new window is calculated using the old window's reduce value :
* 1. reduce the new values that entered the window (e.g., adding new counts)
* 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
@@ -342,25 +320,31 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* @param slideDuration sliding interval of the window (i.e., the interval after which
* the new DStream will generate RDDs); must be a multiple of this
* DStream's batching interval
- * @param numPartitions Number of partitions of each RDD in the new DStream.
+ * @param numPartitions number of partitions of each RDD in the new DStream.
+ * @param filterFunc function to filter expired key-value pairs;
+ * only pairs that satisfy the function are retained
+ * set this to null if you do not want to filter
*/
def reduceByKeyAndWindow(
reduceFunc: Function2[V, V, V],
invReduceFunc: Function2[V, V, V],
windowDuration: Duration,
slideDuration: Duration,
- numPartitions: Int
+ numPartitions: Int,
+ filterFunc: JFunction[(K, V), java.lang.Boolean]
): JavaPairDStream[K, V] = {
dstream.reduceByKeyAndWindow(
reduceFunc,
invReduceFunc,
windowDuration,
slideDuration,
- numPartitions)
+ numPartitions,
+ (p: (K, V)) => filterFunc(p).booleanValue()
+ )
}
/**
- * Create a new DStream by reducing over a using incremental computation.
+ * Return a new DStream by applying incremental `reduceByKey` over a sliding window.
* The reduced value of over a new window is calculated using the old window's reduce value :
* 1. reduce the new values that entered the window (e.g., adding new counts)
* 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
@@ -374,49 +358,26 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* the new DStream will generate RDDs); must be a multiple of this
* DStream's batching interval
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ * @param filterFunc function to filter expired key-value pairs;
+ * only pairs that satisfy the function are retained
+ * set this to null if you do not want to filter
*/
def reduceByKeyAndWindow(
reduceFunc: Function2[V, V, V],
invReduceFunc: Function2[V, V, V],
windowDuration: Duration,
slideDuration: Duration,
- partitioner: Partitioner
- ): JavaPairDStream[K, V] = {
+ partitioner: Partitioner,
+ filterFunc: JFunction[(K, V), java.lang.Boolean]
+ ): JavaPairDStream[K, V] = {
dstream.reduceByKeyAndWindow(
reduceFunc,
invReduceFunc,
windowDuration,
slideDuration,
- partitioner)
- }
-
- /**
- * Create a new DStream by counting the number of values for each key over a window.
- * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
- * @param windowDuration width of the window; must be a multiple of this DStream's
- * batching interval
- * @param slideDuration sliding interval of the window (i.e., the interval after which
- * the new DStream will generate RDDs); must be a multiple of this
- * DStream's batching interval
- */
- def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration)
- : JavaPairDStream[K, JLong] = {
- JavaPairDStream.scalaToJavaLong(dstream.countByKeyAndWindow(windowDuration, slideDuration))
- }
-
- /**
- * Create a new DStream by counting the number of values for each key over a window.
- * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
- * @param windowDuration width of the window; must be a multiple of this DStream's
- * batching interval
- * @param slideDuration sliding interval of the window (i.e., the interval after which
- * the new DStream will generate RDDs); must be a multiple of this
- * DStream's batching interval
- * @param numPartitions Number of partitions of each RDD in the new DStream.
- */
- def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int)
- : JavaPairDStream[K, Long] = {
- dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions)
+ partitioner,
+ (p: (K, V)) => filterFunc(p).booleanValue()
+ )
}
private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]):
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
index e7f446a49b..f3b40b5b88 100644
--- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
@@ -1,16 +1,26 @@
package spark.streaming.api.java
-import scala.collection.JavaConversions._
-import java.lang.{Long => JLong, Integer => JInt}
-
import spark.streaming._
-import dstream._
+import receivers.{ActorReceiver, ReceiverSupervisorStrategy}
+import spark.streaming.dstream._
import spark.storage.StorageLevel
+
import spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
+import spark.api.java.{JavaSparkContext, JavaRDD}
+
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+
+import twitter4j.Status
+
+import akka.actor.Props
+import akka.actor.SupervisorStrategy
+import akka.zeromq.Subscribe
+
+import scala.collection.JavaConversions._
+
+import java.lang.{Long => JLong, Integer => JInt}
import java.io.InputStream
import java.util.{Map => JMap}
-import spark.api.java.{JavaSparkContext, JavaRDD}
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -27,11 +37,11 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Creates a StreamingContext.
* @param master Name of the Spark Master
- * @param frameworkName Name to be used when registering with the scheduler
+ * @param appName Name to be used when registering with the scheduler
* @param batchDuration The time interval at which streaming data will be divided into batches
*/
- def this(master: String, frameworkName: String, batchDuration: Duration) =
- this(new StreamingContext(master, frameworkName, batchDuration))
+ def this(master: String, appName: String, batchDuration: Duration) =
+ this(new StreamingContext(master, appName, batchDuration))
/**
* Creates a StreamingContext.
@@ -53,27 +63,24 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Create an input stream that pulls messages form a Kafka Broker.
- * @param hostname Zookeper hostname.
- * @param port Zookeper port.
+ * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
*/
def kafkaStream[T](
- hostname: String,
- port: Int,
+ zkQuorum: String,
groupId: String,
topics: JMap[String, JInt])
: JavaDStream[T] = {
implicit val cmt: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- ssc.kafkaStream[T](hostname, port, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*))
+ ssc.kafkaStream[T](zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*))
}
/**
* Create an input stream that pulls messages form a Kafka Broker.
- * @param hostname Zookeper hostname.
- * @param port Zookeper port.
+ * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
@@ -81,8 +88,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* By default the value is pulled from zookeper.
*/
def kafkaStream[T](
- hostname: String,
- port: Int,
+ zkQuorum: String,
groupId: String,
topics: JMap[String, JInt],
initialOffsets: JMap[KafkaPartitionKey, JLong])
@@ -90,8 +96,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
implicit val cmt: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
ssc.kafkaStream[T](
- hostname,
- port,
+ zkQuorum,
groupId,
Map(topics.mapValues(_.intValue()).toSeq: _*),
Map(initialOffsets.mapValues(_.longValue()).toSeq: _*))
@@ -99,8 +104,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Create an input stream that pulls messages form a Kafka Broker.
- * @param hostname Zookeper hostname.
- * @param port Zookeper port.
+ * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
@@ -109,8 +113,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param storageLevel RDD storage level. Defaults to memory-only
*/
def kafkaStream[T](
- hostname: String,
- port: Int,
+ zkQuorum: String,
groupId: String,
topics: JMap[String, JInt],
initialOffsets: JMap[KafkaPartitionKey, JLong],
@@ -119,8 +122,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
implicit val cmt: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
ssc.kafkaStream[T](
- hostname,
- port,
+ zkQuorum,
groupId,
Map(topics.mapValues(_.intValue()).toSeq: _*),
Map(initialOffsets.mapValues(_.longValue()).toSeq: _*),
@@ -136,9 +138,9 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param storageLevel Storage level to use for storing the received objects
* (default: StorageLevel.MEMORY_AND_DISK_SER_2)
*/
- def networkTextStream(hostname: String, port: Int, storageLevel: StorageLevel)
+ def socketTextStream(hostname: String, port: Int, storageLevel: StorageLevel)
: JavaDStream[String] = {
- ssc.networkTextStream(hostname, port, storageLevel)
+ ssc.socketTextStream(hostname, port, storageLevel)
}
/**
@@ -148,8 +150,8 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param hostname Hostname to connect to for receiving data
* @param port Port to connect to for receiving data
*/
- def networkTextStream(hostname: String, port: Int): JavaDStream[String] = {
- ssc.networkTextStream(hostname, port)
+ def socketTextStream(hostname: String, port: Int): JavaDStream[String] = {
+ ssc.socketTextStream(hostname, port)
}
/**
@@ -162,7 +164,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param storageLevel Storage level to use for storing the received objects
* @tparam T Type of the objects received (after converting bytes to objects)
*/
- def networkStream[T](
+ def socketStream[T](
hostname: String,
port: Int,
converter: JFunction[InputStream, java.lang.Iterable[T]],
@@ -171,7 +173,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
def fn = (x: InputStream) => converter.apply(x).toIterator
implicit val cmt: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- ssc.networkStream(hostname, port, fn, storageLevel)
+ ssc.socketStream(hostname, port, fn, storageLevel)
}
/**
@@ -194,13 +196,13 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param storageLevel Storage level to use for storing the received objects
* @tparam T Type of the objects in the received blocks
*/
- def rawNetworkStream[T](
+ def rawSocketStream[T](
hostname: String,
port: Int,
storageLevel: StorageLevel): JavaDStream[T] = {
implicit val cmt: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- JavaDStream.fromDStream(ssc.rawNetworkStream(hostname, port, storageLevel))
+ JavaDStream.fromDStream(ssc.rawSocketStream(hostname, port, storageLevel))
}
/**
@@ -212,10 +214,10 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param port Port to connect to for receiving data
* @tparam T Type of the objects in the received blocks
*/
- def rawNetworkStream[T](hostname: String, port: Int): JavaDStream[T] = {
+ def rawSocketStream[T](hostname: String, port: Int): JavaDStream[T] = {
implicit val cmt: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- JavaDStream.fromDStream(ssc.rawNetworkStream(hostname, port))
+ JavaDStream.fromDStream(ssc.rawSocketStream(hostname, port))
}
/**
@@ -254,15 +256,182 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param hostname Hostname of the slave machine to which the flume data will be sent
* @param port Port of the slave machine to which the flume data will be sent
*/
- def flumeStream(hostname: String, port: Int):
- JavaDStream[SparkFlumeEvent] = {
+ def flumeStream(hostname: String, port: Int): JavaDStream[SparkFlumeEvent] = {
ssc.flumeStream(hostname, port)
}
/**
+ * Create a input stream that returns tweets received from Twitter.
+ * @param username Twitter username
+ * @param password Twitter password
+ * @param filters Set of filter strings to get only those tweets that match them
+ * @param storageLevel Storage level to use for storing the received objects
+ */
+ def twitterStream(
+ username: String,
+ password: String,
+ filters: Array[String],
+ storageLevel: StorageLevel
+ ): JavaDStream[Status] = {
+ ssc.twitterStream(username, password, filters, storageLevel)
+ }
+
+ /**
+ * Create a input stream that returns tweets received from Twitter.
+ * @param username Twitter username
+ * @param password Twitter password
+ * @param filters Set of filter strings to get only those tweets that match them
+ */
+ def twitterStream(
+ username: String,
+ password: String,
+ filters: Array[String]
+ ): JavaDStream[Status] = {
+ ssc.twitterStream(username, password, filters)
+ }
+
+ /**
+ * Create a input stream that returns tweets received from Twitter.
+ * @param username Twitter username
+ * @param password Twitter password
+ */
+ def twitterStream(
+ username: String,
+ password: String
+ ): JavaDStream[Status] = {
+ ssc.twitterStream(username, password)
+ }
+
+ /**
+ * Create an input stream with any arbitrary user implemented actor receiver.
+ * @param props Props object defining creation of the actor
+ * @param name Name of the actor
+ * @param storageLevel Storage level to use for storing the received objects
+ *
+ * @note An important point to note:
+ * Since Actor may exist outside the spark framework, It is thus user's responsibility
+ * to ensure the type safety, i.e parametrized type of data received and actorStream
+ * should be same.
+ */
+ def actorStream[T](
+ props: Props,
+ name: String,
+ storageLevel: StorageLevel,
+ supervisorStrategy: SupervisorStrategy
+ ): JavaDStream[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ ssc.actorStream[T](props, name, storageLevel, supervisorStrategy)
+ }
+
+ /**
+ * Create an input stream with any arbitrary user implemented actor receiver.
+ * @param props Props object defining creation of the actor
+ * @param name Name of the actor
+ * @param storageLevel Storage level to use for storing the received objects
+ *
+ * @note An important point to note:
+ * Since Actor may exist outside the spark framework, It is thus user's responsibility
+ * to ensure the type safety, i.e parametrized type of data received and actorStream
+ * should be same.
+ */
+ def actorStream[T](
+ props: Props,
+ name: String,
+ storageLevel: StorageLevel
+ ): JavaDStream[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ ssc.actorStream[T](props, name, storageLevel)
+ }
+
+ /**
+ * Create an input stream with any arbitrary user implemented actor receiver.
+ * @param props Props object defining creation of the actor
+ * @param name Name of the actor
+ *
+ * @note An important point to note:
+ * Since Actor may exist outside the spark framework, It is thus user's responsibility
+ * to ensure the type safety, i.e parametrized type of data received and actorStream
+ * should be same.
+ */
+ def actorStream[T](
+ props: Props,
+ name: String
+ ): JavaDStream[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ ssc.actorStream[T](props, name)
+ }
+
+ /**
+ * Create an input stream that receives messages pushed by a zeromq publisher.
+ * @param publisherUrl Url of remote zeromq publisher
+ * @param subscribe topic to subscribe to
+ * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence
+ * of byte thus it needs the converter(which might be deserializer of bytes)
+ * to translate from sequence of sequence of bytes, where sequence refer to a frame
+ * and sub sequence refer to its payload.
+ * @param storageLevel Storage level to use for storing the received objects
+ */
+ def zeroMQStream[T](
+ publisherUrl:String,
+ subscribe: Subscribe,
+ bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T],
+ storageLevel: StorageLevel,
+ supervisorStrategy: SupervisorStrategy
+ ): JavaDStream[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ ssc.zeroMQStream[T](publisherUrl, subscribe, bytesToObjects, storageLevel, supervisorStrategy)
+ }
+
+ /**
+ * Create an input stream that receives messages pushed by a zeromq publisher.
+ * @param publisherUrl Url of remote zeromq publisher
+ * @param subscribe topic to subscribe to
+ * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence
+ * of byte thus it needs the converter(which might be deserializer of bytes)
+ * to translate from sequence of sequence of bytes, where sequence refer to a frame
+ * and sub sequence refer to its payload.
+ * @param storageLevel RDD storage level. Defaults to memory-only.
+ */
+ def zeroMQStream[T](
+ publisherUrl:String,
+ subscribe: Subscribe,
+ bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]],
+ storageLevel: StorageLevel
+ ): JavaDStream[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ def fn(x: Seq[Seq[Byte]]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator
+ ssc.zeroMQStream[T](publisherUrl, subscribe, fn, storageLevel)
+ }
+
+ /**
+ * Create an input stream that receives messages pushed by a zeromq publisher.
+ * @param publisherUrl Url of remote zeromq publisher
+ * @param subscribe topic to subscribe to
+ * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence
+ * of byte thus it needs the converter(which might be deserializer of bytes)
+ * to translate from sequence of sequence of bytes, where sequence refer to a frame
+ * and sub sequence refer to its payload.
+ */
+ def zeroMQStream[T](
+ publisherUrl:String,
+ subscribe: Subscribe,
+ bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]]
+ ): JavaDStream[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ def fn(x: Seq[Seq[Byte]]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator
+ ssc.zeroMQStream[T](publisherUrl, subscribe, fn)
+ }
+
+ /**
* Registers an output stream that will be computed every interval
*/
- def registerOutputStream(outputStream: JavaDStreamLike[_, _]) {
+ def registerOutputStream(outputStream: JavaDStreamLike[_, _, _]) {
ssc.registerOutputStream(outputStream.dstream)
}
@@ -322,12 +491,11 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Sets the context to periodically checkpoint the DStream operations for master
- * fault-tolerance. By default, the graph will be checkpointed every batch interval.
+ * fault-tolerance. The graph will be checkpointed every batch interval.
* @param directory HDFS-compatible directory where the checkpoint data will be reliably stored
- * @param interval checkpoint interval
*/
- def checkpoint(directory: String, interval: Duration = null) {
- ssc.checkpoint(directory, interval)
+ def checkpoint(directory: String) {
+ ssc.checkpoint(directory)
}
/**
diff --git a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala
index ddb1bf6b28..4ef4bb7de1 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala
@@ -6,7 +6,7 @@ import spark.streaming.{Time, DStream, Duration}
private[streaming]
class CoGroupedDStream[K : ClassManifest](
- parents: Seq[DStream[(_, _)]],
+ parents: Seq[DStream[(K, _)]],
partitioner: Partitioner
) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) {
diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala
index 1e6ad84b44..41b9bd9461 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala
@@ -2,13 +2,14 @@ package spark.streaming.dstream
import spark.RDD
import spark.rdd.UnionRDD
-import spark.streaming.{StreamingContext, Time}
+import spark.streaming.{DStreamCheckpointData, StreamingContext, Time}
import org.apache.hadoop.fs.{FileSystem, Path, PathFilter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{HashSet, HashMap}
+import java.io.{ObjectInputStream, IOException}
private[streaming]
class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest](
@@ -18,28 +19,23 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K
newFilesOnly: Boolean = true)
extends InputDStream[(K, V)](ssc_) {
- @transient private var path_ : Path = null
- @transient private var fs_ : FileSystem = null
-
- var lastModTime = 0L
- val lastModTimeFiles = new HashSet[String]()
+ protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData
- def path(): Path = {
- if (path_ == null) path_ = new Path(directory)
- path_
- }
+ // Latest file mod time seen till any point of time
+ private val lastModTimeFiles = new HashSet[String]()
+ private var lastModTime = 0L
- def fs(): FileSystem = {
- if (fs_ == null) fs_ = path.getFileSystem(new Configuration())
- fs_
- }
+ @transient private var path_ : Path = null
+ @transient private var fs_ : FileSystem = null
+ @transient private[streaming] var files = new HashMap[Time, Array[String]]
override def start() {
if (newFilesOnly) {
- lastModTime = System.currentTimeMillis()
+ lastModTime = graph.zeroTime.milliseconds
} else {
lastModTime = 0
}
+ logDebug("LastModTime initialized to " + lastModTime + ", new files only = " + newFilesOnly)
}
override def stop() { }
@@ -49,38 +45,50 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K
* a union RDD out of them. Note that this maintains the list of files that were processed
* in the latest modification time in the previous call to this method. This is because the
* modification time returned by the FileStatus API seems to return times only at the
- * granularity of seconds. Hence, new files may have the same modification time as the
- * latest modification time in the previous call to this method and the list of files
- * maintained is used to filter the one that have been processed.
+ * granularity of seconds. And new files may have the same modification time as the
+ * latest modification time in the previous call to this method yet was not reported in
+ * the previous call.
*/
override def compute(validTime: Time): Option[RDD[(K, V)]] = {
+ assert(validTime.milliseconds >= lastModTime, "Trying to get new files for really old time [" + validTime + " < " + lastModTime)
+
// Create the filter for selecting new files
val newFilter = new PathFilter() {
+ // Latest file mod time seen in this round of fetching files and its corresponding files
var latestModTime = 0L
val latestModTimeFiles = new HashSet[String]()
def accept(path: Path): Boolean = {
- if (!filter(path)) {
+ if (!filter(path)) { // Reject file if it does not satisfy filter
+ logDebug("Rejected by filter " + path)
return false
- } else {
+ } else { // Accept file only if
val modTime = fs.getFileStatus(path).getModificationTime()
- if (modTime < lastModTime){
- return false
+ logDebug("Mod time for " + path + " is " + modTime)
+ if (modTime < lastModTime) {
+ logDebug("Mod time less than last mod time")
+ return false // If the file was created before the last time it was called
} else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) {
- return false
+ logDebug("Mod time equal to last mod time, but file considered already")
+ return false // If the file was created exactly as lastModTime but not reported yet
+ } else if (modTime > validTime.milliseconds) {
+ logDebug("Mod time more than valid time")
+ return false // If the file was created after the time this function call requires
}
if (modTime > latestModTime) {
latestModTime = modTime
latestModTimeFiles.clear()
+ logDebug("Latest mod time updated to " + latestModTime)
}
latestModTimeFiles += path.toString
+ logDebug("Accepted " + path)
return true
}
}
}
-
- val newFiles = fs.listStatus(path, newFilter)
- logInfo("New files: " + newFiles.map(_.getPath).mkString(", "))
+ logDebug("Finding new files at time " + validTime + " for last mod time = " + lastModTime)
+ val newFiles = fs.listStatus(path, newFilter).map(_.getPath.toString)
+ logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n"))
if (newFiles.length > 0) {
// Update the modification time and the files processed for that modification time
if (lastModTime != newFilter.latestModTime) {
@@ -88,10 +96,81 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K
lastModTimeFiles.clear()
}
lastModTimeFiles ++= newFilter.latestModTimeFiles
+ logDebug("Last mod time updated to " + lastModTime)
+ }
+ files += ((validTime, newFiles))
+ Some(filesToRDD(newFiles))
+ }
+
+ /** Clear the old time-to-files mappings along with old RDDs */
+ protected[streaming] override def clearOldMetadata(time: Time) {
+ super.clearOldMetadata(time)
+ val oldFiles = files.filter(_._1 <= (time - rememberDuration))
+ files --= oldFiles.keys
+ logInfo("Cleared " + oldFiles.size + " old files that were older than " +
+ (time - rememberDuration) + ": " + oldFiles.keys.mkString(", "))
+ logDebug("Cleared files are:\n" +
+ oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n"))
+ }
+
+ /** Generate one RDD from an array of files */
+ protected[streaming] def filesToRDD(files: Seq[String]): RDD[(K, V)] = {
+ new UnionRDD(
+ context.sparkContext,
+ files.map(file => context.sparkContext.newAPIHadoopFile[K, V, F](file))
+ )
+ }
+
+ private def path: Path = {
+ if (path_ == null) path_ = new Path(directory)
+ path_
+ }
+
+ private def fs: FileSystem = {
+ if (fs_ == null) fs_ = path.getFileSystem(new Configuration())
+ fs_
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ logDebug(this.getClass().getSimpleName + ".readObject used")
+ ois.defaultReadObject()
+ generatedRDDs = new HashMap[Time, RDD[(K,V)]] ()
+ files = new HashMap[Time, Array[String]]
+ }
+
+ /**
+ * A custom version of the DStreamCheckpointData that stores names of
+ * Hadoop files as checkpoint data.
+ */
+ private[streaming]
+ class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) {
+
+ def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]]
+
+ override def update() {
+ hadoopFiles.clear()
+ hadoopFiles ++= files
+ }
+
+ override def cleanup() { }
+
+ override def restore() {
+ hadoopFiles.foreach {
+ case (t, f) => {
+ // Restore the metadata in both files and generatedRDDs
+ logInfo("Restoring files for time " + t + " - " +
+ f.mkString("[", ", ", "]") )
+ files += ((t, f))
+ generatedRDDs += ((t, filesToRDD(f)))
+ }
+ }
+ }
+
+ override def toString() = {
+ "[\n" + hadoopFiles.size + " file sets\n" +
+ hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]"
}
- val newRDD = new UnionRDD(ssc.sc, newFiles.map(
- file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString)))
- Some(newRDD)
}
}
@@ -100,3 +179,4 @@ object FileInputDStream {
def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".")
}
+
diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala
index efc7058480..c9644b3a83 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala
@@ -25,7 +25,7 @@ class FlumeInputDStream[T: ClassManifest](
storageLevel: StorageLevel
) extends NetworkInputDStream[SparkFlumeEvent](ssc_) {
- override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = {
+ override def getReceiver(): NetworkReceiver[SparkFlumeEvent] = {
new FlumeReceiver(host, port, storageLevel)
}
}
@@ -134,4 +134,4 @@ class FlumeReceiver(
}
override def getLocationPreference = Some(host)
-} \ No newline at end of file
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala
index 980ca5177e..3c5d43a609 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala
@@ -1,10 +1,42 @@
package spark.streaming.dstream
-import spark.streaming.{Duration, StreamingContext, DStream}
+import spark.streaming.{Time, Duration, StreamingContext, DStream}
+/**
+ * This is the abstract base class for all input streams. This class provides to methods
+ * start() and stop() which called by the scheduler to start and stop receiving data/
+ * Input streams that can generated RDDs from new data just by running a service on
+ * the driver node (that is, without running a receiver onworker nodes) can be
+ * implemented by directly subclassing this InputDStream. For example,
+ * FileInputDStream, a subclass of InputDStream, monitors a HDFS directory for
+ * new files and generates RDDs on the new files. For implementing input streams
+ * that requires running a receiver on the worker nodes, use NetworkInputDStream
+ * as the parent class.
+ */
abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext)
extends DStream[T](ssc_) {
+ var lastValidTime: Time = null
+
+ /**
+ * Checks whether the 'time' is valid wrt slideDuration for generating RDD.
+ * Additionally it also ensures valid times are in strictly increasing order.
+ * This ensures that InputDStream.compute() is called strictly on increasing
+ * times.
+ */
+ override protected def isTimeValid(time: Time): Boolean = {
+ if (!super.isTimeValid(time)) {
+ false // Time not valid
+ } else {
+ // Time is valid, but check it it is more than lastValidTime
+ if (lastValidTime != null && time < lastValidTime) {
+ logWarning("isTimeValid called with " + time + " where as last valid time is " + lastValidTime)
+ }
+ lastValidTime = time
+ true
+ }
+ }
+
override def dependencies = List()
override def slideDuration: Duration = {
@@ -13,7 +45,9 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContex
ssc.graph.batchDuration
}
+ /** Method called to start receiving data. Subclasses must implement this method. */
def start()
+ /** Method called to stop receiving data. Subclasses must implement this method. */
def stop()
}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
index 2b4740bdf7..dc7139cc27 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
@@ -19,21 +19,11 @@ import scala.collection.JavaConversions._
// Key for a specific Kafka Partition: (broker, topic, group, part)
case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int)
-// NOT USED - Originally intended for fault-tolerance
-// Metadata for a Kafka Stream that it sent to the Master
-private[streaming]
-case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long])
-// NOT USED - Originally intended for fault-tolerance
-// Checkpoint data specific to a KafkaInputDstream
-private[streaming]
-case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any],
- savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds)
/**
* Input stream that pulls messages from a Kafka Broker.
*
- * @param host Zookeper hostname.
- * @param port Zookeper port.
+ * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
@@ -44,65 +34,22 @@ case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any],
private[streaming]
class KafkaInputDStream[T: ClassManifest](
@transient ssc_ : StreamingContext,
- host: String,
- port: Int,
+ zkQuorum: String,
groupId: String,
topics: Map[String, Int],
initialOffsets: Map[KafkaPartitionKey, Long],
storageLevel: StorageLevel
) extends NetworkInputDStream[T](ssc_ ) with Logging {
- // Metadata that keeps track of which messages have already been consumed.
- var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]()
-
- /* NOT USED - Originally intended for fault-tolerance
-
- // In case of a failure, the offets for a particular timestamp will be restored.
- @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null
-
-
- override protected[streaming] def addMetadata(metadata: Any) {
- metadata match {
- case x : KafkaInputDStreamMetadata =>
- savedOffsets(x.timestamp) = x.data
- // TOOD: Remove logging
- logInfo("New saved Offsets: " + savedOffsets)
- case _ => logInfo("Received unknown metadata: " + metadata.toString)
- }
- }
-
- override protected[streaming] def updateCheckpointData(currentTime: Time) {
- super.updateCheckpointData(currentTime)
- if(savedOffsets.size > 0) {
- // Find the offets that were stored before the checkpoint was initiated
- val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last
- val latestOffsets = savedOffsets(key)
- logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString)
- checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets)
- // TODO: This may throw out offsets that are created after the checkpoint,
- // but it's unlikely we'll need them.
- savedOffsets.clear()
- }
- }
-
- override protected[streaming] def restoreCheckpointData() {
- super.restoreCheckpointData()
- logInfo("Restoring KafkaDStream checkpoint data.")
- checkpointData match {
- case x : KafkaDStreamCheckpointData =>
- restoredOffsets = x.savedOffsets
- logInfo("Restored KafkaDStream offsets: " + savedOffsets)
- }
- } */
- def createReceiver(): NetworkReceiver[T] = {
- new KafkaReceiver(host, port, groupId, topics, initialOffsets, storageLevel)
+ def getReceiver(): NetworkReceiver[T] = {
+ new KafkaReceiver(zkQuorum, groupId, topics, initialOffsets, storageLevel)
.asInstanceOf[NetworkReceiver[T]]
}
}
private[streaming]
-class KafkaReceiver(host: String, port: Int, groupId: String,
+class KafkaReceiver(zkQuorum: String, groupId: String,
topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long],
storageLevel: StorageLevel) extends NetworkReceiver[Any] {
@@ -111,8 +58,6 @@ class KafkaReceiver(host: String, port: Int, groupId: String,
// Handles pushing data into the BlockManager
lazy protected val blockGenerator = new BlockGenerator(storageLevel)
- // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset
- lazy val offsets = HashMap[KafkaPartitionKey, Long]()
// Connection to Kafka
var consumerConnector : ZookeeperConsumerConnector = null
@@ -127,24 +72,23 @@ class KafkaReceiver(host: String, port: Int, groupId: String,
// In case we are using multiple Threads to handle Kafka Messages
val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
- val zooKeeperEndPoint = host + ":" + port
logInfo("Starting Kafka Consumer Stream with group: " + groupId)
logInfo("Initial offsets: " + initialOffsets.toString)
-
+
// Zookeper connection properties
val props = new Properties()
- props.put("zk.connect", zooKeeperEndPoint)
+ props.put("zk.connect", zkQuorum)
props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString)
props.put("groupid", groupId)
// Create the connection to the cluster
- logInfo("Connecting to Zookeper: " + zooKeeperEndPoint)
+ logInfo("Connecting to Zookeper: " + zkQuorum)
val consumerConfig = new ConsumerConfig(props)
consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector]
- logInfo("Connected to " + zooKeeperEndPoint)
+ logInfo("Connected to " + zkQuorum)
- // Reset the Kafka offsets in case we are recovering from a failure
- resetOffsets(initialOffsets)
+ // If specified, set the topic offset
+ setOffsets(initialOffsets)
// Create Threads for each Topic/Message Stream we are listening
val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder())
@@ -157,11 +101,11 @@ class KafkaReceiver(host: String, port: Int, groupId: String,
}
// Overwrites the offets in Zookeper.
- private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) {
+ private def setOffsets(offsets: Map[KafkaPartitionKey, Long]) {
offsets.foreach { case(key, offset) =>
val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic)
val partitionName = key.brokerId + "-" + key.partId
- updatePersistentPath(consumerConnector.zkClient,
+ updatePersistentPath(consumerConnector.zkClient,
topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString)
}
}
@@ -172,29 +116,10 @@ class KafkaReceiver(host: String, port: Int, groupId: String,
logInfo("Starting MessageHandler.")
stream.takeWhile { msgAndMetadata =>
blockGenerator += msgAndMetadata.message
-
- // Updating the offet. The key is (broker, topic, group, partition).
- val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic,
- groupId, msgAndMetadata.topicInfo.partition.partId)
- val offset = msgAndMetadata.topicInfo.getConsumeOffset
- offsets.put(key, offset)
- // logInfo("Handled message: " + (key, offset).toString)
-
// Keep on handling messages
+
true
- }
+ }
}
}
-
- // NOT USED - Originally intended for fault-tolerance
- // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel)
- // extends BufferingBlockCreator[Any](receiver, storageLevel) {
-
- // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = {
- // // Creates a new Block with Kafka-specific Metadata
- // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap))
- // }
-
- // }
-
}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
index 8c322dd698..7385474963 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
@@ -20,7 +20,7 @@ import java.util.concurrent.ArrayBlockingQueue
/**
* Abstract class for defining any InputDStream that has to start a receiver on worker
* nodes to receive external data. Specific implementations of NetworkInputDStream must
- * define the createReceiver() function that creates the receiver object of type
+ * define the getReceiver() function that gets the receiver object of type
* [[spark.streaming.dstream.NetworkReceiver]] that will be sent to the workers to receive
* data.
* @param ssc_ Streaming context that will execute this input stream
@@ -34,11 +34,11 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
val id = ssc.getNewNetworkStreamId()
/**
- * Creates the receiver object that will be sent to the worker nodes
+ * Gets the receiver object that will be sent to the worker nodes
* to receive data. This method needs to defined by any specific implementation
* of a NetworkInputDStream.
*/
- def createReceiver(): NetworkReceiver[T]
+ def getReceiver(): NetworkReceiver[T]
// Nothing to start or stop as both taken care of by the NetworkInputTracker.
def start() {}
@@ -46,8 +46,15 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
def stop() {}
override def compute(validTime: Time): Option[RDD[T]] = {
- val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime)
- Some(new BlockRDD[T](ssc.sc, blockIds))
+ // If this is called for any time before the start time of the context,
+ // then this returns an empty RDD. This may happen when recovering from a
+ // master failure
+ if (validTime >= graph.startTime) {
+ val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime)
+ Some(new BlockRDD[T](ssc.sc, blockIds))
+ } else {
+ Some(new BlockRDD[T](ssc.sc, Array[String]()))
+ }
}
}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala
new file mode 100644
index 0000000000..3c2a81947b
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala
@@ -0,0 +1,13 @@
+package spark.streaming.dstream
+
+import spark.streaming.StreamingContext
+
+private[streaming]
+class PluggableInputDStream[T: ClassManifest](
+ @transient ssc_ : StreamingContext,
+ receiver: NetworkReceiver[T]) extends NetworkInputDStream[T](ssc_) {
+
+ def getReceiver(): NetworkReceiver[T] = {
+ receiver
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala
index 024bf3bea4..6b310bc0b6 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala
@@ -7,6 +7,7 @@ import scala.collection.mutable.Queue
import scala.collection.mutable.ArrayBuffer
import spark.streaming.{Time, StreamingContext}
+private[streaming]
class QueueInputDStream[T: ClassManifest](
@transient ssc: StreamingContext,
val queue: Queue[RDD[T]],
diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala
index 04e6b69b7b..1b2fa56779 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala
@@ -25,7 +25,7 @@ class RawInputDStream[T: ClassManifest](
storageLevel: StorageLevel
) extends NetworkInputDStream[T](ssc_ ) with Logging {
- def createReceiver(): NetworkReceiver[T] = {
+ def getReceiver(): NetworkReceiver[T] = {
new RawNetworkReceiver(host, port, storageLevel).asInstanceOf[NetworkReceiver[T]]
}
}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala
index 733d5c4a25..343b6915e7 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala
@@ -3,7 +3,7 @@ package spark.streaming.dstream
import spark.streaming.StreamingContext._
import spark.RDD
-import spark.rdd.CoGroupedRDD
+import spark.rdd.{CoGroupedRDD, MapPartitionsRDD}
import spark.Partitioner
import spark.SparkContext._
import spark.storage.StorageLevel
@@ -15,7 +15,8 @@ private[streaming]
class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
parent: DStream[(K, V)],
reduceFunc: (V, V) => V,
- invReduceFunc: (V, V) => V,
+ invReduceFunc: (V, V) => V,
+ filterFunc: Option[((K, V)) => Boolean],
_windowDuration: Duration,
_slideDuration: Duration,
partitioner: Partitioner
@@ -87,21 +88,24 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
//
// Get the RDDs of the reduced values in "old time steps"
- val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration)
+ val oldRDDs =
+ reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration)
logDebug("# old RDDs = " + oldRDDs.size)
// Get the RDDs of the reduced values in "new time steps"
- val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideDuration, currentWindow.endTime)
+ val newRDDs =
+ reducedStream.slice(previousWindow.endTime + parent.slideDuration, currentWindow.endTime)
logDebug("# new RDDs = " + newRDDs.size)
// Get the RDD of the reduced value of the previous window
- val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]()))
+ val previousWindowRDD =
+ getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]()))
// Make the list of RDDs that needs to cogrouped together for reducing their reduced values
val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs
// Cogroup the reduced RDDs and merge the reduced values
- val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner)
+ val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(K, _)]]], partitioner)
//val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _
val numOldValues = oldRDDs.size
@@ -114,7 +118,9 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
// Getting reduced values "old time steps" that will be removed from current window
val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head)
// Getting reduced values "new time steps"
- val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head)
+ val newValues =
+ (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head)
+
if (seqOfValues(0).isEmpty) {
// If previous window's reduce value does not exist, then at least new values should exist
if (newValues.isEmpty) {
@@ -140,10 +146,12 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues)
- Some(mergedValuesRDD)
+ if (filterFunc.isDefined) {
+ Some(mergedValuesRDD.filter(filterFunc.get))
+ } else {
+ Some(mergedValuesRDD)
+ }
}
-
-
}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala
index d42027092b..4af839ad7f 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala
@@ -15,7 +15,7 @@ class SocketInputDStream[T: ClassManifest](
storageLevel: StorageLevel
) extends NetworkInputDStream[T](ssc_) {
- def createReceiver(): NetworkReceiver[T] = {
+ def getReceiver(): NetworkReceiver[T] = {
new SocketReceiver(host, port, bytesToObjects, storageLevel)
}
}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala
index b4506c74aa..db62955036 100644
--- a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala
@@ -48,8 +48,16 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S: ClassManifest](
//logDebug("Generating state RDD for time " + validTime)
return Some(stateRDD)
}
- case None => { // If parent RDD does not exist, then return old state RDD
- return Some(prevStateRDD)
+ case None => { // If parent RDD does not exist
+
+ // Re-apply the update function to the old state RDD
+ val updateFuncLocal = updateFunc
+ val finalFunc = (iterator: Iterator[(K, S)]) => {
+ val i = iterator.map(t => (t._1, Seq[V](), Option(t._2)))
+ updateFuncLocal(i)
+ }
+ val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning)
+ return Some(stateRDD)
}
}
}
diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala
index 99ed4cdc1c..c697498862 100644
--- a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala
@@ -1,12 +1,11 @@
-package spark.streaming.examples.twitter
+package spark.streaming.dstream
import spark._
import spark.streaming._
-import dstream.{NetworkReceiver, NetworkInputDStream}
import storage.StorageLevel
+
import twitter4j._
import twitter4j.auth.BasicAuthorization
-import collection.JavaConversions._
/* A stream of Twitter statuses, potentially filtered by one or more keywords.
*
@@ -14,19 +13,21 @@ import collection.JavaConversions._
* An optional set of string filters can be used to restrict the set of tweets. The Twitter API is
* such that this may return a sampled subset of all tweets during each interval.
*/
+private[streaming]
class TwitterInputDStream(
@transient ssc_ : StreamingContext,
username: String,
password: String,
filters: Seq[String],
storageLevel: StorageLevel
- ) extends NetworkInputDStream[Status](ssc_) {
+ ) extends NetworkInputDStream[Status](ssc_) {
- override def createReceiver(): NetworkReceiver[Status] = {
+ override def getReceiver(): NetworkReceiver[Status] = {
new TwitterReceiver(username, password, filters, storageLevel)
}
}
+private[streaming]
class TwitterReceiver(
username: String,
password: String,
@@ -50,7 +51,7 @@ class TwitterReceiver(
def onTrackLimitationNotice(i: Int) {}
def onScrubGeo(l: Long, l1: Long) {}
def onStallWarning(stallWarning: StallWarning) {}
- def onException(e: Exception) {}
+ def onException(e: Exception) { stopOnError(e) }
})
val query: FilterQuery = new FilterQuery
diff --git a/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala
new file mode 100644
index 0000000000..b3201d0b28
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala
@@ -0,0 +1,153 @@
+package spark.streaming.receivers
+
+import akka.actor.{ Actor, PoisonPill, Props, SupervisorStrategy }
+import akka.actor.{ actorRef2Scala, ActorRef }
+import akka.actor.{ PossiblyHarmful, OneForOneStrategy }
+
+import spark.storage.StorageLevel
+import spark.streaming.dstream.NetworkReceiver
+
+import java.util.concurrent.atomic.AtomicInteger
+
+/** A helper with set of defaults for supervisor strategy **/
+object ReceiverSupervisorStrategy {
+
+ import akka.util.duration._
+ import akka.actor.SupervisorStrategy._
+
+ val defaultStrategy = OneForOneStrategy(maxNrOfRetries = 10, withinTimeRange =
+ 15 millis) {
+ case _: RuntimeException ⇒ Restart
+ case _: Exception ⇒ Escalate
+ }
+}
+
+/**
+ * A receiver trait to be mixed in with your Actor to gain access to
+ * pushBlock API.
+ *
+ * @example {{{
+ * class MyActor extends Actor with Receiver{
+ * def receive {
+ * case anything :String ⇒ pushBlock(anything)
+ * }
+ * }
+ * //Can be plugged in actorStream as follows
+ * ssc.actorStream[String](Props(new MyActor),"MyActorReceiver")
+ *
+ * }}}
+ *
+ * @note An important point to note:
+ * Since Actor may exist outside the spark framework, It is thus user's responsibility
+ * to ensure the type safety, i.e parametrized type of push block and InputDStream
+ * should be same.
+ *
+ */
+trait Receiver { self: Actor ⇒
+ def pushBlock[T: ClassManifest](iter: Iterator[T]) {
+ context.parent ! Data(iter)
+ }
+
+ def pushBlock[T: ClassManifest](data: T) {
+ context.parent ! Data(data)
+ }
+
+}
+
+/**
+ * Statistics for querying the supervisor about state of workers
+ */
+case class Statistics(numberOfMsgs: Int,
+ numberOfWorkers: Int,
+ numberOfHiccups: Int,
+ otherInfo: String)
+
+/** Case class to receive data sent by child actors **/
+private[streaming] case class Data[T: ClassManifest](data: T)
+
+/**
+ * Provides Actors as receivers for receiving stream.
+ *
+ * As Actors can also be used to receive data from almost any stream source.
+ * A nice set of abstraction(s) for actors as receivers is already provided for
+ * a few general cases. It is thus exposed as an API where user may come with
+ * his own Actor to run as receiver for Spark Streaming input source.
+ *
+ * This starts a supervisor actor which starts workers and also provides
+ * [http://doc.akka.io/docs/akka/2.0.5/scala/fault-tolerance.html fault-tolerance].
+ *
+ * Here's a way to start more supervisor/workers as its children.
+ *
+ * @example {{{
+ * context.parent ! Props(new Supervisor)
+ * }}} OR {{{
+ * context.parent ! Props(new Worker,"Worker")
+ * }}}
+ *
+ *
+ */
+private[streaming] class ActorReceiver[T: ClassManifest](
+ props: Props,
+ name: String,
+ storageLevel: StorageLevel,
+ receiverSupervisorStrategy: SupervisorStrategy)
+ extends NetworkReceiver[T] {
+
+ protected lazy val blocksGenerator: BlockGenerator =
+ new BlockGenerator(storageLevel)
+
+ protected lazy val supervisor = env.actorSystem.actorOf(Props(new Supervisor),
+ "Supervisor" + streamId)
+
+ private class Supervisor extends Actor {
+
+ override val supervisorStrategy = receiverSupervisorStrategy
+ val worker = context.actorOf(props, name)
+ logInfo("Started receiver worker at:" + worker.path)
+
+ val n: AtomicInteger = new AtomicInteger(0)
+ val hiccups: AtomicInteger = new AtomicInteger(0)
+
+ def receive = {
+
+ case Data(iter: Iterator[_]) ⇒ pushBlock(iter.asInstanceOf[Iterator[T]])
+
+ case Data(msg) ⇒
+ blocksGenerator += msg.asInstanceOf[T]
+ n.incrementAndGet
+
+ case props: Props ⇒
+ val worker = context.actorOf(props)
+ logInfo("Started receiver worker at:" + worker.path)
+ sender ! worker
+
+ case (props: Props, name: String) ⇒
+ val worker = context.actorOf(props, name)
+ logInfo("Started receiver worker at:" + worker.path)
+ sender ! worker
+
+ case _: PossiblyHarmful => hiccups.incrementAndGet()
+
+ case _: Statistics ⇒
+ val workers = context.children
+ sender ! Statistics(n.get, workers.size, hiccups.get, workers.mkString("\n"))
+
+ }
+ }
+
+ protected def pushBlock(iter: Iterator[T]) {
+ pushBlock("block-" + streamId + "-" + System.nanoTime(),
+ iter, null, storageLevel)
+ }
+
+ protected def onStart() = {
+ blocksGenerator.start()
+ supervisor
+ logInfo("Supervision tree for receivers initialized at:" + supervisor.path)
+ }
+
+ protected def onStop() = {
+ supervisor ! PoisonPill
+ }
+
+}
diff --git a/streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala b/streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala
new file mode 100644
index 0000000000..5533c3cf1e
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala
@@ -0,0 +1,33 @@
+package spark.streaming.receivers
+
+import akka.actor.Actor
+import akka.zeromq._
+
+import spark.Logging
+
+/**
+ * A receiver to subscribe to ZeroMQ stream.
+ */
+private[streaming] class ZeroMQReceiver[T: ClassManifest](publisherUrl: String,
+ subscribe: Subscribe,
+ bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T])
+ extends Actor with Receiver with Logging {
+
+ override def preStart() = context.system.newSocket(SocketType.Sub, Listener(self),
+ Connect(publisherUrl), subscribe)
+
+ def receive: Receive = {
+
+ case Connecting ⇒ logInfo("connecting ...")
+
+ case m: ZMQMessage ⇒
+ logDebug("Received message for:" + m.firstFrameAsString)
+
+ //We ignore first frame for processing as it is the topic
+ val bytes = m.frames.tail.map(_.payload)
+ pushBlock(bytesToObjects(bytes))
+
+ case Closed ⇒ logInfo("received closed ")
+
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala
new file mode 100644
index 0000000000..bdd9f4d753
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala
@@ -0,0 +1,392 @@
+package spark.streaming.util
+
+import spark.{Logging, RDD}
+import spark.streaming._
+import spark.streaming.dstream.ForEachDStream
+import StreamingContext._
+
+import scala.util.Random
+import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+
+import java.io.{File, ObjectInputStream, IOException}
+import java.util.UUID
+
+import com.google.common.io.Files
+
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.fs.{FileUtil, FileSystem, Path}
+import org.apache.hadoop.conf.Configuration
+
+
+private[streaming]
+object MasterFailureTest extends Logging {
+ initLogging()
+
+ @volatile var killed = false
+ @volatile var killCount = 0
+
+ def main(args: Array[String]) {
+ if (args.size < 2) {
+ println(
+ "Usage: MasterFailureTest <local/HDFS directory> <# batches> [<batch size in milliseconds>]")
+ System.exit(1)
+ }
+ val directory = args(0)
+ val numBatches = args(1).toInt
+ val batchDuration = if (args.size > 2) Milliseconds(args(2).toInt) else Seconds(1)
+
+ println("\n\n========================= MAP TEST =========================\n\n")
+ testMap(directory, numBatches, batchDuration)
+
+ println("\n\n================= UPDATE-STATE-BY-KEY TEST =================\n\n")
+ testUpdateStateByKey(directory, numBatches, batchDuration)
+
+ println("\n\nSUCCESS\n\n")
+ }
+
+ def testMap(directory: String, numBatches: Int, batchDuration: Duration) {
+ // Input: time=1 ==> [ 1 ] , time=2 ==> [ 2 ] , time=3 ==> [ 3 ] , ...
+ val input = (1 to numBatches).map(_.toString).toSeq
+ // Expected output: time=1 ==> [ 1 ] , time=2 ==> [ 2 ] , time=3 ==> [ 3 ] , ...
+ val expectedOutput = (1 to numBatches)
+
+ val operation = (st: DStream[String]) => st.map(_.toInt)
+
+ // Run streaming operation with multiple master failures
+ val output = testOperation(directory, batchDuration, input, operation, expectedOutput)
+
+ logInfo("Expected output, size = " + expectedOutput.size)
+ logInfo(expectedOutput.mkString("[", ",", "]"))
+ logInfo("Output, size = " + output.size)
+ logInfo(output.mkString("[", ",", "]"))
+
+ // Verify whether all the values of the expected output is present
+ // in the output
+ assert(output.distinct.toSet == expectedOutput.toSet)
+ }
+
+
+ def testUpdateStateByKey(directory: String, numBatches: Int, batchDuration: Duration) {
+ // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ...
+ val input = (1 to numBatches).map(i => (1 to i).map(_ => "a").mkString(" ")).toSeq
+ // Expected output: time=1 ==> [ (a, 1) ] , time=2 ==> [ (a, 3) ] , time=3 ==> [ (a,6) ] , ...
+ val expectedOutput = (1L to numBatches).map(i => (1L to i).reduce(_ + _)).map(j => ("a", j))
+
+ val operation = (st: DStream[String]) => {
+ val updateFunc = (values: Seq[Long], state: Option[Long]) => {
+ Some(values.foldLeft(0L)(_ + _) + state.getOrElse(0L))
+ }
+ st.flatMap(_.split(" "))
+ .map(x => (x, 1L))
+ .updateStateByKey[Long](updateFunc)
+ .checkpoint(batchDuration * 5)
+ }
+
+ // Run streaming operation with multiple master failures
+ val output = testOperation(directory, batchDuration, input, operation, expectedOutput)
+
+ logInfo("Expected output, size = " + expectedOutput.size + "\n" + expectedOutput)
+ logInfo("Output, size = " + output.size + "\n" + output)
+
+ // Verify whether all the values in the output are among the expected output values
+ output.foreach(o =>
+ assert(expectedOutput.contains(o), "Expected value " + o + " not found")
+ )
+
+ // Verify whether the last expected output value has been generated, there by
+ // confirming that none of the inputs have been missed
+ assert(output.last == expectedOutput.last)
+ }
+
+ /**
+ * Tests stream operation with multiple master failures, and verifies whether the
+ * final set of output values is as expected or not.
+ */
+ def testOperation[T: ClassManifest](
+ directory: String,
+ batchDuration: Duration,
+ input: Seq[String],
+ operation: DStream[String] => DStream[T],
+ expectedOutput: Seq[T]
+ ): Seq[T] = {
+
+ // Just making sure that the expected output does not have duplicates
+ assert(expectedOutput.distinct.toSet == expectedOutput.toSet)
+
+ // Setup the stream computation with the given operation
+ val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation)
+
+ // Start generating files in the a different thread
+ val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds)
+ fileGeneratingThread.start()
+
+ // Run the streams and repeatedly kill it until the last expected output
+ // has been generated, or until it has run for twice the expected time
+ val lastExpectedOutput = expectedOutput.last
+ val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2
+ val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun)
+
+ // Delete directories
+ fileGeneratingThread.join()
+ val fs = checkpointDir.getFileSystem(new Configuration())
+ fs.delete(checkpointDir, true)
+ fs.delete(testDir, true)
+ logInfo("Finished test after " + killCount + " failures")
+ mergedOutput
+ }
+
+ /**
+ * Sets up the stream computation with the given operation, directory (local or HDFS),
+ * and batch duration. Returns the streaming context and the directory to which
+ * files should be written for testing.
+ */
+ private def setupStreams[T: ClassManifest](
+ directory: String,
+ batchDuration: Duration,
+ operation: DStream[String] => DStream[T]
+ ): (StreamingContext, Path, Path) = {
+ // Reset all state
+ reset()
+
+ // Create the directories for this test
+ val uuid = UUID.randomUUID().toString
+ val rootDir = new Path(directory, uuid)
+ val fs = rootDir.getFileSystem(new Configuration())
+ val checkpointDir = new Path(rootDir, "checkpoint")
+ val testDir = new Path(rootDir, "test")
+ fs.mkdirs(checkpointDir)
+ fs.mkdirs(testDir)
+
+ // Setup the streaming computation with the given operation
+ System.clearProperty("spark.driver.port")
+ var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration)
+ ssc.checkpoint(checkpointDir.toString)
+ val inputStream = ssc.textFileStream(testDir.toString)
+ val operatedStream = operation(inputStream)
+ val outputStream = new TestOutputStream(operatedStream)
+ ssc.registerOutputStream(outputStream)
+ (ssc, checkpointDir, testDir)
+ }
+
+
+ /**
+ * Repeatedly starts and kills the streaming context until timed out or
+ * the last expected output is generated. Finally, return
+ */
+ private def runStreams[T: ClassManifest](
+ ssc_ : StreamingContext,
+ lastExpectedOutput: T,
+ maxTimeToRun: Long
+ ): Seq[T] = {
+
+ var ssc = ssc_
+ var totalTimeRan = 0L
+ var isLastOutputGenerated = false
+ var isTimedOut = false
+ val mergedOutput = new ArrayBuffer[T]()
+ val checkpointDir = ssc.checkpointDir
+ var batchDuration = ssc.graph.batchDuration
+
+ while(!isLastOutputGenerated && !isTimedOut) {
+ // Get the output buffer
+ val outputBuffer = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[T]].output
+ def output = outputBuffer.flatMap(x => x)
+
+ // Start the thread to kill the streaming after some time
+ killed = false
+ val killingThread = new KillingThread(ssc, batchDuration.milliseconds * 10)
+ killingThread.start()
+
+ var timeRan = 0L
+ try {
+ // Start the streaming computation and let it run while ...
+ // (i) StreamingContext has not been shut down yet
+ // (ii) The last expected output has not been generated yet
+ // (iii) Its not timed out yet
+ System.clearProperty("spark.streaming.clock")
+ System.clearProperty("spark.driver.port")
+ ssc.start()
+ val startTime = System.currentTimeMillis()
+ while (!killed && !isLastOutputGenerated && !isTimedOut) {
+ Thread.sleep(100)
+ timeRan = System.currentTimeMillis() - startTime
+ isLastOutputGenerated = (!output.isEmpty && output.last == lastExpectedOutput)
+ isTimedOut = (timeRan + totalTimeRan > maxTimeToRun)
+ }
+ } catch {
+ case e: Exception => logError("Error running streaming context", e)
+ }
+ if (killingThread.isAlive) killingThread.interrupt()
+ ssc.stop()
+
+ logInfo("Has been killed = " + killed)
+ logInfo("Is last output generated = " + isLastOutputGenerated)
+ logInfo("Is timed out = " + isTimedOut)
+
+ // Verify whether the output of each batch has only one element or no element
+ // and then merge the new output with all the earlier output
+ mergedOutput ++= output
+ totalTimeRan += timeRan
+ logInfo("New output = " + output)
+ logInfo("Merged output = " + mergedOutput)
+ logInfo("Time ran = " + timeRan)
+ logInfo("Total time ran = " + totalTimeRan)
+
+ if (!isLastOutputGenerated && !isTimedOut) {
+ val sleepTime = Random.nextInt(batchDuration.milliseconds.toInt * 10)
+ logInfo(
+ "\n-------------------------------------------\n" +
+ " Restarting stream computation in " + sleepTime + " ms " +
+ "\n-------------------------------------------\n"
+ )
+ Thread.sleep(sleepTime)
+ // Recreate the streaming context from checkpoint
+ ssc = new StreamingContext(checkpointDir)
+ }
+ }
+ mergedOutput
+ }
+
+ /**
+ * Verifies the output value are the same as expected. Since failures can lead to
+ * a batch being processed twice, a batches output may appear more than once
+ * consecutively. To avoid getting confused with those, we eliminate consecutive
+ * duplicate batch outputs of values from the `output`. As a result, the
+ * expected output should not have consecutive batches with the same values as output.
+ */
+ private def verifyOutput[T: ClassManifest](output: Seq[T], expectedOutput: Seq[T]) {
+ // Verify whether expected outputs do not consecutive batches with same output
+ for (i <- 0 until expectedOutput.size - 1) {
+ assert(expectedOutput(i) != expectedOutput(i+1),
+ "Expected output has consecutive duplicate sequence of values")
+ }
+
+ // Log the output
+ println("Expected output, size = " + expectedOutput.size)
+ println(expectedOutput.mkString("[", ",", "]"))
+ println("Output, size = " + output.size)
+ println(output.mkString("[", ",", "]"))
+
+ // Match the output with the expected output
+ output.foreach(o =>
+ assert(expectedOutput.contains(o), "Expected value " + o + " not found")
+ )
+ }
+
+ /** Resets counter to prepare for the test */
+ private def reset() {
+ killed = false
+ killCount = 0
+ }
+}
+
+/**
+ * This is a output stream just for testing. All the output is collected into a
+ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ */
+private[streaming]
+class TestOutputStream[T: ClassManifest](
+ parent: DStream[T],
+ val output: ArrayBuffer[Seq[T]] = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
+ ) extends ForEachDStream[T](
+ parent,
+ (rdd: RDD[T], t: Time) => {
+ val collected = rdd.collect()
+ output += collected
+ }
+ ) {
+
+ // This is to clear the output buffer every it is read from a checkpoint
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ ois.defaultReadObject()
+ output.clear()
+ }
+}
+
+
+/**
+ * Thread to kill streaming context after a random period of time.
+ */
+private[streaming]
+class KillingThread(ssc: StreamingContext, maxKillWaitTime: Long) extends Thread with Logging {
+ initLogging()
+
+ override def run() {
+ try {
+ // If it is the first killing, then allow the first checkpoint to be created
+ var minKillWaitTime = if (MasterFailureTest.killCount == 0) 5000 else 2000
+ val killWaitTime = minKillWaitTime + math.abs(Random.nextLong % maxKillWaitTime)
+ logInfo("Kill wait time = " + killWaitTime)
+ Thread.sleep(killWaitTime)
+ logInfo(
+ "\n---------------------------------------\n" +
+ "Killing streaming context after " + killWaitTime + " ms" +
+ "\n---------------------------------------\n"
+ )
+ if (ssc != null) {
+ ssc.stop()
+ MasterFailureTest.killed = true
+ MasterFailureTest.killCount += 1
+ }
+ logInfo("Killing thread finished normally")
+ } catch {
+ case ie: InterruptedException => logInfo("Killing thread interrupted")
+ case e: Exception => logWarning("Exception in killing thread", e)
+ }
+
+ }
+}
+
+
+/**
+ * Thread to generate input files periodically with the desired text.
+ */
+private[streaming]
+class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long)
+ extends Thread with Logging {
+ initLogging()
+
+ override def run() {
+ val localTestDir = Files.createTempDir()
+ var fs = testDir.getFileSystem(new Configuration())
+ val maxTries = 3
+ try {
+ Thread.sleep(5000) // To make sure that all the streaming context has been set up
+ for (i <- 0 until input.size) {
+ // Write the data to a local file and then move it to the target test directory
+ val localFile = new File(localTestDir, (i+1).toString)
+ val hadoopFile = new Path(testDir, (i+1).toString)
+ FileUtils.writeStringToFile(localFile, input(i).toString + "\n")
+ var tries = 0
+ var done = false
+ while (!done && tries < maxTries) {
+ tries += 1
+ try {
+ fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile)
+ done = true
+ } catch {
+ case ioe: IOException => {
+ fs = testDir.getFileSystem(new Configuration())
+ logWarning("Attempt " + tries + " at generating file " + hadoopFile + " failed.", ioe)
+ }
+ }
+ }
+ if (!done)
+ logError("Could not generate file " + hadoopFile)
+ else
+ logInfo("Generated file " + hadoopFile + " at " + System.currentTimeMillis)
+ Thread.sleep(interval)
+ localFile.delete()
+ }
+ logInfo("File generating thread finished normally")
+ } catch {
+ case ie: InterruptedException => logInfo("File generating thread interrupted")
+ case e: Exception => logWarning("File generating in killing thread", e)
+ } finally {
+ fs.close()
+ }
+ }
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
index db715cc295..8e10276deb 100644
--- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
+++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
@@ -3,9 +3,9 @@ package spark.streaming.util
private[streaming]
class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) {
- val minPollTime = 25L
+ private val minPollTime = 25L
- val pollTime = {
+ private val pollTime = {
if (period / 10.0 > minPollTime) {
(period / 10.0).toLong
} else {
@@ -13,11 +13,20 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) =>
}
}
- val thread = new Thread() {
+ private val thread = new Thread() {
override def run() { loop }
}
- var nextTime = 0L
+ private var nextTime = 0L
+
+ def getStartTime(): Long = {
+ (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period
+ }
+
+ def getRestartTime(originalStartTime: Long): Long = {
+ val gap = clock.currentTime - originalStartTime
+ (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime
+ }
def start(startTime: Long): Long = {
nextTime = startTime
@@ -26,21 +35,14 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) =>
}
def start(): Long = {
- val startTime = (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period
- start(startTime)
+ start(getStartTime())
}
- def restart(originalStartTime: Long): Long = {
- val gap = clock.currentTime - originalStartTime
- val newStartTime = (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime
- start(newStartTime)
- }
-
- def stop() {
+ def stop() {
thread.interrupt()
}
- def loop() {
+ private def loop() {
try {
while (true) {
clock.waitTillTime(nextTime)
diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
index 79d6093429..3bed500f73 100644
--- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
@@ -11,7 +11,10 @@ import org.junit.Before;
import org.junit.Test;
import scala.Tuple2;
import spark.HashPartitioner;
+import spark.api.java.JavaPairRDD;
import spark.api.java.JavaRDD;
+import spark.api.java.JavaRDDLike;
+import spark.api.java.JavaPairRDD;
import spark.api.java.JavaSparkContext;
import spark.api.java.function.*;
import spark.storage.StorageLevel;
@@ -21,10 +24,16 @@ import spark.streaming.api.java.JavaStreamingContext;
import spark.streaming.JavaTestUtils;
import spark.streaming.JavaCheckpointTestUtils;
import spark.streaming.dstream.KafkaPartitionKey;
+import spark.streaming.InputStreamsSuite;
import java.io.*;
import java.util.*;
+import akka.actor.Props;
+import akka.zeromq.Subscribe;
+
+
+
// The test suite itself is Serializable so that anonymous Function implementations can be
// serialized, as an alternative to converting these anonymous classes to static inner classes;
// see http://stackoverflow.com/questions/758570/.
@@ -33,8 +42,9 @@ public class JavaAPISuite implements Serializable {
@Before
public void setUp() {
- ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
- ssc.checkpoint("checkpoint", new Duration(1000));
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock");
+ ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
+ ssc.checkpoint("checkpoint");
}
@After
@@ -134,29 +144,6 @@ public class JavaAPISuite implements Serializable {
}
@Test
- public void testTumble() {
- List<List<Integer>> inputData = Arrays.asList(
- Arrays.asList(1,2,3),
- Arrays.asList(4,5,6),
- Arrays.asList(7,8,9),
- Arrays.asList(10,11,12),
- Arrays.asList(13,14,15),
- Arrays.asList(16,17,18));
-
- List<List<Integer>> expected = Arrays.asList(
- Arrays.asList(1,2,3,4,5,6),
- Arrays.asList(7,8,9,10,11,12),
- Arrays.asList(13,14,15,16,17,18));
-
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream windowed = stream.tumble(new Duration(2000));
- JavaTestUtils.attachTestOutputStream(windowed);
- List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 6, 3);
-
- assertOrderInvariantEquals(expected, result);
- }
-
- @Test
public void testFilter() {
List<List<String>> inputData = Arrays.asList(
Arrays.asList("giants", "dodgers"),
@@ -315,8 +302,9 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(6,7,8),
Arrays.asList(9,10,11));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream transformed = stream.transform(new Function<JavaRDD<Integer>, JavaRDD<Integer>>() {
+ JavaDStream<Integer> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<Integer> transformed =
+ stream.transform(new Function<JavaRDD<Integer>, JavaRDD<Integer>>() {
@Override
public JavaRDD<Integer> call(JavaRDD<Integer> in) throws Exception {
return in.map(new Function<Integer, Integer>() {
@@ -507,6 +495,141 @@ public class JavaAPISuite implements Serializable {
new Tuple2<String, Integer>("new york", 1)));
@Test
+ public void testPairMap() { // Maps pair -> pair of different type
+ List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
+
+ List<List<Tuple2<Integer, String>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<Integer, String>(1, "california"),
+ new Tuple2<Integer, String>(3, "california"),
+ new Tuple2<Integer, String>(4, "new york"),
+ new Tuple2<Integer, String>(1, "new york")),
+ Arrays.asList(
+ new Tuple2<Integer, String>(5, "california"),
+ new Tuple2<Integer, String>(5, "california"),
+ new Tuple2<Integer, String>(3, "new york"),
+ new Tuple2<Integer, String>(1, "new york")));
+
+ JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+ JavaPairDStream<Integer, String> reversed = pairStream.map(
+ new PairFunction<Tuple2<String, Integer>, Integer, String>() {
+ @Override
+ public Tuple2<Integer, String> call(Tuple2<String, Integer> in) throws Exception {
+ return in.swap();
+ }
+ });
+
+ JavaTestUtils.attachTestOutputStream(reversed);
+ List<List<Tuple2<Integer, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testPairMapPartitions() { // Maps pair -> pair of different type
+ List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
+
+ List<List<Tuple2<Integer, String>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<Integer, String>(1, "california"),
+ new Tuple2<Integer, String>(3, "california"),
+ new Tuple2<Integer, String>(4, "new york"),
+ new Tuple2<Integer, String>(1, "new york")),
+ Arrays.asList(
+ new Tuple2<Integer, String>(5, "california"),
+ new Tuple2<Integer, String>(5, "california"),
+ new Tuple2<Integer, String>(3, "new york"),
+ new Tuple2<Integer, String>(1, "new york")));
+
+ JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+ JavaPairDStream<Integer, String> reversed = pairStream.mapPartitions(
+ new PairFlatMapFunction<Iterator<Tuple2<String, Integer>>, Integer, String>() {
+ @Override
+ public Iterable<Tuple2<Integer, String>> call(Iterator<Tuple2<String, Integer>> in) throws Exception {
+ LinkedList<Tuple2<Integer, String>> out = new LinkedList<Tuple2<Integer, String>>();
+ while (in.hasNext()) {
+ Tuple2<String, Integer> next = in.next();
+ out.add(next.swap());
+ }
+ return out;
+ }
+ });
+
+ JavaTestUtils.attachTestOutputStream(reversed);
+ List<List<Tuple2<Integer, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testPairMap2() { // Maps pair -> single
+ List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
+
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(1, 3, 4, 1),
+ Arrays.asList(5, 5, 3, 1));
+
+ JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+ JavaDStream<Integer> reversed = pairStream.map(
+ new Function<Tuple2<String, Integer>, Integer>() {
+ @Override
+ public Integer call(Tuple2<String, Integer> in) throws Exception {
+ return in._2();
+ }
+ });
+
+ JavaTestUtils.attachTestOutputStream(reversed);
+ List<List<Tuple2<Integer, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair
+ List<List<Tuple2<String, Integer>>> inputData = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, Integer>("hi", 1),
+ new Tuple2<String, Integer>("ho", 2)),
+ Arrays.asList(
+ new Tuple2<String, Integer>("hi", 1),
+ new Tuple2<String, Integer>("ho", 2)));
+
+ List<List<Tuple2<Integer, String>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<Integer, String>(1, "h"),
+ new Tuple2<Integer, String>(1, "i"),
+ new Tuple2<Integer, String>(2, "h"),
+ new Tuple2<Integer, String>(2, "o")),
+ Arrays.asList(
+ new Tuple2<Integer, String>(1, "h"),
+ new Tuple2<Integer, String>(1, "i"),
+ new Tuple2<Integer, String>(2, "h"),
+ new Tuple2<Integer, String>(2, "o")));
+
+ JavaDStream<Tuple2<String, Integer>> stream =
+ JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+ JavaPairDStream<Integer, String> flatMapped = pairStream.flatMap(
+ new PairFlatMapFunction<Tuple2<String, Integer>, Integer, String>() {
+ @Override
+ public Iterable<Tuple2<Integer, String>> call(Tuple2<String, Integer> in) throws Exception {
+ List<Tuple2<Integer, String>> out = new LinkedList<Tuple2<Integer, String>>();
+ for (Character s : in._1().toCharArray()) {
+ out.add(new Tuple2<Integer, String>(in._2(), s.toString()));
+ }
+ return out;
+ }
+ });
+ JavaTestUtils.attachTestOutputStream(flatMapped);
+ List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
public void testPairGroupByKey() {
List<List<Tuple2<String, String>>> inputData = stringStringKVStream;
@@ -570,7 +693,7 @@ public class JavaAPISuite implements Serializable {
JavaPairDStream<String, Integer> combined = pairStream.<Integer>combineByKey(
new Function<Integer, Integer>() {
- @Override
+ @Override
public Integer call(Integer i) throws Exception {
return i;
}
@@ -583,50 +706,73 @@ public class JavaAPISuite implements Serializable {
}
@Test
- public void testCountByKey() {
- List<List<Tuple2<String, String>>> inputData = stringStringKVStream;
+ public void testCountByValue() {
+ List<List<String>> inputData = Arrays.asList(
+ Arrays.asList("hello", "world"),
+ Arrays.asList("hello", "moon"),
+ Arrays.asList("hello"));
List<List<Tuple2<String, Long>>> expected = Arrays.asList(
- Arrays.asList(
- new Tuple2<String, Long>("california", 2L),
- new Tuple2<String, Long>("new york", 2L)),
- Arrays.asList(
- new Tuple2<String, Long>("california", 2L),
- new Tuple2<String, Long>("new york", 2L)));
-
- JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream(
- ssc, inputData, 1);
- JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream);
+ Arrays.asList(
+ new Tuple2<String, Long>("hello", 1L),
+ new Tuple2<String, Long>("world", 1L)),
+ Arrays.asList(
+ new Tuple2<String, Long>("hello", 1L),
+ new Tuple2<String, Long>("moon", 1L)),
+ Arrays.asList(
+ new Tuple2<String, Long>("hello", 1L)));
- JavaPairDStream<String, Long> counted = pairStream.countByKey();
+ JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, Long> counted = stream.countByValue();
JavaTestUtils.attachTestOutputStream(counted);
- List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+ List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
Assert.assertEquals(expected, result);
}
@Test
public void testGroupByKeyAndWindow() {
- List<List<Tuple2<String, String>>> inputData = stringStringKVStream;
+ List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
- List<List<Tuple2<String, List<String>>>> expected = Arrays.asList(
- Arrays.asList(new Tuple2<String, List<String>>("california", Arrays.asList("dodgers", "giants")),
- new Tuple2<String, List<String>>("new york", Arrays.asList("yankees", "mets"))),
- Arrays.asList(new Tuple2<String, List<String>>("california",
- Arrays.asList("sharks", "ducks", "dodgers", "giants")),
- new Tuple2<String, List<String>>("new york", Arrays.asList("rangers", "islanders", "yankees", "mets"))),
- Arrays.asList(new Tuple2<String, List<String>>("california", Arrays.asList("sharks", "ducks")),
- new Tuple2<String, List<String>>("new york", Arrays.asList("rangers", "islanders"))));
+ List<List<Tuple2<String, List<Integer>>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, List<Integer>>("california", Arrays.asList(1, 3)),
+ new Tuple2<String, List<Integer>>("new york", Arrays.asList(1, 4))
+ ),
+ Arrays.asList(
+ new Tuple2<String, List<Integer>>("california", Arrays.asList(1, 3, 5, 5)),
+ new Tuple2<String, List<Integer>>("new york", Arrays.asList(1, 1, 3, 4))
+ ),
+ Arrays.asList(
+ new Tuple2<String, List<Integer>>("california", Arrays.asList(5, 5)),
+ new Tuple2<String, List<Integer>>("new york", Arrays.asList(1, 3))
+ )
+ );
- JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream);
+ JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
- JavaPairDStream<String, List<String>> groupWindowed =
+ JavaPairDStream<String, List<Integer>> groupWindowed =
pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000));
JavaTestUtils.attachTestOutputStream(groupWindowed);
- List<List<Tuple2<String, List<String>>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+ List<List<Tuple2<String, List<Integer>>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
- Assert.assertEquals(expected, result);
+ assert(result.size() == expected.size());
+ for (int i = 0; i < result.size(); i++) {
+ assert(convert(result.get(i)).equals(convert(expected.get(i))));
+ }
+ }
+
+ private HashSet<Tuple2<String, HashSet<Integer>>> convert(List<Tuple2<String, List<Integer>>> listOfTuples) {
+ List<Tuple2<String, HashSet<Integer>>> newListOfTuples = new ArrayList<Tuple2<String, HashSet<Integer>>>();
+ for (Tuple2<String, List<Integer>> tuple: listOfTuples) {
+ newListOfTuples.add(convert(tuple));
+ }
+ return new HashSet<Tuple2<String, HashSet<Integer>>>(newListOfTuples);
+ }
+
+ private Tuple2<String, HashSet<Integer>> convert(Tuple2<String, List<Integer>> tuple) {
+ return new Tuple2<String, HashSet<Integer>>(tuple._1(), new HashSet<Integer>(tuple._2()));
}
@Test
@@ -668,7 +814,7 @@ public class JavaAPISuite implements Serializable {
JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey(
- new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>(){
+ new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
@Override
public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
int out = 0;
@@ -680,7 +826,7 @@ public class JavaAPISuite implements Serializable {
}
return Optional.of(out);
}
- });
+ });
JavaTestUtils.attachTestOutputStream(updated);
List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -711,26 +857,28 @@ public class JavaAPISuite implements Serializable {
}
@Test
- public void testCountByKeyAndWindow() {
- List<List<Tuple2<String, String>>> inputData = stringStringKVStream;
+ public void testCountByValueAndWindow() {
+ List<List<String>> inputData = Arrays.asList(
+ Arrays.asList("hello", "world"),
+ Arrays.asList("hello", "moon"),
+ Arrays.asList("hello"));
List<List<Tuple2<String, Long>>> expected = Arrays.asList(
Arrays.asList(
- new Tuple2<String, Long>("california", 2L),
- new Tuple2<String, Long>("new york", 2L)),
+ new Tuple2<String, Long>("hello", 1L),
+ new Tuple2<String, Long>("world", 1L)),
Arrays.asList(
- new Tuple2<String, Long>("california", 4L),
- new Tuple2<String, Long>("new york", 4L)),
+ new Tuple2<String, Long>("hello", 2L),
+ new Tuple2<String, Long>("world", 1L),
+ new Tuple2<String, Long>("moon", 1L)),
Arrays.asList(
- new Tuple2<String, Long>("california", 2L),
- new Tuple2<String, Long>("new york", 2L)));
+ new Tuple2<String, Long>("hello", 2L),
+ new Tuple2<String, Long>("moon", 1L)));
- JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream(
+ JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(
ssc, inputData, 1);
- JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream);
-
JavaPairDStream<String, Long> counted =
- pairStream.countByKeyAndWindow(new Duration(2000), new Duration(1000));
+ stream.countByValueAndWindow(new Duration(2000), new Duration(1000));
JavaTestUtils.attachTestOutputStream(counted);
List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -738,6 +886,90 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void testPairTransform() {
+ List<List<Tuple2<Integer, Integer>>> inputData = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<Integer, Integer>(3, 5),
+ new Tuple2<Integer, Integer>(1, 5),
+ new Tuple2<Integer, Integer>(4, 5),
+ new Tuple2<Integer, Integer>(2, 5)),
+ Arrays.asList(
+ new Tuple2<Integer, Integer>(2, 5),
+ new Tuple2<Integer, Integer>(3, 5),
+ new Tuple2<Integer, Integer>(4, 5),
+ new Tuple2<Integer, Integer>(1, 5)));
+
+ List<List<Tuple2<Integer, Integer>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<Integer, Integer>(1, 5),
+ new Tuple2<Integer, Integer>(2, 5),
+ new Tuple2<Integer, Integer>(3, 5),
+ new Tuple2<Integer, Integer>(4, 5)),
+ Arrays.asList(
+ new Tuple2<Integer, Integer>(1, 5),
+ new Tuple2<Integer, Integer>(2, 5),
+ new Tuple2<Integer, Integer>(3, 5),
+ new Tuple2<Integer, Integer>(4, 5)));
+
+ JavaDStream<Tuple2<Integer, Integer>> stream = JavaTestUtils.attachTestInputStream(
+ ssc, inputData, 1);
+ JavaPairDStream<Integer, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<Integer, Integer> sorted = pairStream.transform(
+ new Function<JavaPairRDD<Integer, Integer>, JavaPairRDD<Integer, Integer>>() {
+ @Override
+ public JavaPairRDD<Integer, Integer> call(JavaPairRDD<Integer, Integer> in) throws Exception {
+ return in.sortByKey();
+ }
+ });
+
+ JavaTestUtils.attachTestOutputStream(sorted);
+ List<List<Tuple2<String, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testPairToNormalRDDTransform() {
+ List<List<Tuple2<Integer, Integer>>> inputData = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<Integer, Integer>(3, 5),
+ new Tuple2<Integer, Integer>(1, 5),
+ new Tuple2<Integer, Integer>(4, 5),
+ new Tuple2<Integer, Integer>(2, 5)),
+ Arrays.asList(
+ new Tuple2<Integer, Integer>(2, 5),
+ new Tuple2<Integer, Integer>(3, 5),
+ new Tuple2<Integer, Integer>(4, 5),
+ new Tuple2<Integer, Integer>(1, 5)));
+
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(3,1,4,2),
+ Arrays.asList(2,3,4,1));
+
+ JavaDStream<Tuple2<Integer, Integer>> stream = JavaTestUtils.attachTestInputStream(
+ ssc, inputData, 1);
+ JavaPairDStream<Integer, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaDStream<Integer> firstParts = pairStream.transform(
+ new Function<JavaPairRDD<Integer, Integer>, JavaRDD<Integer>>() {
+ @Override
+ public JavaRDD<Integer> call(JavaPairRDD<Integer, Integer> in) throws Exception {
+ return in.map(new Function<Tuple2<Integer, Integer>, Integer>() {
+ @Override
+ public Integer call(Tuple2<Integer, Integer> in) {
+ return in._1();
+ }
+ });
+ }
+ });
+
+ JavaTestUtils.attachTestOutputStream(firstParts);
+ List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
public void testMapValues() {
List<List<Tuple2<String, String>>> inputData = stringStringKVStream;
@@ -911,9 +1143,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(1,4),
Arrays.asList(8,7));
-
File tempDir = Files.createTempDir();
- ssc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000));
+ ssc.checkpoint(tempDir.getAbsolutePath());
JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1);
JavaDStream letterCount = stream.map(new Function<String, Integer>() {
@@ -927,14 +1158,16 @@ public class JavaAPISuite implements Serializable {
assertOrderInvariantEquals(expectedInitial, initialResult);
Thread.sleep(1000);
-
ssc.stop();
+
ssc = new JavaStreamingContext(tempDir.getAbsolutePath());
- ssc.start();
- List<List<Integer>> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 2);
- assertOrderInvariantEquals(expectedFinal, finalResult);
+ // Tweak to take into consideration that the last batch before failure
+ // will be re-processed after recovery
+ List<List<Integer>> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 3);
+ assertOrderInvariantEquals(expectedFinal, finalResult.subList(1, 3));
}
+
/** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD
@Test
public void testCheckpointofIndividualStream() throws InterruptedException {
@@ -971,19 +1204,19 @@ public class JavaAPISuite implements Serializable {
public void testKafkaStream() {
HashMap<String, Integer> topics = Maps.newHashMap();
HashMap<KafkaPartitionKey, Long> offsets = Maps.newHashMap();
- JavaDStream test1 = ssc.kafkaStream("localhost", 12345, "group", topics);
- JavaDStream test2 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets);
- JavaDStream test3 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets,
+ JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics);
+ JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, offsets);
+ JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, offsets,
StorageLevel.MEMORY_AND_DISK());
}
@Test
- public void testNetworkTextStream() {
- JavaDStream test = ssc.networkTextStream("localhost", 12345);
+ public void testSocketTextStream() {
+ JavaDStream test = ssc.socketTextStream("localhost", 12345);
}
@Test
- public void testNetworkString() {
+ public void testSocketString() {
class Converter extends Function<InputStream, Iterable<String>> {
public Iterable<String> call(InputStream in) {
BufferedReader reader = new BufferedReader(new InputStreamReader(in));
@@ -999,7 +1232,7 @@ public class JavaAPISuite implements Serializable {
}
}
- JavaDStream test = ssc.networkStream(
+ JavaDStream test = ssc.socketStream(
"localhost",
12345,
new Converter(),
@@ -1012,13 +1245,13 @@ public class JavaAPISuite implements Serializable {
}
@Test
- public void testRawNetworkStream() {
- JavaDStream test = ssc.rawNetworkStream("localhost", 12345);
+ public void testRawSocketStream() {
+ JavaDStream test = ssc.rawSocketStream("localhost", 12345);
}
@Test
public void testFlumeStream() {
- JavaDStream test = ssc.flumeStream("localhost", 12345);
+ JavaDStream test = ssc.flumeStream("localhost", 12345, StorageLevel.MEMORY_ONLY());
}
@Test
@@ -1026,4 +1259,25 @@ public class JavaAPISuite implements Serializable {
JavaPairDStream<String, String> foo =
ssc.<String, String, SequenceFileInputFormat>fileStream("/tmp/foo");
}
+
+ @Test
+ public void testTwitterStream() {
+ String[] filters = new String[] { "good", "bad", "ugly" };
+ JavaDStream test = ssc.twitterStream("username", "password", filters, StorageLevel.MEMORY_ONLY());
+ }
+
+ @Test
+ public void testActorStream() {
+ JavaDStream test = ssc.actorStream((Props)null, "TestActor", StorageLevel.MEMORY_ONLY());
+ }
+
+ @Test
+ public void testZeroMQStream() {
+ JavaDStream test = ssc.zeroMQStream("url", (Subscribe) null, new Function<byte[][], Iterable<String>>() {
+ @Override
+ public Iterable<String> call(byte[][] b) throws Exception {
+ return null;
+ }
+ });
+ }
}
diff --git a/streaming/src/test/java/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala
index 56349837e5..64a7e7cbf9 100644
--- a/streaming/src/test/java/spark/streaming/JavaTestUtils.scala
+++ b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala
@@ -31,8 +31,9 @@ trait JavaTestBase extends TestSuiteBase {
* Attach a provided stream to it's associated StreamingContext as a
* [[spark.streaming.TestOutputStream]].
**/
- def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T,This]](
- dstream: JavaDStreamLike[T, This]) = {
+ def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T, This, R],
+ R <: spark.api.java.JavaRDDLike[T, R]](
+ dstream: JavaDStreamLike[T, This, R]) = {
implicit val cm: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
val ostream = new TestOutputStream(dstream.dstream,
@@ -57,6 +58,7 @@ trait JavaTestBase extends TestSuiteBase {
}
object JavaTestUtils extends JavaTestBase {
+ override def maxWaitTimeMillis = 20000
}
diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties
index edfa1243fa..59c445e63f 100644
--- a/streaming/src/test/resources/log4j.properties
+++ b/streaming/src/test/resources/log4j.properties
@@ -1,5 +1,6 @@
# Set everything to be logged to the file streaming/target/unit-tests.log
log4j.rootCategory=INFO, file
+# log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=streaming/target/unit-tests.log
diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
index 4a036f0710..8fce91853c 100644
--- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
@@ -6,6 +6,8 @@ import util.ManualClock
class BasicOperationsSuite extends TestSuiteBase {
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
override def framework() = "BasicOperationsSuite"
after {
@@ -22,7 +24,7 @@ class BasicOperationsSuite extends TestSuiteBase {
)
}
- test("flatmap") {
+ test("flatMap") {
val input = Seq(1 to 4, 5 to 8, 9 to 12)
testOperation(
input,
@@ -86,6 +88,23 @@ class BasicOperationsSuite extends TestSuiteBase {
)
}
+ test("count") {
+ testOperation(
+ Seq(1 to 1, 1 to 2, 1 to 3, 1 to 4),
+ (s: DStream[Int]) => s.count(),
+ Seq(Seq(1L), Seq(2L), Seq(3L), Seq(4L))
+ )
+ }
+
+ test("countByValue") {
+ testOperation(
+ Seq(1 to 1, Seq(1, 1, 1), 1 to 2, Seq(1, 1, 2, 2)),
+ (s: DStream[Int]) => s.countByValue(),
+ Seq(Seq((1, 1L)), Seq((1, 3L)), Seq((1, 1L), (2, 1L)), Seq((2, 2L), (1, 2L))),
+ true
+ )
+ }
+
test("mapValues") {
testOperation(
Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
@@ -165,6 +184,71 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(inputData, updateStateOperation, outputData, true)
}
+ test("updateStateByKey - object lifecycle") {
+ val inputData =
+ Seq(
+ Seq("a","b"),
+ null,
+ Seq("a","c","a"),
+ Seq("c"),
+ null,
+ null
+ )
+
+ val outputData =
+ Seq(
+ Seq(("a", 1), ("b", 1)),
+ Seq(("a", 1), ("b", 1)),
+ Seq(("a", 3), ("c", 1)),
+ Seq(("a", 3), ("c", 2)),
+ Seq(("c", 2)),
+ Seq()
+ )
+
+ val updateStateOperation = (s: DStream[String]) => {
+ class StateObject(var counter: Int = 0, var expireCounter: Int = 0) extends Serializable
+
+ // updateFunc clears a state when a StateObject is seen without new values twice in a row
+ val updateFunc = (values: Seq[Int], state: Option[StateObject]) => {
+ val stateObj = state.getOrElse(new StateObject)
+ values.foldLeft(0)(_ + _) match {
+ case 0 => stateObj.expireCounter += 1 // no new values
+ case n => { // has new values, increment and reset expireCounter
+ stateObj.counter += n
+ stateObj.expireCounter = 0
+ }
+ }
+ stateObj.expireCounter match {
+ case 2 => None // seen twice with no new values, give it the boot
+ case _ => Option(stateObj)
+ }
+ }
+ s.map(x => (x, 1)).updateStateByKey[StateObject](updateFunc).mapValues(_.counter)
+ }
+
+ testOperation(inputData, updateStateOperation, outputData, true)
+ }
+
+ test("slice") {
+ val ssc = new StreamingContext("local[2]", "BasicOperationSuite", Seconds(1))
+ val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4))
+ val stream = new TestInputStream[Int](ssc, input, 2)
+ ssc.registerInputStream(stream)
+ stream.foreach(_ => {}) // Dummy output stream
+ ssc.start()
+ Thread.sleep(2000)
+ def getInputFromSlice(fromMillis: Long, toMillis: Long) = {
+ stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet
+ }
+
+ assert(getInputFromSlice(0, 1000) == Set(1))
+ assert(getInputFromSlice(0, 2000) == Set(1, 2))
+ assert(getInputFromSlice(1000, 2000) == Set(1, 2))
+ assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4))
+ ssc.stop()
+ Thread.sleep(1000)
+ }
+
test("forgetting of RDDs - map and window operations") {
assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second")
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
index 563a7d1458..cac86deeaf 100644
--- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -1,5 +1,6 @@
package spark.streaming
+import dstream.FileInputDStream
import spark.streaming.StreamingContext._
import java.io.File
import runtime.RichInt
@@ -7,9 +8,19 @@ import org.scalatest.BeforeAndAfter
import org.apache.commons.io.FileUtils
import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import util.{Clock, ManualClock}
+import scala.util.Random
+import com.google.common.io.Files
+
+/**
+ * This test suites tests the checkpointing functionality of DStreams -
+ * the checkpointing of a DStream's RDDs as well as the checkpointing of
+ * the whole DStream graph.
+ */
class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
before {
FileUtils.deleteDirectory(new File(checkpointDir))
}
@@ -28,21 +39,18 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
override def batchDuration = Milliseconds(500)
- override def checkpointInterval = batchDuration
-
override def actuallyWait = true
- test("basic stream+rdd recovery") {
+ test("basic rdd checkpoints + dstream graph checkpoint recovery") {
assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second")
- assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration")
System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
val stateStreamCheckpointInterval = Seconds(1)
// this ensure checkpointing occurs at least once
- val firstNumBatches = (stateStreamCheckpointInterval / batchDuration) * 2
+ val firstNumBatches = (stateStreamCheckpointInterval / batchDuration).toLong * 2
val secondNumBatches = firstNumBatches
// Setup the streams
@@ -62,10 +70,10 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// Run till a time such that at least one RDD in the stream should have been checkpointed,
// then check whether some RDD has been checkpointed or not
ssc.start()
- runStreamsWithRealDelay(ssc, firstNumBatches)
- logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.rdds.mkString(",\n") + "]")
- assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before first failure")
- stateStream.checkpointData.rdds.foreach {
+ advanceTimeWithRealDelay(ssc, firstNumBatches)
+ logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData)
+ assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure")
+ stateStream.checkpointData.checkpointFiles.foreach {
case (time, data) => {
val file = new File(data.toString)
assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist")
@@ -74,8 +82,8 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// Run till a further time such that previous checkpoint files in the stream would be deleted
// and check whether the earlier checkpoint files are deleted
- val checkpointFiles = stateStream.checkpointData.rdds.map(x => new File(x._2.toString))
- runStreamsWithRealDelay(ssc, secondNumBatches)
+ val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2))
+ advanceTimeWithRealDelay(ssc, secondNumBatches)
checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted"))
ssc.stop()
@@ -90,9 +98,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// Run one batch to generate a new checkpoint file and check whether some RDD
// is present in the checkpoint data or not
ssc.start()
- runStreamsWithRealDelay(ssc, 1)
- assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before second failure")
- stateStream.checkpointData.rdds.foreach {
+ advanceTimeWithRealDelay(ssc, 1)
+ assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure")
+ stateStream.checkpointData.checkpointFiles.foreach {
case (time, data) => {
val file = new File(data.toString)
assert(file.exists(),
@@ -111,13 +119,16 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// Adjust manual clock time as if it is being restarted after a delay
System.setProperty("spark.streaming.manualClock.jump", (batchDuration.milliseconds * 7).toString)
ssc.start()
- runStreamsWithRealDelay(ssc, 4)
+ advanceTimeWithRealDelay(ssc, 4)
ssc.stop()
System.clearProperty("spark.streaming.manualClock.jump")
ssc = null
}
- test("map and reduceByKey") {
+ // This tests whether the systm can recover from a master failure with simple
+ // non-stateful operations. This assumes as reliable, replayable input
+ // source - TestInputDStream.
+ test("recovery with map and reduceByKey operations") {
testCheckpointedOperation(
Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ),
(s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
@@ -126,7 +137,11 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
)
}
- test("reduceByKeyAndWindowInv") {
+
+ // This tests whether the ReduceWindowedDStream's RDD checkpoints works correctly such
+ // that the system can recover from a master failure. This assumes as reliable,
+ // replayable input source - TestInputDStream.
+ test("recovery with invertible reduceByKeyAndWindow operation") {
val n = 10
val w = 4
val input = (1 to n).map(_ => Seq("a")).toSeq
@@ -139,7 +154,11 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
testCheckpointedOperation(input, operation, output, 7)
}
- test("updateStateByKey") {
+
+ // This tests whether the StateDStream's RDD checkpoints works correctly such
+ // that the system can recover from a master failure. This assumes as reliable,
+ // replayable input source - TestInputDStream.
+ test("recovery with updateStateByKey operation") {
val input = (1 to 10).map(_ => Seq("a")).toSeq
val output = (1 to 10).map(x => Seq(("a", x))).toSeq
val operation = (st: DStream[String]) => {
@@ -154,11 +173,126 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
testCheckpointedOperation(input, operation, output, 7)
}
+ // This tests whether file input stream remembers what files were seen before
+ // the master failure and uses them again to process a large window operation.
+ // It also tests whether batches, whose processing was incomplete due to the
+ // failure, are re-processed or not.
+ test("recovery with file input stream") {
+ // Disable manual clock as FileInputDStream does not work with manual clock
+ val clockProperty = System.getProperty("spark.streaming.clock")
+ System.clearProperty("spark.streaming.clock")
+
+ // Set up the streaming context and input streams
+ val testDir = Files.createTempDir()
+ var ssc = new StreamingContext(master, framework, Seconds(1))
+ ssc.checkpoint(checkpointDir)
+ val fileStream = ssc.textFileStream(testDir.toString)
+ // Making value 3 take large time to process, to ensure that the master
+ // shuts down in the middle of processing the 3rd batch
+ val mappedStream = fileStream.map(s => {
+ val i = s.toInt
+ if (i == 3) Thread.sleep(2000)
+ i
+ })
+
+ // Reducing over a large window to ensure that recovery from master failure
+ // requires reprocessing of all the files seen before the failure
+ val reducedStream = mappedStream.reduceByWindow(_ + _, Seconds(30), Seconds(1))
+ val outputBuffer = new ArrayBuffer[Seq[Int]]
+ var outputStream = new TestOutputStream(reducedStream, outputBuffer)
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ // Create files and advance manual clock to process them
+ //var clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ Thread.sleep(1000)
+ for (i <- Seq(1, 2, 3)) {
+ FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n")
+ // wait to make sure that the file is written such that it gets shown in the file listings
+ Thread.sleep(1000)
+ }
+ logInfo("Output = " + outputStream.output.mkString(","))
+ assert(outputStream.output.size > 0, "No files processed before restart")
+ ssc.stop()
+
+ // Verify whether files created have been recorded correctly or not
+ var fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]]
+ def recordedFiles = fileInputDStream.files.values.flatMap(x => x)
+ assert(!recordedFiles.filter(_.endsWith("1")).isEmpty)
+ assert(!recordedFiles.filter(_.endsWith("2")).isEmpty)
+ assert(!recordedFiles.filter(_.endsWith("3")).isEmpty)
+
+ // Create files while the master is down
+ for (i <- Seq(4, 5, 6)) {
+ FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n")
+ Thread.sleep(1000)
+ }
+
+ // Recover context from checkpoint file and verify whether the files that were
+ // recorded before failure were saved and successfully recovered
+ logInfo("*********** RESTARTING ************")
+ ssc = new StreamingContext(checkpointDir)
+ fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]]
+ assert(!recordedFiles.filter(_.endsWith("1")).isEmpty)
+ assert(!recordedFiles.filter(_.endsWith("2")).isEmpty)
+ assert(!recordedFiles.filter(_.endsWith("3")).isEmpty)
+
+ // Restart stream computation
+ ssc.start()
+ for (i <- Seq(7, 8, 9)) {
+ FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n")
+ Thread.sleep(1000)
+ }
+ Thread.sleep(1000)
+ logInfo("Output = " + outputStream.output.mkString("[", ", ", "]"))
+ assert(outputStream.output.size > 0, "No files processed after restart")
+ ssc.stop()
+
+ // Verify whether files created while the driver was down have been recorded or not
+ assert(!recordedFiles.filter(_.endsWith("4")).isEmpty)
+ assert(!recordedFiles.filter(_.endsWith("5")).isEmpty)
+ assert(!recordedFiles.filter(_.endsWith("6")).isEmpty)
+
+ // Verify whether new files created after recover have been recorded or not
+ assert(!recordedFiles.filter(_.endsWith("7")).isEmpty)
+ assert(!recordedFiles.filter(_.endsWith("8")).isEmpty)
+ assert(!recordedFiles.filter(_.endsWith("9")).isEmpty)
+
+ // Append the new output to the old buffer
+ outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]]
+ outputBuffer ++= outputStream.output
+
+ val expectedOutput = Seq(1, 3, 6, 10, 15, 21, 28, 36, 45)
+ logInfo("--------------------------------")
+ logInfo("output, size = " + outputBuffer.size)
+ outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("expected output, size = " + expectedOutput.size)
+ expectedOutput.foreach(x => logInfo("[" + x + "]"))
+ logInfo("--------------------------------")
+
+ // Verify whether all the elements received are as expected
+ val output = outputBuffer.flatMap(x => x)
+ assert(output.contains(6)) // To ensure that the 3rd input (i.e., 3) was processed
+ output.foreach(o => // To ensure all the inputs are correctly added cumulatively
+ assert(expectedOutput.contains(o), "Expected value " + o + " not found")
+ )
+ // To ensure that all the inputs were received correctly
+ assert(expectedOutput.last === output.last)
+
+ // Enable manual clock back again for other tests
+ if (clockProperty != null)
+ System.setProperty("spark.streaming.clock", clockProperty)
+ }
+
+
/**
- * Tests a streaming operation under checkpointing, by restart the operation
+ * Tests a streaming operation under checkpointing, by restarting the operation
* from checkpoint file and verifying whether the final output is correct.
* The output is assumed to have come from a reliable queue which an replay
* data as required.
+ *
+ * NOTE: This takes into consideration that the last batch processed before
+ * master failure will be re-processed after restart/recovery.
*/
def testCheckpointedOperation[U: ClassManifest, V: ClassManifest](
input: Seq[Seq[U]],
@@ -172,11 +306,14 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
val totalNumBatches = input.size
val nextNumBatches = totalNumBatches - initialNumBatches
val initialNumExpectedOutputs = initialNumBatches
- val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs
+ val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
+ // because the last batch will be processed again
// Do the computation for initial number of batches, create checkpoint file and quit
ssc = setupStreams[U, V](input, operation)
- val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs)
+ ssc.start()
+ val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches)
+ ssc.stop()
verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
Thread.sleep(1000)
@@ -187,16 +324,20 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
"\n-------------------------------------------\n"
)
ssc = new StreamingContext(checkpointDir)
- val outputNew = runStreams[V](ssc, nextNumBatches, nextNumExpectedOutputs)
+ System.clearProperty("spark.driver.port")
+ ssc.start()
+ val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
+ // the first element will be re-processed data of the last batch before restart
verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
+ ssc.stop()
ssc = null
}
/**
* Advances the manual clock on the streaming scheduler by given number of batches.
- * It also wait for the expected amount of time for each batch.
+ * It also waits for the expected amount of time for each batch.
*/
- def runStreamsWithRealDelay(ssc: StreamingContext, numBatches: Long) {
+ def advanceTimeWithRealDelay[V: ClassManifest](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = {
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
logInfo("Manual clock before advancing = " + clock.time)
for (i <- 1 to numBatches.toInt) {
@@ -205,6 +346,8 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
}
logInfo("Manual clock after advancing = " + clock.time)
Thread.sleep(batchDuration.milliseconds)
- }
+ val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
+ outputStream.output
+ }
} \ No newline at end of file
diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala
index c4cfffbfc1..a5fa7ab92d 100644
--- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala
@@ -1,191 +1,40 @@
package spark.streaming
-import org.scalatest.BeforeAndAfter
-import org.apache.commons.io.FileUtils
+import spark.Logging
+import spark.streaming.util.MasterFailureTest
+import StreamingContext._
+
+import org.scalatest.{FunSuite, BeforeAndAfter}
+import com.google.common.io.Files
import java.io.File
-import scala.runtime.RichInt
-import scala.util.Random
-import spark.streaming.StreamingContext._
+import org.apache.commons.io.FileUtils
import collection.mutable.ArrayBuffer
-import spark.Logging
+
/**
* This testsuite tests master failures at random times while the stream is running using
* the real clock.
*/
-class FailureSuite extends TestSuiteBase with BeforeAndAfter {
+class FailureSuite extends FunSuite with BeforeAndAfter with Logging {
+
+ var directory = "FailureSuite"
+ val numBatches = 30
+ val batchDuration = Milliseconds(1000)
before {
- FileUtils.deleteDirectory(new File(checkpointDir))
+ FileUtils.deleteDirectory(new File(directory))
}
after {
- FailureSuite.reset()
- FileUtils.deleteDirectory(new File(checkpointDir))
-
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.driver.port")
- }
-
- override def framework = "CheckpointSuite"
-
- override def batchDuration = Milliseconds(500)
-
- override def checkpointDir = "checkpoint"
-
- override def checkpointInterval = batchDuration
-
- test("multiple failures with updateStateByKey") {
- val n = 30
- // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ...
- val input = (1 to n).map(i => (1 to i).map(_ =>"a").toSeq).toSeq
- // Last output: [ (a, 465) ] for n=30
- val lastOutput = Seq( ("a", (1 to n).reduce(_ + _)) )
-
- val operation = (st: DStream[String]) => {
- val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
- Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
- }
- st.map(x => (x, 1))
- .updateStateByKey[RichInt](updateFunc)
- .checkpoint(Seconds(2))
- .map(t => (t._1, t._2.self))
- }
-
- testOperationWithMultipleFailures(input, operation, lastOutput, n, n)
- }
-
- test("multiple failures with reduceByKeyAndWindow") {
- val n = 30
- val w = 100
- assert(w > n, "Window should be much larger than the number of input sets in this test")
- // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ...
- val input = (1 to n).map(i => (1 to i).map(_ =>"a").toSeq).toSeq
- // Last output: [ (a, 465) ]
- val lastOutput = Seq( ("a", (1 to n).reduce(_ + _)) )
-
- val operation = (st: DStream[String]) => {
- st.map(x => (x, 1))
- .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration)
- .checkpoint(Seconds(2))
- }
-
- testOperationWithMultipleFailures(input, operation, lastOutput, n, n)
- }
-
-
- /**
- * Tests stream operation with multiple master failures, and verifies whether the
- * final set of output values is as expected or not. Checking the final value is
- * proof that no intermediate data was lost due to master failures.
- */
- def testOperationWithMultipleFailures[U: ClassManifest, V: ClassManifest](
- input: Seq[Seq[U]],
- operation: DStream[U] => DStream[V],
- lastExpectedOutput: Seq[V],
- numBatches: Int,
- numExpectedOutput: Int
- ) {
- var ssc = setupStreams[U, V](input, operation)
- val mergedOutput = new ArrayBuffer[Seq[V]]()
-
- var totalTimeRan = 0L
- while(totalTimeRan <= numBatches * batchDuration.milliseconds * 2) {
- new KillingThread(ssc, numBatches * batchDuration.milliseconds.toInt / 4).start()
- val (output, timeRan) = runStreamsWithRealClock[V](ssc, numBatches, numExpectedOutput)
-
- mergedOutput ++= output
- totalTimeRan += timeRan
- logInfo("New output = " + output)
- logInfo("Merged output = " + mergedOutput)
- logInfo("Total time spent = " + totalTimeRan)
- val sleepTime = Random.nextInt(numBatches * batchDuration.milliseconds.toInt / 8)
- logInfo(
- "\n-------------------------------------------\n" +
- " Restarting stream computation in " + sleepTime + " ms " +
- "\n-------------------------------------------\n"
- )
- Thread.sleep(sleepTime)
- FailureSuite.failed = false
- ssc = new StreamingContext(checkpointDir)
- }
- ssc.stop()
- ssc = null
-
- // Verify whether the last output is the expected one
- val lastOutput = mergedOutput(mergedOutput.lastIndexWhere(!_.isEmpty))
- assert(lastOutput.toSet === lastExpectedOutput.toSet)
- logInfo("Finished computation after " + FailureSuite.failureCount + " failures")
+ FileUtils.deleteDirectory(new File(directory))
}
- /**
- * Runs the streams set up in `ssc` on real clock until the expected max number of
- */
- def runStreamsWithRealClock[V: ClassManifest](
- ssc: StreamingContext,
- numBatches: Int,
- maxExpectedOutput: Int
- ): (Seq[Seq[V]], Long) = {
-
- System.clearProperty("spark.streaming.clock")
-
- assert(numBatches > 0, "Number of batches to run stream computation is zero")
- assert(maxExpectedOutput > 0, "Max expected outputs after " + numBatches + " is zero")
- logInfo("numBatches = " + numBatches + ", maxExpectedOutput = " + maxExpectedOutput)
-
- // Get the output buffer
- val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
- val output = outputStream.output
- val waitTime = (batchDuration.milliseconds * (numBatches.toDouble + 0.5)).toLong
- val startTime = System.currentTimeMillis()
-
- try {
- // Start computation
- ssc.start()
-
- // Wait until expected number of output items have been generated
- while (output.size < maxExpectedOutput && System.currentTimeMillis() - startTime < waitTime && !FailureSuite.failed) {
- logInfo("output.size = " + output.size + ", maxExpectedOutput = " + maxExpectedOutput)
- Thread.sleep(100)
- }
- } catch {
- case e: Exception => logInfo("Exception while running streams: " + e)
- } finally {
- ssc.stop()
- }
- val timeTaken = System.currentTimeMillis() - startTime
- logInfo("" + output.size + " sets of output generated in " + timeTaken + " ms")
- (output, timeTaken)
+ test("multiple failures with map") {
+ MasterFailureTest.testMap(directory, numBatches, batchDuration)
}
-
-}
-
-object FailureSuite {
- var failed = false
- var failureCount = 0
-
- def reset() {
- failed = false
- failureCount = 0
+ test("multiple failures with updateStateByKey") {
+ MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration)
}
}
-class KillingThread(ssc: StreamingContext, maxKillWaitTime: Int) extends Thread with Logging {
- initLogging()
-
- override def run() {
- var minKillWaitTime = if (FailureSuite.failureCount == 0) 3000 else 1000 // to allow the first checkpoint
- val killWaitTime = minKillWaitTime + Random.nextInt(maxKillWaitTime)
- logInfo("Kill wait time = " + killWaitTime)
- Thread.sleep(killWaitTime.toLong)
- logInfo(
- "\n---------------------------------------\n" +
- "Killing streaming context after " + killWaitTime + " ms" +
- "\n---------------------------------------\n"
- )
- if (ssc != null) ssc.stop()
- FailureSuite.failed = true
- FailureSuite.failureCount += 1
- }
-}
diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
index 70ae6e3934..1024d3ac97 100644
--- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
@@ -1,5 +1,11 @@
package spark.streaming
+import akka.actor.Actor
+import akka.actor.IO
+import akka.actor.IOManager
+import akka.actor.Props
+import akka.util.ByteString
+
import dstream.SparkFlumeEvent
import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket}
import java.io.{File, BufferedWriter, OutputStreamWriter}
@@ -7,6 +13,7 @@ import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import util.ManualClock
import spark.storage.StorageLevel
+import spark.streaming.receivers.Receiver
import spark.Logging
import scala.util.Random
import org.apache.commons.io.FileUtils
@@ -19,40 +26,30 @@ import org.apache.avro.ipc.specific.SpecificRequestor
import java.nio.ByteBuffer
import collection.JavaConversions._
import java.nio.charset.Charset
+import com.google.common.io.Files
class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
val testPort = 9999
- var testServer: TestServer = null
- var testDir: File = null
override def checkpointDir = "checkpoint"
after {
- FileUtils.deleteDirectory(new File(checkpointDir))
- if (testServer != null) {
- testServer.stop()
- testServer = null
- }
- if (testDir != null && testDir.exists()) {
- FileUtils.deleteDirectory(testDir)
- testDir = null
- }
-
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
}
- test("network input stream") {
+
+ test("socket input stream") {
// Start the server
- testServer = new TestServer(testPort)
+ val testServer = new TestServer(testPort)
testServer.start()
// Set up the streaming context and input streams
val ssc = new StreamingContext(master, framework, batchDuration)
- val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK)
+ val networkStream = ssc.socketTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK)
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]]
val outputStream = new TestOutputStream(networkStream, outputBuffer)
def output = outputBuffer.flatMap(x => x)
@@ -93,46 +90,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
}
}
- test("network input stream with checkpoint") {
- // Start the server
- testServer = new TestServer(testPort)
- testServer.start()
-
- // Set up the streaming context and input streams
- var ssc = new StreamingContext(master, framework, batchDuration)
- ssc.checkpoint(checkpointDir, checkpointInterval)
- val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK)
- var outputStream = new TestOutputStream(networkStream, new ArrayBuffer[Seq[String]])
- ssc.registerOutputStream(outputStream)
- ssc.start()
-
- // Feed data to the server to send to the network receiver
- var clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- for (i <- Seq(1, 2, 3)) {
- testServer.send(i.toString + "\n")
- Thread.sleep(100)
- clock.addToTime(batchDuration.milliseconds)
- }
- Thread.sleep(500)
- assert(outputStream.output.size > 0)
- ssc.stop()
-
- // Restart stream computation from checkpoint and feed more data to see whether
- // they are being received and processed
- logInfo("*********** RESTARTING ************")
- ssc = new StreamingContext(checkpointDir)
- ssc.start()
- clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- for (i <- Seq(4, 5, 6)) {
- testServer.send(i.toString + "\n")
- Thread.sleep(100)
- clock.addToTime(batchDuration.milliseconds)
- }
- Thread.sleep(500)
- outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[String]]
- assert(outputStream.output.size > 0)
- ssc.stop()
- }
test("flume input stream") {
// Set up the streaming context and input streams
@@ -146,7 +103,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val input = Seq(1, 2, 3, 4, 5)
-
+ Thread.sleep(1000)
val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", 33333));
val client = SpecificRequestor.getClient(
classOf[AvroSourceProtocol], transceiver);
@@ -182,42 +139,33 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
}
}
- test("file input stream") {
- // Create a temporary directory
- testDir = {
- var temp = File.createTempFile(".temp.", Random.nextInt().toString)
- temp.delete()
- temp.mkdirs()
- logInfo("Created temp dir " + temp)
- temp
- }
+ test("file input stream") {
+ // Disable manual clock as FileInputDStream does not work with manual clock
+ System.clearProperty("spark.streaming.clock")
// Set up the streaming context and input streams
+ val testDir = Files.createTempDir()
val ssc = new StreamingContext(master, framework, batchDuration)
- val filestream = ssc.textFileStream(testDir.toString)
+ val fileStream = ssc.textFileStream(testDir.toString)
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
def output = outputBuffer.flatMap(x => x)
- val outputStream = new TestOutputStream(filestream, outputBuffer)
+ val outputStream = new TestOutputStream(fileStream, outputBuffer)
ssc.registerOutputStream(outputStream)
ssc.start()
// Create files in the temporary directory so that Spark Streaming can read data from it
- val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val input = Seq(1, 2, 3, 4, 5)
val expectedOutput = input.map(_.toString)
Thread.sleep(1000)
for (i <- 0 until input.size) {
- FileUtils.writeStringToFile(new File(testDir, i.toString), input(i).toString + "\n")
- Thread.sleep(500)
- clock.addToTime(batchDuration.milliseconds)
- //Thread.sleep(100)
+ val file = new File(testDir, i.toString)
+ FileUtils.writeStringToFile(file, input(i).toString + "\n")
+ logInfo("Created file " + file)
+ Thread.sleep(batchDuration.milliseconds)
+ Thread.sleep(1000)
}
val startTime = System.currentTimeMillis()
- /*while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
- logInfo("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size)
- Thread.sleep(100)
- }*/
Thread.sleep(1000)
val timeTaken = System.currentTimeMillis() - startTime
assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
@@ -226,75 +174,76 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
// Verify whether data received by Spark Streaming was as expected
logInfo("--------------------------------")
- logInfo("output.size = " + outputBuffer.size)
- logInfo("output")
+ logInfo("output, size = " + outputBuffer.size)
outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
- logInfo("expected output.size = " + expectedOutput.size)
- logInfo("expected output")
+ logInfo("expected output, size = " + expectedOutput.size)
expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
logInfo("--------------------------------")
// Verify whether all the elements received are as expected
// (whether the elements were received one in each interval is not verified)
- assert(output.size === expectedOutput.size)
- for (i <- 0 until output.size) {
- assert(output(i).size === 1)
- assert(output(i).head.toString === expectedOutput(i))
- }
+ assert(output.toList === expectedOutput.toList)
+
+ FileUtils.deleteDirectory(testDir)
+
+ // Enable manual clock back again for other tests
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
}
- test("file input stream with checkpoint") {
- // Create a temporary directory
- testDir = {
- var temp = File.createTempFile(".temp.", Random.nextInt().toString)
- temp.delete()
- temp.mkdirs()
- logInfo("Created temp dir " + temp)
- temp
- }
+
+ test("actor input stream") {
+ // Start the server
+ val port = testPort
+ val testServer = new TestServer(port)
+ testServer.start()
// Set up the streaming context and input streams
- var ssc = new StreamingContext(master, framework, batchDuration)
- ssc.checkpoint(checkpointDir, checkpointInterval)
- val filestream = ssc.textFileStream(testDir.toString)
- var outputStream = new TestOutputStream(filestream, new ArrayBuffer[Seq[String]])
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor",
+ StorageLevel.MEMORY_AND_DISK) //Had to pass the local value of port to prevent from closing over entire scope
+ val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
+ val outputStream = new TestOutputStream(networkStream, outputBuffer)
+ def output = outputBuffer.flatMap(x => x)
ssc.registerOutputStream(outputStream)
ssc.start()
- // Create files and advance manual clock to process them
- var clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ // Feed data to the server to send to the network receiver
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val input = 1 to 9
+ val expectedOutput = input.map(x => x.toString)
Thread.sleep(1000)
- for (i <- Seq(1, 2, 3)) {
- FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n")
- Thread.sleep(100)
+ for (i <- 0 until input.size) {
+ testServer.send(input(i).toString)
+ Thread.sleep(500)
clock.addToTime(batchDuration.milliseconds)
}
- Thread.sleep(500)
- logInfo("Output = " + outputStream.output.mkString(","))
- assert(outputStream.output.size > 0)
+ Thread.sleep(1000)
+ logInfo("Stopping server")
+ testServer.stop()
+ logInfo("Stopping context")
ssc.stop()
- // Restart stream computation from checkpoint and create more files to see whether
- // they are being processed
- logInfo("*********** RESTARTING ************")
- ssc = new StreamingContext(checkpointDir)
- ssc.start()
- clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- Thread.sleep(500)
- for (i <- Seq(4, 5, 6)) {
- FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n")
- Thread.sleep(100)
- clock.addToTime(batchDuration.milliseconds)
+ // Verify whether data received was as expected
+ logInfo("--------------------------------")
+ logInfo("output.size = " + outputBuffer.size)
+ logInfo("output")
+ outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("expected output.size = " + expectedOutput.size)
+ logInfo("expected output")
+ expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+
+ // Verify whether all the elements received are as expected
+ // (whether the elements were received one in each interval is not verified)
+ assert(output.size === expectedOutput.size)
+ for (i <- 0 until output.size) {
+ assert(output(i) === expectedOutput(i))
}
- Thread.sleep(500)
- outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[String]]
- logInfo("Output = " + outputStream.output.mkString(","))
- assert(outputStream.output.size > 0)
- ssc.stop()
}
}
+/** This is server to test the network input stream */
class TestServer(port: Int) extends Logging {
val queue = new ArrayBlockingQueue[String](100)
@@ -353,3 +302,15 @@ object TestServer {
}
}
}
+
+class TestActor(port: Int) extends Actor with Receiver {
+
+ def bytesToString(byteString: ByteString) = byteString.utf8String
+
+ override def preStart = IOManager(context.system).connect(new InetSocketAddress(port))
+
+ def receive = {
+ case IO.Read(socket, bytes) =>
+ pushBlock(bytesToString(bytes))
+ }
+}
diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
index 49129f3964..ad6aa79d10 100644
--- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
@@ -28,6 +28,11 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[
logInfo("Computing RDD for time " + validTime)
val index = ((validTime - zeroTime) / slideDuration - 1).toInt
val selectedInput = if (index < input.size) input(index) else Seq[T]()
+
+ // lets us test cases where RDDs are not created
+ if (selectedInput == null)
+ return None
+
val rdd = ssc.sc.makeRDD(selectedInput, numPartitions)
logInfo("Created RDD " + rdd.id + " with " + selectedInput)
Some(rdd)
@@ -58,20 +63,25 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu
*/
trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
+ // Name of the framework for Spark context
def framework = "TestSuiteBase"
+ // Master for Spark context
def master = "local[2]"
+ // Batch duration
def batchDuration = Seconds(1)
+ // Directory where the checkpoint data will be saved
def checkpointDir = "checkpoint"
- def checkpointInterval = batchDuration
-
+ // Number of partitions of the input parallel collections created for testing
def numInputPartitions = 2
+ // Maximum time to wait before the test times out
def maxWaitTimeMillis = 10000
+ // Whether to actually wait in real time before changing manual clock
def actuallyWait = false
/**
@@ -86,7 +96,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Create StreamingContext
val ssc = new StreamingContext(master, framework, batchDuration)
if (checkpointDir != null) {
- ssc.checkpoint(checkpointDir, checkpointInterval)
+ ssc.checkpoint(checkpointDir)
}
// Setup the stream computation
@@ -111,7 +121,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Create StreamingContext
val ssc = new StreamingContext(master, framework, batchDuration)
if (checkpointDir != null) {
- ssc.checkpoint(checkpointDir, checkpointInterval)
+ ssc.checkpoint(checkpointDir)
}
// Setup the stream computation
@@ -135,9 +145,6 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
numBatches: Int,
numExpectedOutput: Int
): Seq[Seq[V]] = {
-
- System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
-
assert(numBatches > 0, "Number of batches to run stream computation is zero")
assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero")
logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
@@ -181,7 +188,6 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
} finally {
ssc.stop()
}
-
output
}
diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
index cd9608df53..1b66f3bda2 100644
--- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
@@ -5,6 +5,8 @@ import collection.mutable.ArrayBuffer
class WindowOperationsSuite extends TestSuiteBase {
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
override def framework = "WindowOperationsSuite"
override def maxWaitTimeMillis = 20000
@@ -82,12 +84,9 @@ class WindowOperationsSuite extends TestSuiteBase {
)
/*
- The output of the reduceByKeyAndWindow with inverse reduce function is
- different from the naive reduceByKeyAndWindow. Even if the count of a
- particular key is 0, the key does not get eliminated from the RDDs of
- ReducedWindowedDStream. This causes the number of keys in these RDDs to
- increase forever. A more generalized version that allows elimination of
- keys should be considered.
+ The output of the reduceByKeyAndWindow with inverse function but without a filter
+ function will be different from the naive reduceByKeyAndWindow, as no keys get
+ eliminated from the ReducedWindowedDStream even if the value of a key becomes 0.
*/
val bigReduceInvOutput = Seq(
@@ -175,31 +174,31 @@ class WindowOperationsSuite extends TestSuiteBase {
// Testing reduceByKeyAndWindow (with invertible reduce function)
- testReduceByKeyAndWindowInv(
+ testReduceByKeyAndWindowWithInverse(
"basic reduction",
Seq(Seq(("a", 1), ("a", 3)) ),
Seq(Seq(("a", 4)) )
)
- testReduceByKeyAndWindowInv(
+ testReduceByKeyAndWindowWithInverse(
"key already in window and new value added into window",
Seq( Seq(("a", 1)), Seq(("a", 1)) ),
Seq( Seq(("a", 1)), Seq(("a", 2)) )
)
- testReduceByKeyAndWindowInv(
+ testReduceByKeyAndWindowWithInverse(
"new key added into window",
Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ),
Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) )
)
- testReduceByKeyAndWindowInv(
+ testReduceByKeyAndWindowWithInverse(
"key removed from window",
Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ),
Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) )
)
- testReduceByKeyAndWindowInv(
+ testReduceByKeyAndWindowWithInverse(
"larger slide time",
largerSlideInput,
largerSlideReduceOutput,
@@ -207,7 +206,9 @@ class WindowOperationsSuite extends TestSuiteBase {
Seconds(2)
)
- testReduceByKeyAndWindowInv("big test", bigInput, bigReduceInvOutput)
+ testReduceByKeyAndWindowWithInverse("big test", bigInput, bigReduceInvOutput)
+
+ testReduceByKeyAndWindowWithFilteredInverse("big test", bigInput, bigReduceOutput)
test("groupByKeyAndWindow") {
val input = bigInput
@@ -235,14 +236,14 @@ class WindowOperationsSuite extends TestSuiteBase {
testOperation(input, operation, expectedOutput, numBatches, true)
}
- test("countByKeyAndWindow") {
- val input = Seq(Seq(("a", 1)), Seq(("b", 1), ("b", 2)), Seq(("a", 10), ("b", 20)))
+ test("countByValueAndWindow") {
+ val input = Seq(Seq("a"), Seq("b", "b"), Seq("a", "b"))
val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3)))
val windowDuration = Seconds(2)
val slideDuration = Seconds(1)
val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
- val operation = (s: DStream[(String, Int)]) => {
- s.countByKeyAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toInt))
+ val operation = (s: DStream[String]) => {
+ s.countByValueAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toInt))
}
testOperation(input, operation, expectedOutput, numBatches, true)
}
@@ -272,29 +273,50 @@ class WindowOperationsSuite extends TestSuiteBase {
slideDuration: Duration = Seconds(1)
) {
test("reduceByKeyAndWindow - " + name) {
+ logInfo("reduceByKeyAndWindow - " + name)
val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
val operation = (s: DStream[(String, Int)]) => {
- s.reduceByKeyAndWindow(_ + _, windowDuration, slideDuration).persist()
+ s.reduceByKeyAndWindow((x: Int, y: Int) => x + y, windowDuration, slideDuration)
}
testOperation(input, operation, expectedOutput, numBatches, true)
}
}
- def testReduceByKeyAndWindowInv(
+ def testReduceByKeyAndWindowWithInverse(
name: String,
input: Seq[Seq[(String, Int)]],
expectedOutput: Seq[Seq[(String, Int)]],
windowDuration: Duration = Seconds(2),
slideDuration: Duration = Seconds(1)
) {
- test("reduceByKeyAndWindowInv - " + name) {
+ test("reduceByKeyAndWindow with inverse function - " + name) {
+ logInfo("reduceByKeyAndWindow with inverse function - " + name)
val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
val operation = (s: DStream[(String, Int)]) => {
s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration)
- .persist()
.checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing
}
testOperation(input, operation, expectedOutput, numBatches, true)
}
}
+
+ def testReduceByKeyAndWindowWithFilteredInverse(
+ name: String,
+ input: Seq[Seq[(String, Int)]],
+ expectedOutput: Seq[Seq[(String, Int)]],
+ windowDuration: Duration = Seconds(2),
+ slideDuration: Duration = Seconds(1)
+ ) {
+ test("reduceByKeyAndWindow with inverse and filter functions - " + name) {
+ logInfo("reduceByKeyAndWindow with inverse and filter functions - " + name)
+ val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
+ val filterFunc = (p: (String, Int)) => p._2 != 0
+ val operation = (s: DStream[(String, Int)]) => {
+ s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration, filterFunc = filterFunc)
+ .persist()
+ .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing
+ }
+ testOperation(input, operation, expectedOutput, numBatches, true)
+ }
+ }
}