aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--bagel/pom.xml11
-rw-r--r--bagel/src/test/resources/log4j.properties4
-rw-r--r--core/pom.xml17
-rw-r--r--core/src/main/scala/spark/Accumulators.scala41
-rw-r--r--core/src/main/scala/spark/BoundedMemoryCache.scala118
-rw-r--r--core/src/main/scala/spark/CacheManager.scala65
-rw-r--r--core/src/main/scala/spark/CacheTracker.scala238
-rw-r--r--core/src/main/scala/spark/DaemonThreadFactory.scala18
-rw-r--r--core/src/main/scala/spark/HttpFileServer.scala8
-rw-r--r--core/src/main/scala/spark/HttpServer.scala9
-rw-r--r--core/src/main/scala/spark/KryoSerializer.scala210
-rw-r--r--core/src/main/scala/spark/Logging.scala3
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala60
-rw-r--r--core/src/main/scala/spark/PairRDDFunctions.scala86
-rw-r--r--core/src/main/scala/spark/ParallelCollection.scala24
-rw-r--r--core/src/main/scala/spark/Partitioner.scala4
-rw-r--r--core/src/main/scala/spark/RDD.scala197
-rw-r--r--core/src/main/scala/spark/RDDCheckpointData.scala105
-rw-r--r--core/src/main/scala/spark/SequenceFileRDDFunctions.scala8
-rw-r--r--core/src/main/scala/spark/SizeEstimator.scala13
-rw-r--r--core/src/main/scala/spark/SparkContext.scala157
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala39
-rw-r--r--core/src/main/scala/spark/SparkFiles.java25
-rw-r--r--core/src/main/scala/spark/TaskContext.scala3
-rw-r--r--core/src/main/scala/spark/Utils.scala125
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala10
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala33
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala105
-rw-r--r--core/src/main/scala/spark/api/java/StorageLevels.java11
-rw-r--r--core/src/main/scala/spark/api/python/PythonPartitioner.scala48
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala293
-rw-r--r--core/src/main/scala/spark/broadcast/Broadcast.scala2
-rw-r--r--core/src/main/scala/spark/broadcast/HttpBroadcast.scala26
-rw-r--r--core/src/main/scala/spark/deploy/DeployMessage.scala4
-rw-r--r--core/src/main/scala/spark/deploy/JobDescription.scala3
-rw-r--r--core/src/main/scala/spark/deploy/JsonProtocol.scala78
-rw-r--r--core/src/main/scala/spark/deploy/client/TestClient.scala2
-rw-r--r--core/src/main/scala/spark/deploy/master/Master.scala15
-rw-r--r--core/src/main/scala/spark/deploy/master/MasterWebUI.scala58
-rw-r--r--core/src/main/scala/spark/deploy/master/WorkerInfo.scala6
-rw-r--r--core/src/main/scala/spark/deploy/master/WorkerState.scala7
-rw-r--r--core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala5
-rw-r--r--core/src/main/scala/spark/deploy/worker/Worker.scala4
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerArguments.scala22
-rw-r--r--core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala19
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala34
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala3
-rw-r--r--core/src/main/scala/spark/network/Connection.scala7
-rw-r--r--core/src/main/scala/spark/network/ConnectionManager.scala17
-rw-r--r--core/src/main/scala/spark/network/ConnectionManagerTest.scala24
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala20
-rw-r--r--core/src/main/scala/spark/rdd/CartesianRDD.scala47
-rw-r--r--core/src/main/scala/spark/rdd/CheckpointRDD.scala128
-rw-r--r--core/src/main/scala/spark/rdd/CoGroupedRDD.scala70
-rw-r--r--core/src/main/scala/spark/rdd/CoalescedRDD.scala47
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala17
-rw-r--r--core/src/main/scala/spark/rdd/FlatMappedRDD.scala10
-rw-r--r--core/src/main/scala/spark/rdd/GlommedRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/HadoopRDD.scala15
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala14
-rw-r--r--core/src/main/scala/spark/rdd/MappedRDD.scala11
-rw-r--r--core/src/main/scala/spark/rdd/NewHadoopRDD.scala19
-rw-r--r--core/src/main/scala/spark/rdd/PipedRDD.scala18
-rw-r--r--core/src/main/scala/spark/rdd/SampledRDD.scala29
-rw-r--r--core/src/main/scala/spark/rdd/ShuffledRDD.scala28
-rw-r--r--core/src/main/scala/spark/rdd/UnionRDD.scala45
-rw-r--r--core/src/main/scala/spark/rdd/ZippedRDD.scala60
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala90
-rw-r--r--core/src/main/scala/spark/scheduler/MapStatus.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala102
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala24
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala31
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala6
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala44
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala16
-rw-r--r--core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala10
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala231
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerId.scala70
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMaster.scala727
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMasterActor.scala401
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerMessages.scala100
-rw-r--r--core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala16
-rw-r--r--core/src/main/scala/spark/storage/BlockMessage.scala2
-rw-r--r--core/src/main/scala/spark/storage/BlockStore.scala7
-rw-r--r--core/src/main/scala/spark/storage/DiskStore.scala5
-rw-r--r--core/src/main/scala/spark/storage/MemoryStore.scala6
-rw-r--r--core/src/main/scala/spark/storage/StorageLevel.scala80
-rw-r--r--core/src/main/scala/spark/storage/ThreadingTest.scala13
-rw-r--r--core/src/main/scala/spark/util/AkkaUtils.scala1
-rw-r--r--core/src/main/scala/spark/util/IdGenerator.scala14
-rw-r--r--core/src/main/scala/spark/util/MetadataCleaner.scala44
-rw-r--r--core/src/main/scala/spark/util/RateLimitedOutputStream.scala62
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashMap.scala93
-rw-r--r--core/src/main/scala/spark/util/TimeStampedHashSet.scala69
-rw-r--r--core/src/main/twirl/spark/deploy/master/worker_row.scala.html1
-rw-r--r--core/src/main/twirl/spark/deploy/master/worker_table.scala.html1
-rw-r--r--core/src/test/resources/log4j.properties4
-rw-r--r--core/src/test/scala/spark/BoundedMemoryCacheSuite.scala58
-rw-r--r--core/src/test/scala/spark/CacheTrackerSuite.scala131
-rw-r--r--core/src/test/scala/spark/CheckpointSuite.scala357
-rw-r--r--core/src/test/scala/spark/ClosureCleanerSuite.scala2
-rw-r--r--core/src/test/scala/spark/DistributedSuite.scala69
-rw-r--r--core/src/test/scala/spark/DriverSuite.scala31
-rw-r--r--core/src/test/scala/spark/FileServerSuite.scala13
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java98
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala56
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala26
-rw-r--r--core/src/test/scala/spark/RDDSuite.scala60
-rw-r--r--core/src/test/scala/spark/ShuffleSuite.scala7
-rw-r--r--core/src/test/scala/spark/SizeEstimatorSuite.scala48
-rw-r--r--core/src/test/scala/spark/scheduler/TaskContextSuite.scala42
-rw-r--r--core/src/test/scala/spark/storage/BlockManagerSuite.scala168
-rw-r--r--core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala23
-rw-r--r--docs/README.md8
-rwxr-xr-xdocs/_layouts/global.html11
-rw-r--r--docs/_plugins/copy_api_dirs.rb21
-rw-r--r--docs/api.md6
-rw-r--r--docs/configuration.md27
-rw-r--r--docs/ec2-scripts.md4
-rw-r--r--docs/index.md17
-rw-r--r--docs/java-programming-guide.md3
-rw-r--r--docs/python-programming-guide.md110
-rw-r--r--docs/quick-start.md50
-rw-r--r--docs/scala-programming-guide.md3
-rw-r--r--docs/spark-standalone.md43
-rw-r--r--docs/streaming-programming-guide.md313
-rw-r--r--docs/tuning.md30
-rw-r--r--examples/pom.xml28
-rw-r--r--examples/src/main/scala/spark/examples/LocalLR.scala2
-rw-r--r--examples/src/main/scala/spark/examples/SparkALS.scala59
-rw-r--r--examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala43
-rw-r--r--examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala36
-rw-r--r--examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java50
-rw-r--r--examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java62
-rw-r--r--examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java62
-rw-r--r--examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala69
-rw-r--r--examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala36
-rw-r--r--examples/src/main/scala/spark/streaming/examples/QueueStream.scala39
-rw-r--r--examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala46
-rw-r--r--examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala85
-rw-r--r--examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala84
-rw-r--r--examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala60
-rw-r--r--examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala71
-rw-r--r--pom.xml45
-rw-r--r--project/SparkBuild.scala32
-rwxr-xr-xpyspark39
-rw-r--r--python/.gitignore2
-rw-r--r--python/epydoc.conf19
-rwxr-xr-xpython/examples/als.py71
-rw-r--r--python/examples/kmeans.py54
-rwxr-xr-xpython/examples/logistic_regression.py57
-rw-r--r--python/examples/pi.py21
-rw-r--r--python/examples/transitive_closure.py50
-rw-r--r--python/examples/wordcount.py19
-rw-r--r--python/lib/PY4J_LICENSE.txt27
-rw-r--r--python/lib/PY4J_VERSION.txt1
-rw-r--r--python/lib/py4j0.7.eggbin0 -> 191756 bytes
-rw-r--r--python/lib/py4j0.7.jarbin0 -> 103286 bytes
-rw-r--r--python/pyspark/__init__.py27
-rw-r--r--python/pyspark/accumulators.py187
-rw-r--r--python/pyspark/broadcast.py48
-rw-r--r--python/pyspark/cloudpickle.py974
-rw-r--r--python/pyspark/context.py258
-rw-r--r--python/pyspark/files.py38
-rw-r--r--python/pyspark/java_gateway.py38
-rw-r--r--python/pyspark/join.py92
-rw-r--r--python/pyspark/rdd.py761
-rw-r--r--python/pyspark/serializers.py83
-rw-r--r--python/pyspark/shell.py17
-rw-r--r--python/pyspark/tests.py112
-rw-r--r--python/pyspark/worker.py52
-rwxr-xr-xpython/run-tests35
-rwxr-xr-xpython/test_support/hello.txt1
-rwxr-xr-xpython/test_support/userlibrary.py7
-rw-r--r--repl-bin/pom.xml16
-rw-r--r--repl-bin/src/deb/control/control4
-rw-r--r--repl/pom.xml35
-rw-r--r--repl/src/test/resources/log4j.properties4
-rwxr-xr-xrun27
-rw-r--r--run2.cmd4
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jarbin0 -> 1358063 bytes
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha11
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom9
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha11
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml12
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha11
-rw-r--r--streaming/pom.xml155
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala118
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala657
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala134
-rw-r--r--streaming/src/main/scala/spark/streaming/Duration.scala62
-rw-r--r--streaming/src/main/scala/spark/streaming/Interval.scala41
-rw-r--r--streaming/src/main/scala/spark/streaming/Job.scala24
-rw-r--r--streaming/src/main/scala/spark/streaming/JobManager.scala33
-rw-r--r--streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala151
-rw-r--r--streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala562
-rw-r--r--streaming/src/main/scala/spark/streaming/Scheduler.scala77
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala411
-rw-r--r--streaming/src/main/scala/spark/streaming/Time.scala42
-rw-r--r--streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala91
-rw-r--r--streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala183
-rw-r--r--streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala638
-rw-r--r--streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala346
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala40
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala19
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala102
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala21
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala20
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala20
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala137
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala28
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala17
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala19
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala200
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala21
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala21
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala20
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala254
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala41
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala91
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala149
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala27
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala103
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala84
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala19
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala40
-rw-r--r--streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala40
-rw-r--r--streaming/src/main/scala/spark/streaming/util/Clock.scala84
-rw-r--r--streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala98
-rw-r--r--streaming/src/main/scala/spark/streaming/util/RawTextSender.scala60
-rw-r--r--streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala75
-rw-r--r--streaming/src/test/java/spark/streaming/JavaAPISuite.java1029
-rw-r--r--streaming/src/test/java/spark/streaming/JavaTestUtils.scala65
-rw-r--r--streaming/src/test/resources/log4j.properties11
-rw-r--r--streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala218
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala210
-rw-r--r--streaming/src/test/scala/spark/streaming/FailureSuite.scala191
-rw-r--r--streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala355
-rw-r--r--streaming/src/test/scala/spark/streaming/TestSuiteBase.scala291
-rw-r--r--streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala300
246 files changed, 17206 insertions, 2297 deletions
diff --git a/.gitignore b/.gitignore
index 672c60af3d..155e785b01 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,6 +12,7 @@ third_party/libmesos.so
third_party/libmesos.dylib
conf/java-opts
conf/spark-env.sh
+conf/streaming-env.sh
conf/log4j.properties
docs/_site
docs/api
@@ -31,6 +32,7 @@ project/plugins/src_managed/
logs/
log/
spark-tests.log
+streaming-tests.log
dependency-reduced-pom.xml
.ensime
.ensime_lucene
diff --git a/bagel/pom.xml b/bagel/pom.xml
index a8256a6e8b..5f58347204 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -45,6 +45,11 @@
<profiles>
<profile>
<id>hadoop1</id>
+ <activation>
+ <property>
+ <name>!hadoopVersion</name>
+ </property>
+ </activation>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
@@ -72,6 +77,12 @@
</profile>
<profile>
<id>hadoop2</id>
+ <activation>
+ <property>
+ <name>hadoopVersion</name>
+ <value>2</value>
+ </property>
+ </activation>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties
index 4c99e450bc..83d05cab2f 100644
--- a/bagel/src/test/resources/log4j.properties
+++ b/bagel/src/test/resources/log4j.properties
@@ -1,8 +1,8 @@
-# Set everything to be logged to the console
+# Set everything to be logged to the file bagel/target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
-log4j.appender.file.file=spark-tests.log
+log4j.appender.file.file=bagel/target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
diff --git a/core/pom.xml b/core/pom.xml
index ae52c20657..862d3ec37a 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -72,6 +72,10 @@
<artifactId>spray-server</artifactId>
</dependency>
<dependency>
+ <groupId>cc.spray</groupId>
+ <artifactId>spray-json_${scala.version}</artifactId>
+ </dependency>
+ <dependency>
<groupId>org.tomdz.twirl</groupId>
<artifactId>twirl-api</artifactId>
</dependency>
@@ -159,6 +163,11 @@
<profiles>
<profile>
<id>hadoop1</id>
+ <activation>
+ <property>
+ <name>!hadoopVersion</name>
+ </property>
+ </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
@@ -211,6 +220,12 @@
</profile>
<profile>
<id>hadoop2</id>
+ <activation>
+ <property>
+ <name>hadoopVersion</name>
+ <value>2</value>
+ </property>
+ </activation>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
@@ -267,4 +282,4 @@
</build>
</profile>
</profiles>
-</project> \ No newline at end of file
+</project>
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index bacd0ace37..57c6df35be 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -25,8 +25,7 @@ class Accumulable[R, T] (
extends Serializable {
val id = Accumulators.newId
- @transient
- private var value_ = initialValue // Current value on master
+ @transient private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
@@ -39,19 +38,36 @@ class Accumulable[R, T] (
def += (term: T) { value_ = param.addAccumulator(value_, term) }
/**
+ * Add more data to this accumulator / accumulable
+ * @param term the data to add
+ */
+ def add(term: T) { value_ = param.addAccumulator(value_, term) }
+
+ /**
* Merge two accumulable objects together
- *
+ *
* Normally, a user will not want to use this version, but will instead call `+=`.
- * @param term the other Accumulable that will get merged with this
+ * @param term the other `R` that will get merged with this
*/
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
/**
+ * Merge two accumulable objects together
+ *
+ * Normally, a user will not want to use this version, but will instead call `add`.
+ * @param term the other `R` that will get merged with this
+ */
+ def merge(term: R) { value_ = param.addInPlace(value_, term)}
+
+ /**
* Access the accumulator's current value; only allowed on master.
*/
- def value = {
- if (!deserialized) value_
- else throw new UnsupportedOperationException("Can't read accumulator value in task")
+ def value: R = {
+ if (!deserialized) {
+ value_
+ } else {
+ throw new UnsupportedOperationException("Can't read accumulator value in task")
+ }
}
/**
@@ -68,10 +84,17 @@ class Accumulable[R, T] (
/**
* Set the accumulator's value; only allowed on master.
*/
- def value_= (r: R) {
- if (!deserialized) value_ = r
+ def value_= (newValue: R) {
+ if (!deserialized) value_ = newValue
else throw new UnsupportedOperationException("Can't assign accumulator value in task")
}
+
+ /**
+ * Set the accumulator's value; only allowed on master
+ */
+ def setValue(newValue: R) {
+ this.value = newValue
+ }
// Called by Java when deserializing an object
private def readObject(in: ObjectInputStream) {
diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala
deleted file mode 100644
index e8392a194f..0000000000
--- a/core/src/main/scala/spark/BoundedMemoryCache.scala
+++ /dev/null
@@ -1,118 +0,0 @@
-package spark
-
-import java.util.LinkedHashMap
-
-/**
- * An implementation of Cache that estimates the sizes of its entries and attempts to limit its
- * total memory usage to a fraction of the JVM heap. Objects' sizes are estimated using
- * SizeEstimator, which has limitations; most notably, we will overestimate total memory used if
- * some cache entries have pointers to a shared object. Nonetheless, this Cache should work well
- * when most of the space is used by arrays of primitives or of simple classes.
- */
-private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
- logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
-
- def this() {
- this(BoundedMemoryCache.getMaxBytes)
- }
-
- private var currentBytes = 0L
- private val map = new LinkedHashMap[(Any, Int), Entry](32, 0.75f, true)
-
- override def get(datasetId: Any, partition: Int): Any = {
- synchronized {
- val entry = map.get((datasetId, partition))
- if (entry != null) {
- entry.value
- } else {
- null
- }
- }
- }
-
- override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = {
- val key = (datasetId, partition)
- logInfo("Asked to add key " + key)
- val size = estimateValueSize(key, value)
- synchronized {
- if (size > getCapacity) {
- return CachePutFailure()
- } else if (ensureFreeSpace(datasetId, size)) {
- logInfo("Adding key " + key)
- map.put(key, new Entry(value, size))
- currentBytes += size
- logInfo("Number of entries is now " + map.size)
- return CachePutSuccess(size)
- } else {
- logInfo("Didn't add key " + key + " because we would have evicted part of same dataset")
- return CachePutFailure()
- }
- }
- }
-
- override def getCapacity: Long = maxBytes
-
- /**
- * Estimate sizeOf 'value'
- */
- private def estimateValueSize(key: (Any, Int), value: Any) = {
- val startTime = System.currentTimeMillis
- val size = SizeEstimator.estimate(value.asInstanceOf[AnyRef])
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Estimated size for key %s is %d".format(key, size))
- logInfo("Size estimation for key %s took %d ms".format(key, timeTaken))
- size
- }
-
- /**
- * Remove least recently used entries from the map until at least space bytes are free, in order
- * to make space for a partition from the given dataset ID. If this cannot be done without
- * evicting other data from the same dataset, returns false; otherwise, returns true. Assumes
- * that a lock is held on the BoundedMemoryCache.
- */
- private def ensureFreeSpace(datasetId: Any, space: Long): Boolean = {
- logInfo("ensureFreeSpace(%s, %d) called with curBytes=%d, maxBytes=%d".format(
- datasetId, space, currentBytes, maxBytes))
- val iter = map.entrySet.iterator // Will give entries in LRU order
- while (maxBytes - currentBytes < space && iter.hasNext) {
- val mapEntry = iter.next()
- val (entryDatasetId, entryPartition) = mapEntry.getKey
- if (entryDatasetId == datasetId) {
- // Cannot make space without removing part of the same dataset, or a more recently used one
- return false
- }
- reportEntryDropped(entryDatasetId, entryPartition, mapEntry.getValue)
- currentBytes -= mapEntry.getValue.size
- iter.remove()
- }
- return true
- }
-
- protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
- logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
- // TODO: remove BoundedMemoryCache
-
- val (keySpaceId, innerDatasetId) = datasetId.asInstanceOf[(Any, Any)]
- innerDatasetId match {
- case rddId: Int =>
- SparkEnv.get.cacheTracker.dropEntry(rddId, partition)
- case broadcastUUID: java.util.UUID =>
- // TODO: Maybe something should be done if the broadcasted variable falls out of cache
- case _ =>
- }
- }
-}
-
-// An entry in our map; stores a cached object and its size in bytes
-private[spark] case class Entry(value: Any, size: Long)
-
-private[spark] object BoundedMemoryCache {
- /**
- * Get maximum cache capacity from system configuration
- */
- def getMaxBytes: Long = {
- val memoryFractionToUse = System.getProperty("spark.boundedMemoryCache.memoryFraction", "0.66").toDouble
- (Runtime.getRuntime.maxMemory * memoryFractionToUse).toLong
- }
-}
-
diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala
new file mode 100644
index 0000000000..a0b53fd9d6
--- /dev/null
+++ b/core/src/main/scala/spark/CacheManager.scala
@@ -0,0 +1,65 @@
+package spark
+
+import scala.collection.mutable.{ArrayBuffer, HashSet}
+import spark.storage.{BlockManager, StorageLevel}
+
+
+/** Spark class responsible for passing RDDs split contents to the BlockManager and making
+ sure a node doesn't load two copies of an RDD at once.
+ */
+private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
+ private val loading = new HashSet[String]
+
+ /** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */
+ def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
+ : Iterator[T] = {
+ val key = "rdd_%d_%d".format(rdd.id, split.index)
+ logInfo("Cache key is " + key)
+ blockManager.get(key) match {
+ case Some(cachedValues) =>
+ // Split is in cache, so just return its values
+ logInfo("Found partition in cache!")
+ return cachedValues.asInstanceOf[Iterator[T]]
+
+ case None =>
+ // Mark the split as loading (unless someone else marks it first)
+ loading.synchronized {
+ if (loading.contains(key)) {
+ logInfo("Loading contains " + key + ", waiting...")
+ while (loading.contains(key)) {
+ try {loading.wait()} catch {case _ =>}
+ }
+ logInfo("Loading no longer contains " + key + ", so returning cached result")
+ // See whether someone else has successfully loaded it. The main way this would fail
+ // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+ // partition but we didn't want to make space for it. However, that case is unlikely
+ // because it's unlikely that two threads would work on the same RDD partition. One
+ // downside of the current code is that threads wait serially if this does happen.
+ blockManager.get(key) match {
+ case Some(values) =>
+ return values.asInstanceOf[Iterator[T]]
+ case None =>
+ logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
+ loading.add(key)
+ }
+ } else {
+ loading.add(key)
+ }
+ }
+ try {
+ // If we got here, we have to load the split
+ val elements = new ArrayBuffer[Any]
+ logInfo("Computing partition " + split)
+ elements ++= rdd.compute(split, context)
+ // Try to put this block in the blockManager
+ blockManager.put(key, elements, storageLevel, true)
+ return elements.iterator.asInstanceOf[Iterator[T]]
+ } finally {
+ loading.synchronized {
+ loading.remove(key)
+ loading.notifyAll()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala
deleted file mode 100644
index 3d79078733..0000000000
--- a/core/src/main/scala/spark/CacheTracker.scala
+++ /dev/null
@@ -1,238 +0,0 @@
-package spark
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-
-import akka.actor._
-import akka.dispatch._
-import akka.pattern.ask
-import akka.remote._
-import akka.util.Duration
-import akka.util.Timeout
-import akka.util.duration._
-
-import spark.storage.BlockManager
-import spark.storage.StorageLevel
-
-private[spark] sealed trait CacheTrackerMessage
-
-private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
- extends CacheTrackerMessage
-private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
- extends CacheTrackerMessage
-private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
-private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
-private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
-private[spark] case object GetCacheStatus extends CacheTrackerMessage
-private[spark] case object GetCacheLocations extends CacheTrackerMessage
-private[spark] case object StopCacheTracker extends CacheTrackerMessage
-
-private[spark] class CacheTrackerActor extends Actor with Logging {
- // TODO: Should probably store (String, CacheType) tuples
- private val locs = new HashMap[Int, Array[List[String]]]
-
- /**
- * A map from the slave's host name to its cache size.
- */
- private val slaveCapacity = new HashMap[String, Long]
- private val slaveUsage = new HashMap[String, Long]
-
- private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L)
- private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L)
- private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host)
-
- def receive = {
- case SlaveCacheStarted(host: String, size: Long) =>
- slaveCapacity.put(host, size)
- slaveUsage.put(host, 0)
- sender ! true
-
- case RegisterRDD(rddId: Int, numPartitions: Int) =>
- logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions")
- locs(rddId) = Array.fill[List[String]](numPartitions)(Nil)
- sender ! true
-
- case AddedToCache(rddId, partition, host, size) =>
- slaveUsage.put(host, getCacheUsage(host) + size)
- locs(rddId)(partition) = host :: locs(rddId)(partition)
- sender ! true
-
- case DroppedFromCache(rddId, partition, host, size) =>
- slaveUsage.put(host, getCacheUsage(host) - size)
- // Do a sanity check to make sure usage is greater than 0.
- locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
- sender ! true
-
- case MemoryCacheLost(host) =>
- logInfo("Memory cache lost on " + host)
- for ((id, locations) <- locs) {
- for (i <- 0 until locations.length) {
- locations(i) = locations(i).filterNot(_ == host)
- }
- }
- sender ! true
-
- case GetCacheLocations =>
- logInfo("Asked for current cache locations")
- sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())}
-
- case GetCacheStatus =>
- val status = slaveCapacity.map { case (host, capacity) =>
- (host, capacity, getCacheUsage(host))
- }.toSeq
- sender ! status
-
- case StopCacheTracker =>
- logInfo("Stopping CacheTrackerActor")
- sender ! true
- context.stop(self)
- }
-}
-
-private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
- extends Logging {
-
- // Tracker actor on the master, or remote reference to it on workers
- val ip: String = System.getProperty("spark.master.host", "localhost")
- val port: Int = System.getProperty("spark.master.port", "7077").toInt
- val actorName: String = "CacheTracker"
-
- val timeout = 10.seconds
-
- var trackerActor: ActorRef = if (isMaster) {
- val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
- logInfo("Registered CacheTrackerActor actor")
- actor
- } else {
- val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
- actorSystem.actorFor(url)
- }
-
- val registeredRddIds = new HashSet[Int]
-
- // Remembers which splits are currently being loaded (on worker nodes)
- val loading = new HashSet[String]
-
- // Send a message to the trackerActor and get its result within a default timeout, or
- // throw a SparkException if this fails.
- def askTracker(message: Any): Any = {
- try {
- val future = trackerActor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with CacheTracker", e)
- }
- }
-
- // Send a one-way message to the trackerActor, to which we expect it to reply with true.
- def communicate(message: Any) {
- if (askTracker(message) != true) {
- throw new SparkException("Error reply received from CacheTracker")
- }
- }
-
- // Registers an RDD (on master only)
- def registerRDD(rddId: Int, numPartitions: Int) {
- registeredRddIds.synchronized {
- if (!registeredRddIds.contains(rddId)) {
- logInfo("Registering RDD ID " + rddId + " with cache")
- registeredRddIds += rddId
- communicate(RegisterRDD(rddId, numPartitions))
- }
- }
- }
-
- // For BlockManager.scala only
- def cacheLost(host: String) {
- communicate(MemoryCacheLost(host))
- logInfo("CacheTracker successfully removed entries on " + host)
- }
-
- // Get the usage status of slave caches. Each tuple in the returned sequence
- // is in the form of (host name, capacity, usage).
- def getCacheStatus(): Seq[(String, Long, Long)] = {
- askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]]
- }
-
- // For BlockManager.scala only
- def notifyFromBlockManager(t: AddedToCache) {
- communicate(t)
- }
-
- // Get a snapshot of the currently known locations
- def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = {
- askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
- }
-
- // Gets or computes an RDD split
- def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
- : Iterator[T] = {
- val key = "rdd_%d_%d".format(rdd.id, split.index)
- logInfo("Cache key is " + key)
- blockManager.get(key) match {
- case Some(cachedValues) =>
- // Split is in cache, so just return its values
- logInfo("Found partition in cache!")
- return cachedValues.asInstanceOf[Iterator[T]]
-
- case None =>
- // Mark the split as loading (unless someone else marks it first)
- loading.synchronized {
- if (loading.contains(key)) {
- logInfo("Loading contains " + key + ", waiting...")
- while (loading.contains(key)) {
- try {loading.wait()} catch {case _ =>}
- }
- logInfo("Loading no longer contains " + key + ", so returning cached result")
- // See whether someone else has successfully loaded it. The main way this would fail
- // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
- // partition but we didn't want to make space for it. However, that case is unlikely
- // because it's unlikely that two threads would work on the same RDD partition. One
- // downside of the current code is that threads wait serially if this does happen.
- blockManager.get(key) match {
- case Some(values) =>
- return values.asInstanceOf[Iterator[T]]
- case None =>
- logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
- loading.add(key)
- }
- } else {
- loading.add(key)
- }
- }
- // If we got here, we have to load the split
- // Tell the master that we're doing so
- //val host = System.getProperty("spark.hostname", Utils.localHostName)
- //val future = trackerActor !! AddedToCache(rdd.id, split.index, host)
- // TODO: fetch any remote copy of the split that may be available
- // TODO: also register a listener for when it unloads
- logInfo("Computing partition " + split)
- val elements = new ArrayBuffer[Any]
- elements ++= rdd.compute(split, context)
- try {
- // Try to put this block in the blockManager
- blockManager.put(key, elements, storageLevel, true)
- //future.apply() // Wait for the reply from the cache tracker
- } finally {
- loading.synchronized {
- loading.remove(key)
- loading.notifyAll()
- }
- }
- return elements.iterator.asInstanceOf[Iterator[T]]
- }
- }
-
- // Called by the Cache to report that an entry has been dropped from it
- def dropEntry(rddId: Int, partition: Int) {
- communicate(DroppedFromCache(rddId, partition, Utils.localHostName()))
- }
-
- def stop() {
- communicate(StopCacheTracker)
- registeredRddIds.clear()
- trackerActor = null
- }
-}
diff --git a/core/src/main/scala/spark/DaemonThreadFactory.scala b/core/src/main/scala/spark/DaemonThreadFactory.scala
deleted file mode 100644
index 56e59adeb7..0000000000
--- a/core/src/main/scala/spark/DaemonThreadFactory.scala
+++ /dev/null
@@ -1,18 +0,0 @@
-package spark
-
-import java.util.concurrent.ThreadFactory
-
-/**
- * A ThreadFactory that creates daemon threads
- */
-private object DaemonThreadFactory extends ThreadFactory {
- override def newThread(r: Runnable): Thread = new DaemonThread(r)
-}
-
-private class DaemonThread(r: Runnable = null) extends Thread {
- override def run() {
- if (r != null) {
- r.run()
- }
- }
-} \ No newline at end of file
diff --git a/core/src/main/scala/spark/HttpFileServer.scala b/core/src/main/scala/spark/HttpFileServer.scala
index 659d17718f..00901d95e2 100644
--- a/core/src/main/scala/spark/HttpFileServer.scala
+++ b/core/src/main/scala/spark/HttpFileServer.scala
@@ -1,9 +1,7 @@
package spark
-import java.io.{File, PrintWriter}
-import java.net.URL
-import scala.collection.mutable.HashMap
-import org.apache.hadoop.fs.FileUtil
+import java.io.{File}
+import com.google.common.io.Files
private[spark] class HttpFileServer extends Logging {
@@ -40,7 +38,7 @@ private[spark] class HttpFileServer extends Logging {
}
def addFileToDir(file: File, dir: File) : String = {
- Utils.copyFile(file, new File(dir, file.getName))
+ Files.copy(file, new File(dir, file.getName))
return dir + "/" + file.getName
}
diff --git a/core/src/main/scala/spark/HttpServer.scala b/core/src/main/scala/spark/HttpServer.scala
index 0196595ba1..4e0507c080 100644
--- a/core/src/main/scala/spark/HttpServer.scala
+++ b/core/src/main/scala/spark/HttpServer.scala
@@ -4,6 +4,7 @@ import java.io.File
import java.net.InetAddress
import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.server.bio.SocketConnector
import org.eclipse.jetty.server.handler.DefaultHandler
import org.eclipse.jetty.server.handler.HandlerList
import org.eclipse.jetty.server.handler.ResourceHandler
@@ -27,7 +28,13 @@ private[spark] class HttpServer(resourceBase: File) extends Logging {
if (server != null) {
throw new ServerStateException("Server is already started")
} else {
- server = new Server(0)
+ server = new Server()
+ val connector = new SocketConnector
+ connector.setMaxIdleTime(60*1000)
+ connector.setSoLingerTime(-1)
+ connector.setPort(0)
+ server.addConnector(connector)
+
val threadPool = new QueuedThreadPool
threadPool.setDaemon(true)
server.setThreadPool(threadPool)
diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala
index 44b630e478..0bd73e936b 100644
--- a/core/src/main/scala/spark/KryoSerializer.scala
+++ b/core/src/main/scala/spark/KryoSerializer.scala
@@ -9,153 +9,80 @@ import scala.collection.mutable
import com.esotericsoftware.kryo._
import com.esotericsoftware.kryo.{Serializer => KSerializer}
-import com.esotericsoftware.kryo.serialize.ClassSerializer
-import com.esotericsoftware.kryo.serialize.SerializableSerializer
+import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
+import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
import serializer.{SerializerInstance, DeserializationStream, SerializationStream}
import spark.broadcast._
import spark.storage._
-/**
- * Zig-zag encoder used to write object sizes to serialization streams.
- * Based on Kryo's integer encoder.
- */
-private[spark] object ZigZag {
- def writeInt(n: Int, out: OutputStream) {
- var value = n
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- if ((value & ~0x7F) == 0) {
- out.write(value)
- return
- }
- out.write(((value & 0x7F) | 0x80))
- value >>>= 7
- out.write(value)
- }
+private[spark]
+class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
- def readInt(in: InputStream): Int = {
- var offset = 0
- var result = 0
- while (offset < 32) {
- val b = in.read()
- if (b == -1) {
- throw new EOFException("End of stream")
- }
- result |= ((b & 0x7F) << offset)
- if ((b & 0x80) == 0) {
- return result
- }
- offset += 7
- }
- throw new SparkException("Malformed zigzag-encoded integer")
- }
-}
-
-private[spark]
-class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
-extends SerializationStream {
- val channel = Channels.newChannel(out)
+ val output = new KryoOutput(outStream)
def writeObject[T](t: T): SerializationStream = {
- kryo.writeClassAndObject(threadBuffer, t)
- ZigZag.writeInt(threadBuffer.position(), out)
- threadBuffer.flip()
- channel.write(threadBuffer)
- threadBuffer.clear()
+ kryo.writeClassAndObject(output, t)
this
}
- def flush() { out.flush() }
- def close() { out.close() }
+ def flush() { output.flush() }
+ def close() { output.close() }
}
-private[spark]
-class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
-extends DeserializationStream {
+private[spark]
+class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
+
+ val input = new KryoInput(inStream)
+
def readObject[T](): T = {
- val len = ZigZag.readInt(in)
- objectBuffer.readClassAndObject(in, len).asInstanceOf[T]
+ try {
+ kryo.readClassAndObject(input).asInstanceOf[T]
+ } catch {
+ // DeserializationStream uses the EOF exception to indicate stopping condition.
+ case e: com.esotericsoftware.kryo.KryoException => throw new java.io.EOFException
+ }
}
- def close() { in.close() }
+ def close() {
+ // Kryo's Input automatically closes the input stream it is using.
+ input.close()
+ }
}
private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
- val kryo = ks.kryo
- val threadBuffer = ks.threadBuffer.get()
- val objectBuffer = ks.objectBuffer.get()
+
+ val kryo = ks.kryo.get()
+ val output = ks.output.get()
+ val input = ks.input.get()
def serialize[T](t: T): ByteBuffer = {
- // Write it to our thread-local scratch buffer first to figure out the size, then return a new
- // ByteBuffer of the appropriate size
- threadBuffer.clear()
- kryo.writeClassAndObject(threadBuffer, t)
- val newBuf = ByteBuffer.allocate(threadBuffer.position)
- threadBuffer.flip()
- newBuf.put(threadBuffer)
- newBuf.flip()
- newBuf
+ output.clear()
+ kryo.writeClassAndObject(output, t)
+ ByteBuffer.wrap(output.toBytes)
}
def deserialize[T](bytes: ByteBuffer): T = {
- kryo.readClassAndObject(bytes).asInstanceOf[T]
+ input.setBuffer(bytes.array)
+ kryo.readClassAndObject(input).asInstanceOf[T]
}
def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
val oldClassLoader = kryo.getClassLoader
kryo.setClassLoader(loader)
- val obj = kryo.readClassAndObject(bytes).asInstanceOf[T]
+ input.setBuffer(bytes.array)
+ val obj = kryo.readClassAndObject(input).asInstanceOf[T]
kryo.setClassLoader(oldClassLoader)
obj
}
def serializeStream(s: OutputStream): SerializationStream = {
- threadBuffer.clear()
- new KryoSerializationStream(kryo, threadBuffer, s)
+ new KryoSerializationStream(kryo, s)
}
def deserializeStream(s: InputStream): DeserializationStream = {
- new KryoDeserializationStream(objectBuffer, s)
- }
-
- override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
- threadBuffer.clear()
- while (iterator.hasNext) {
- val element = iterator.next()
- // TODO: Do we also want to write the object's size? Doesn't seem necessary.
- kryo.writeClassAndObject(threadBuffer, element)
- }
- val newBuf = ByteBuffer.allocate(threadBuffer.position)
- threadBuffer.flip()
- newBuf.put(threadBuffer)
- newBuf.flip()
- newBuf
- }
-
- override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
- buffer.rewind()
- new Iterator[Any] {
- override def hasNext: Boolean = buffer.remaining > 0
- override def next(): Any = kryo.readClassAndObject(buffer)
- }
+ new KryoDeserializationStream(kryo, s)
}
}
@@ -171,18 +98,19 @@ trait KryoRegistrator {
* A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
*/
class KryoSerializer extends spark.serializer.Serializer with Logging {
- // Make this lazy so that it only gets called once we receive our first task on each executor,
- // so we can pull out any custom Kryo registrator from the user's JARs.
- lazy val kryo = createKryo()
- val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
+ val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
- val objectBuffer = new ThreadLocal[ObjectBuffer] {
- override def initialValue = new ObjectBuffer(kryo, bufferSize)
+ val kryo = new ThreadLocal[Kryo] {
+ override def initialValue = createKryo()
}
- val threadBuffer = new ThreadLocal[ByteBuffer] {
- override def initialValue = ByteBuffer.allocate(bufferSize)
+ val output = new ThreadLocal[KryoOutput] {
+ override def initialValue = new KryoOutput(bufferSize)
+ }
+
+ val input = new ThreadLocal[KryoInput] {
+ override def initialValue = new KryoInput(bufferSize)
}
def createKryo(): Kryo = {
@@ -213,41 +141,44 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
kryo.register(obj.getClass)
}
- // Register the following classes for passing closures.
- kryo.register(classOf[Class[_]], new ClassSerializer(kryo))
- kryo.setRegistrationOptional(true)
-
// Allow sending SerializableWritable
- kryo.register(classOf[SerializableWritable[_]], new SerializableSerializer())
- kryo.register(classOf[HttpBroadcast[_]], new SerializableSerializer())
+ kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
+ kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
// Register some commonly used Scala singleton objects. Because these
// are singletons, we must return the exact same local object when we
// deserialize rather than returning a clone as FieldSerializer would.
- class SingletonSerializer(obj: AnyRef) extends KSerializer {
- override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {}
- override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = obj.asInstanceOf[T]
+ class SingletonSerializer[T](obj: T) extends KSerializer[T] {
+ override def write(kryo: Kryo, output: KryoOutput, obj: T) {}
+ override def read(kryo: Kryo, input: KryoInput, cls: java.lang.Class[T]): T = obj
}
- kryo.register(None.getClass, new SingletonSerializer(None))
- kryo.register(Nil.getClass, new SingletonSerializer(Nil))
+ kryo.register(None.getClass, new SingletonSerializer[AnyRef](None))
+ kryo.register(Nil.getClass, new SingletonSerializer[AnyRef](Nil))
// Register maps with a special serializer since they have complex internal structure
class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any])
- extends KSerializer {
- override def writeObjectData(buf: ByteBuffer, obj: AnyRef) {
+ extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] {
+ override def write(
+ kryo: Kryo,
+ output: KryoOutput,
+ obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) {
val map = obj.asInstanceOf[scala.collection.Map[Any, Any]]
- kryo.writeObject(buf, map.size.asInstanceOf[java.lang.Integer])
+ kryo.writeObject(output, map.size.asInstanceOf[java.lang.Integer])
for ((k, v) <- map) {
- kryo.writeClassAndObject(buf, k)
- kryo.writeClassAndObject(buf, v)
+ kryo.writeClassAndObject(output, k)
+ kryo.writeClassAndObject(output, v)
}
}
- override def readObjectData[T](buf: ByteBuffer, cls: Class[T]): T = {
- val size = kryo.readObject(buf, classOf[java.lang.Integer]).intValue
+ override def read (
+ kryo: Kryo,
+ input: KryoInput,
+ cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]])
+ : Array[(Any, Any)] => scala.collection.Map[Any, Any] = {
+ val size = kryo.readObject(input, classOf[java.lang.Integer]).intValue
val elems = new Array[(Any, Any)](size)
for (i <- 0 until size)
- elems(i) = (kryo.readClassAndObject(buf), kryo.readClassAndObject(buf))
- buildMap(elems).asInstanceOf[T]
+ elems(i) = (kryo.readClassAndObject(input), kryo.readClassAndObject(input))
+ buildMap(elems).asInstanceOf[Array[(Any, Any)] => scala.collection.Map[Any, Any]]
}
}
kryo.register(mutable.HashMap().getClass, new ScalaMapSerializer(mutable.HashMap() ++ _))
@@ -275,5 +206,8 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
kryo
}
- def newInstance(): SerializerInstance = new KryoSerializerInstance(this)
+ def newInstance(): SerializerInstance = {
+ this.kryo.get().setClassLoader(Thread.currentThread().getContextClassLoader)
+ new KryoSerializerInstance(this)
+ }
}
diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala
index 90bae26202..7c1c1bb144 100644
--- a/core/src/main/scala/spark/Logging.scala
+++ b/core/src/main/scala/spark/Logging.scala
@@ -11,8 +11,7 @@ import org.slf4j.LoggerFactory
trait Logging {
// Make the log field transient so that objects with Logging can
// be serialized and used on another machine
- @transient
- private var log_ : Logger = null
+ @transient private var log_ : Logger = null
// Method to get or create the logger for this object
protected def log: Logger = {
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 70eb9f702e..ac02f3363a 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -17,6 +17,7 @@ import akka.util.duration._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
+import spark.util.{MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait MapOutputTrackerMessage
@@ -44,7 +45,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
val timeout = 10.seconds
- var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
+ var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
@@ -53,7 +54,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
- val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]
+ val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
@@ -64,6 +65,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
actorSystem.actorFor(url)
}
+ val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
+
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
@@ -84,14 +87,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.get(shuffleId) != null) {
+ if (mapStatuses.get(shuffleId) != None) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
- var array = mapStatuses.get(shuffleId)
+ var array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
@@ -108,7 +111,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var array = mapStatuses.get(shuffleId)
+ var array = mapStatuses(shuffleId)
if (array != null) {
array.synchronized {
if (array(mapId) != null && array(mapId).address == bmAddress) {
@@ -126,7 +129,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
- val statuses = mapStatuses.get(shuffleId)
+ val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
fetching.synchronized {
@@ -139,8 +142,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case e: InterruptedException =>
}
}
- return mapStatuses.get(shuffleId).map(status =>
- (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
} else {
fetching += shuffleId
}
@@ -156,27 +158,27 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
- if (fetchedStatuses.contains(null)) {
- throw new FetchFailedException(null, shuffleId, -1, reduceId,
- new Exception("Missing an output location for shuffle " + shuffleId))
- }
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
- return fetchedStatuses.map(s =>
- (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
} else {
- return statuses.map(s =>
- (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
}
}
+ def cleanup(cleanupTime: Long) {
+ mapStatuses.clearOldValues(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
+ }
+
def stop() {
communicate(StopMapOutputTracker)
mapStatuses.clear()
+ metadataCleaner.cancel()
trackerActor = null
}
@@ -202,7 +204,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
- mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
+ mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
generation = newGen
}
}
@@ -220,7 +222,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case Some(bytes) =>
return bytes
case None =>
- statuses = mapStatuses.get(shuffleId)
+ statuses = mapStatuses(shuffleId)
generationGotten = generation
}
}
@@ -258,6 +260,28 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
+ // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
+ // any of the statuses is null (indicating a missing location due to a failed mapper),
+ // throw a FetchFailedException.
+ def convertMapStatuses(
+ shuffleId: Int,
+ reduceId: Int,
+ statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
+ if (statuses == null) {
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing all output locations for shuffle " + shuffleId))
+ }
+ statuses.map {
+ status =>
+ if (status == null) {
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing an output location for shuffle " + shuffleId))
+ } else {
+ (status.address, decompressSize(status.compressedSizes(reduceId)))
+ }
+ }
+ }
+
/**
* Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
* We do this by encoding the log base 1.1 of the size as an integer, which can support
diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 08ae06e865..53b051f1c5 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -52,6 +52,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
mapSideCombine: Boolean = true): RDD[(K, C)] = {
+ if (getKeyClass().isArray) {
+ if (mapSideCombine) {
+ throw new SparkException("Cannot use map-side combining with array keys.")
+ }
+ if (partitioner.isInstanceOf[HashPartitioner]) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ }
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (mapSideCombine) {
@@ -92,6 +100,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
+
+ if (getKeyClass().isArray) {
+ throw new SparkException("reduceByKeyLocally() does not support array keys")
+ }
+
def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
val map = new JHashMap[K, V]
for ((k, v) <- iter) {
@@ -165,6 +178,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* be set to true.
*/
def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = {
+ if (getKeyClass().isArray) {
+ if (mapSideCombine) {
+ throw new SparkException("Cannot use map-side combining with array keys.")
+ }
+ if (partitioner.isInstanceOf[HashPartitioner]) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ }
if (mapSideCombine) {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
@@ -178,9 +199,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
- * Merge the values for each key using an associative reduce function. This will also perform
- * the merging locally on each mapper before sending results to a reducer, similarly to a
- * "combiner" in MapReduce.
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD.
*/
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
this.cogroup(other, partitioner).flatMapValues {
@@ -336,6 +357,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]),
partitioner)
@@ -352,6 +376,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
*/
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]],
other1.asInstanceOf[RDD[(_, _)]],
@@ -438,7 +465,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
val res = self.context.runJob(self, process _, Array(index), false)
res(0)
case None =>
- throw new UnsupportedOperationException("lookup() called on an RDD without a partitioner")
+ self.filter(_._1 == key).map(_._2).collect
}
}
@@ -466,20 +493,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
path: String,
keyClass: Class[_],
valueClass: Class[_],
- outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) {
- saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration)
- }
-
- /**
- * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
- * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
- */
- def saveAsNewAPIHadoopFile(
- path: String,
- keyClass: Class[_],
- valueClass: Class[_],
outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
- conf: Configuration) {
+ conf: Configuration = self.context.hadoopConfiguration) {
val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)
@@ -530,7 +545,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
- conf: JobConf = new JobConf) {
+ conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
@@ -588,6 +603,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
writer.cleanup()
}
+ /**
+ * Return an RDD with the keys of each tuple.
+ */
+ def keys: RDD[K] = self.map(_._1)
+
+ /**
+ * Return an RDD with the values of each tuple.
+ */
+ def values: RDD[V] = self.map(_._2)
+
private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure
private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure
@@ -624,24 +649,23 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
private[spark]
-class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override val partitioner = prev.partitioner
- override def compute(split: Split, taskContext: TaskContext) =
- prev.iterator(split, taskContext).map{case (k, v) => (k, f(v))}
+class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U)
+ extends RDD[(K, U)](prev) {
+
+ override def getSplits = firstParent[(K, V)].splits
+ override val partitioner = firstParent[(K, V)].partitioner
+ override def compute(split: Split, context: TaskContext) =
+ firstParent[(K, V)].iterator(split, context).map{ case (k, v) => (k, f(v)) }
}
private[spark]
class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U])
- extends RDD[(K, U)](prev.context) {
-
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override val partitioner = prev.partitioner
+ extends RDD[(K, U)](prev) {
- override def compute(split: Split, taskContext: TaskContext) = {
- prev.iterator(split, taskContext).flatMap { case (k, v) => f(v).map(x => (k, x)) }
+ override def getSplits = firstParent[(K, V)].splits
+ override val partitioner = firstParent[(K, V)].partitioner
+ override def compute(split: Split, context: TaskContext) = {
+ firstParent[(K, V)].iterator(split, context).flatMap { case (k, v) => f(v).map(x => (k, x)) }
}
}
diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index a27f766e31..10adcd53ec 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -2,6 +2,7 @@ package spark
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
+import scala.collection.Map
private[spark] class ParallelCollectionSplit[T: ClassManifest](
val rddId: Long,
@@ -22,28 +23,33 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest](
}
private[spark] class ParallelCollection[T: ClassManifest](
- sc: SparkContext,
+ @transient sc: SparkContext,
@transient data: Seq[T],
- numSlices: Int)
- extends RDD[T](sc) {
+ numSlices: Int,
+ locationPrefs: Map[Int,Seq[String]])
+ extends RDD[T](sc, Nil) {
// TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
// cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
// instead.
+ // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
- @transient
- val splits_ = {
+ @transient var splits_ : Array[Split] = {
val slices = ParallelCollection.slice(data, numSlices).toArray
slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray
}
- override def splits = splits_.asInstanceOf[Array[Split]]
+ override def getSplits = splits_
- override def compute(s: Split, taskContext: TaskContext) =
+ override def compute(s: Split, context: TaskContext) =
s.asInstanceOf[ParallelCollectionSplit[T]].iterator
- override def preferredLocations(s: Split): Seq[String] = Nil
+ override def getPreferredLocations(s: Split): Seq[String] = {
+ locationPrefs.getOrElse(s.index, Nil)
+ }
- override val dependencies: List[Dependency[_]] = Nil
+ override def clearDependencies() {
+ splits_ = null
+ }
}
private object ParallelCollection {
diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index b71021a082..9d5b966e1e 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -11,6 +11,10 @@ abstract class Partitioner extends Serializable {
/**
* A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
+ *
+ * Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
+ * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
+ * produce an unexpected or incorrect result.
*/
class HashPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index bb4c13c494..c79f34342f 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -1,10 +1,8 @@
package spark
-import java.io.EOFException
-import java.io.ObjectInputStream
+import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream}
import java.net.URL
-import java.util.Random
-import java.util.Date
+import java.util.{Date, Random}
import java.util.{HashMap => JHashMap}
import java.util.concurrent.atomic.AtomicLong
@@ -13,6 +11,7 @@ import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
@@ -73,41 +72,42 @@ import SparkContext._
* [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
* on RDD internals.
*/
-abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable {
+abstract class RDD[T: ClassManifest](
+ @transient var sc: SparkContext,
+ var dependencies_ : List[Dependency[_]]
+ ) extends Serializable with Logging {
- // Methods that must be implemented by subclasses:
- /** Set of partitions in this RDD. */
- def splits: Array[Split]
+ def this(@transient oneParent: RDD[_]) =
+ this(oneParent.context , List(new OneToOneDependency(oneParent)))
+
+ // =======================================================================
+ // Methods that should be implemented by subclasses of RDD
+ // =======================================================================
/** Function for computing a given partition. */
def compute(split: Split, context: TaskContext): Iterator[T]
- /** How this RDD depends on any parent RDDs. */
- @transient val dependencies: List[Dependency[_]]
+ /** Set of partitions in this RDD. */
+ protected def getSplits(): Array[Split]
- // Methods available on all RDDs:
+ /** How this RDD depends on any parent RDDs. */
+ protected def getDependencies(): List[Dependency[_]] = dependencies_
- /** Record user function generating this RDD. */
- private[spark] val origin = Utils.getSparkCallSite
+ /** Optionally overridden by subclasses to specify placement preferences. */
+ protected def getPreferredLocations(split: Split): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
- /** Optionally overridden by subclasses to specify placement preferences. */
- def preferredLocations(split: Split): Seq[String] = Nil
-
- /** The [[spark.SparkContext]] that this RDD was created on. */
- def context = sc
- private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+ // =======================================================================
+ // Methods and fields available on all RDDs
+ // =======================================================================
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
- // Variables relating to persistence
- private var storageLevel: StorageLevel = StorageLevel.NONE
-
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
@@ -131,22 +131,39 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
- private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
- if (!level.useDisk && level.replication < 2) {
- throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
+ /**
+ * Get the preferred location of a split, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def preferredLocations(split: Split): Seq[String] = {
+ if (isCheckpointed) {
+ checkpointData.get.getPreferredLocations(split)
+ } else {
+ getPreferredLocations(split)
}
+ }
- // This is a hack. Ideally this should re-use the code used by the CacheTracker
- // to generate the key.
- def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
-
- persist(level)
- sc.runJob(this, (iter: Iterator[T]) => {} )
-
- val p = this.partitioner
+ /**
+ * Get the array of splits of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def splits: Array[Split] = {
+ if (isCheckpointed) {
+ checkpointData.get.getSplits
+ } else {
+ getSplits
+ }
+ }
- new BlockRDD[T](sc, splits.map(getSplitKey).toArray) {
- override val partitioner = p
+ /**
+ * Get the list of dependencies of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def dependencies: List[Dependency[_]] = {
+ if (isCheckpointed) {
+ dependencies_
+ } else {
+ getDependencies
}
}
@@ -156,8 +173,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* subclasses of RDD.
*/
final def iterator(split: Split, context: TaskContext): Iterator[T] = {
- if (storageLevel != StorageLevel.NONE) {
- SparkEnv.get.cacheTracker.getOrCompute[T](this, split, context, storageLevel)
+ if (isCheckpointed) {
+ checkpointData.get.iterator(split, context)
+ } else if (storageLevel != StorageLevel.NONE) {
+ SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
compute(split, context)
}
@@ -185,9 +204,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
- def distinct(numSplits: Int = splits.size): RDD[T] =
+ def distinct(numSplits: Int): RDD[T] =
map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1)
+ def distinct(): RDD[T] = distinct(splits.size)
+
/**
* Return a sampled subset of this RDD.
*/
@@ -328,6 +349,13 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def toArray(): Array[T] = collect()
/**
+ * Return an RDD that contains all matching values by applying `f`.
+ */
+ def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = {
+ filter(f.isDefinedAt).map(f)
+ }
+
+ /**
* Reduces the elements of this RDD using the specified associative binary operator.
*/
def reduce(f: (T, T) => T): T = {
@@ -415,6 +443,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): Map[T, Long] = {
+ if (elementClassManifest.erasure.isArray) {
+ throw new SparkException("countByValue() does not support arrays")
+ }
// TODO: This should perhaps be distributed by default.
def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
val map = new OLMap[T]
@@ -443,6 +474,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
timeout: Long,
confidence: Double = 0.95
): PartialResult[Map[T, BoundedDouble]] = {
+ if (elementClassManifest.erasure.isArray) {
+ throw new SparkException("countByValueApprox() does not support arrays")
+ }
val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
val map = new OLMap[T]
while (iter.hasNext) {
@@ -502,8 +536,95 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
.saveAsSequenceFile(path)
}
+ /**
+ * Creates tuples of the elements in this RDD by applying `f`.
+ */
+ def keyBy[K](f: T => K): RDD[(K, T)] = {
+ map(x => (f(x), x))
+ }
+
/** A private method for tests, to look at the contents of each partition */
private[spark] def collectPartitions(): Array[Array[T]] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
}
+
+ /**
+ * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+ * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * RDDs will be removed. This function must be called before any job has been
+ * executed on this RDD. It is strongly recommended that this RDD is persisted in
+ * memory, otherwise saving it on a file will require recomputation.
+ */
+ def checkpoint() {
+ if (context.checkpointDir.isEmpty) {
+ throw new Exception("Checkpoint directory has not been set in the SparkContext")
+ } else if (checkpointData.isEmpty) {
+ checkpointData = Some(new RDDCheckpointData(this))
+ checkpointData.get.markForCheckpoint()
+ }
+ }
+
+ /**
+ * Return whether this RDD has been checkpointed or not
+ */
+ def isCheckpointed(): Boolean = {
+ if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false
+ }
+
+ /**
+ * Gets the name of the file to which this RDD was checkpointed
+ */
+ def getCheckpointFile(): Option[String] = {
+ if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None
+ }
+
+ // =======================================================================
+ // Other internal methods and fields
+ // =======================================================================
+
+ private var storageLevel: StorageLevel = StorageLevel.NONE
+
+ /** Record user function generating this RDD. */
+ private[spark] val origin = Utils.getSparkCallSite
+
+ private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+
+ private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
+
+ /** Returns the first parent RDD */
+ protected[spark] def firstParent[U: ClassManifest] = {
+ dependencies.head.rdd.asInstanceOf[RDD[U]]
+ }
+
+ /** The [[spark.SparkContext]] that this RDD was created on. */
+ def context = sc
+
+ /**
+ * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler
+ * after a job using this RDD has completed (therefore the RDD has been materialized and
+ * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
+ */
+ protected[spark] def doCheckpoint() {
+ if (checkpointData.isDefined) checkpointData.get.doCheckpoint()
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
+
+ /**
+ * Changes the dependencies of this RDD from its original parents to the new RDD
+ * (`newRDD`) created from the checkpoint file.
+ */
+ protected[spark] def changeDependencies(newRDD: RDD[_]) {
+ clearDependencies()
+ dependencies_ = List(new OneToOneDependency(newRDD))
+ }
+
+ /**
+ * Clears the dependencies of this RDD. This method must ensure that all references
+ * to the original parent RDDs is removed to enable the parent RDDs to be garbage
+ * collected. Subclasses of RDD may override this method for implementing their own cleaning
+ * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea.
+ */
+ protected[spark] def clearDependencies() {
+ dependencies_ = null
+ }
}
diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
new file mode 100644
index 0000000000..18df530b7d
--- /dev/null
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -0,0 +1,105 @@
+package spark
+
+import org.apache.hadoop.fs.Path
+import rdd.{CheckpointRDD, CoalescedRDD}
+import scheduler.{ResultTask, ShuffleMapTask}
+
+/**
+ * Enumeration to manage state transitions of an RDD through checkpointing
+ * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
+ */
+private[spark] object CheckpointState extends Enumeration {
+ type CheckpointState = Value
+ val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
+}
+
+/**
+ * This class contains all the information related to RDD checkpointing. Each instance of this class
+ * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as,
+ * manages the post-checkpoint state by providing the updated splits, iterator and preferred locations
+ * of the checkpointed RDD.
+ */
+private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
+extends Logging with Serializable {
+
+ import CheckpointState._
+
+ // The checkpoint state of the associated RDD.
+ var cpState = Initialized
+
+ // The file to which the associated RDD has been checkpointed to
+ @transient var cpFile: Option[String] = None
+
+ // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
+ @transient var cpRDD: Option[RDD[T]] = None
+
+ // Mark the RDD for checkpointing
+ def markForCheckpoint() {
+ RDDCheckpointData.synchronized {
+ if (cpState == Initialized) cpState = MarkedForCheckpoint
+ }
+ }
+
+ // Is the RDD already checkpointed
+ def isCheckpointed(): Boolean = {
+ RDDCheckpointData.synchronized { cpState == Checkpointed }
+ }
+
+ // Get the file to which this RDD was checkpointed to as an Option
+ def getCheckpointFile(): Option[String] = {
+ RDDCheckpointData.synchronized { cpFile }
+ }
+
+ // Do the checkpointing of the RDD. Called after the first job using that RDD is over.
+ def doCheckpoint() {
+ // If it is marked for checkpointing AND checkpointing is not already in progress,
+ // then set it to be in progress, else return
+ RDDCheckpointData.synchronized {
+ if (cpState == MarkedForCheckpoint) {
+ cpState = CheckpointingInProgress
+ } else {
+ return
+ }
+ }
+
+ // Save to file, and reload it as an RDD
+ val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString
+ rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
+ val newRDD = new CheckpointRDD[T](rdd.context, path)
+
+ // Change the dependencies and splits of the RDD
+ RDDCheckpointData.synchronized {
+ cpFile = Some(path)
+ cpRDD = Some(newRDD)
+ rdd.changeDependencies(newRDD)
+ cpState = Checkpointed
+ RDDCheckpointData.clearTaskCaches()
+ logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
+ }
+ }
+
+ // Get preferred location of a split after checkpointing
+ def getPreferredLocations(split: Split) = {
+ RDDCheckpointData.synchronized {
+ cpRDD.get.preferredLocations(split)
+ }
+ }
+
+ def getSplits: Array[Split] = {
+ RDDCheckpointData.synchronized {
+ cpRDD.get.splits
+ }
+ }
+
+ // Get iterator. This is called at the worker nodes.
+ def iterator(split: Split, context: TaskContext): Iterator[T] = {
+ rdd.firstParent[T].iterator(split, context)
+ }
+}
+
+private[spark] object RDDCheckpointData {
+ def clearTaskCaches() {
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
+ }
+}
diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
index a34aee69c1..6b4a11d6d3 100644
--- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala
@@ -42,7 +42,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
classManifest[T].erasure
} else {
- implicitly[T => Writable].getClass.getMethods()(0).getReturnType
+ // We get the type of the Writable class by looking at the apply method which converts
+ // from T to Writable. Since we have two apply methods we filter out the one which
+ // is of the form "java.lang.Object apply(java.lang.Object)"
+ implicitly[T => Writable].getClass.getDeclaredMethods().filter(
+ m => m.getReturnType().toString != "java.lang.Object" &&
+ m.getName() == "apply")(0).getReturnType
+
}
// TODO: use something like WritableConverter to avoid reflection
}
diff --git a/core/src/main/scala/spark/SizeEstimator.scala b/core/src/main/scala/spark/SizeEstimator.scala
index 7c3e8640e9..d4e1157250 100644
--- a/core/src/main/scala/spark/SizeEstimator.scala
+++ b/core/src/main/scala/spark/SizeEstimator.scala
@@ -9,7 +9,6 @@ import java.util.Random
import javax.management.MBeanServer
import java.lang.management.ManagementFactory
-import com.sun.management.HotSpotDiagnosticMXBean
import scala.collection.mutable.ArrayBuffer
@@ -76,12 +75,20 @@ private[spark] object SizeEstimator extends Logging {
if (System.getProperty("spark.test.useCompressedOops") != null) {
return System.getProperty("spark.test.useCompressedOops").toBoolean
}
+
try {
val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"
val server = ManagementFactory.getPlatformMBeanServer()
+
+ // NOTE: This should throw an exception in non-Sun JVMs
+ val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean")
+ val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption",
+ Class.forName("java.lang.String"))
+
val bean = ManagementFactory.newPlatformMXBeanProxy(server,
- hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean])
- return bean.getVMOption("UseCompressedOops").getValue.toBoolean
+ hotSpotMBeanName, hotSpotMBeanClass)
+ // TODO: We could use reflection on the VMOption returned ?
+ return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
} catch {
case e: Exception => {
// Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 0afab522af..66bdbe7cda 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -3,10 +3,12 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
import java.net.{URI, URLClassLoader}
+import java.lang.ref.WeakReference
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.JavaConversions._
import akka.actor.Actor
import akka.actor.Actor._
@@ -36,12 +38,8 @@ import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
-import spark.rdd.HadoopRDD
-import spark.rdd.NewHadoopRDD
-import spark.rdd.UnionRDD
-import spark.scheduler.ShuffleMapTask
-import spark.scheduler.DAGScheduler
-import spark.scheduler.TaskScheduler
+import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD}
+import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
@@ -58,29 +56,13 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend
* @param environment Environment variables to set on worker nodes.
*/
class SparkContext(
- master: String,
- jobName: String,
- val sparkHome: String,
- jars: Seq[String],
- environment: Map[String, String])
+ val master: String,
+ val jobName: String,
+ val sparkHome: String = null,
+ val jars: Seq[String] = Nil,
+ environment: Map[String, String] = Map())
extends Logging {
- /**
- * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
- * @param sparkHome Location where Spark is installed on cluster nodes.
- * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
- * system or HDFS, HTTP, HTTPS, or FTP URLs.
- */
- def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) =
- this(master, jobName, sparkHome, jars, Map())
-
- /**
- * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
- * @param jobName A name for your job, to display on the cluster web UI
- */
- def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map())
-
// Ensure logging is initialized before we spawn any threads
initLogging()
@@ -187,11 +169,32 @@ class SparkContext(
private var dagScheduler = new DAGScheduler(taskScheduler)
+ /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
+ val hadoopConfiguration = {
+ val conf = new Configuration()
+ // Explicitly check for S3 environment variables
+ if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
+ conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ }
+ // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
+ for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
+ conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536")
+ conf.set("io.file.buffer.size", bufferSize)
+ conf
+ }
+
+ private[spark] var checkpointDir: Option[String] = None
+
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
- new ParallelCollection[T](this, seq, numSlices)
+ new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
/** Distribute a local Scala collection to form an RDD. */
@@ -199,6 +202,14 @@ class SparkContext(
parallelize(seq, numSlices)
}
+ /** Distribute a local Scala collection to form an RDD, with one or more
+ * location preferences (hostnames of Spark nodes) for each object.
+ * Create a new partition for each collection item. */
+ def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
+ new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs)
+ }
+
/**
* Read a text file from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI, and return it as an RDD of Strings.
@@ -231,10 +242,8 @@ class SparkContext(
valueClass: Class[V],
minSplits: Int = defaultMinSplits
) : RDD[(K, V)] = {
- val conf = new JobConf()
+ val conf = new JobConf(hadoopConfiguration)
FileInputFormat.setInputPaths(conf, path)
- val bufferSize = System.getProperty("spark.buffer.size", "65536")
- conf.set("io.file.buffer.size", bufferSize)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
@@ -275,8 +284,7 @@ class SparkContext(
path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
- vm.erasure.asInstanceOf[Class[V]],
- new Configuration)
+ vm.erasure.asInstanceOf[Class[V]])
}
/**
@@ -288,7 +296,7 @@ class SparkContext(
fClass: Class[F],
kClass: Class[K],
vClass: Class[V],
- conf: Configuration): RDD[(K, V)] = {
+ conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -300,7 +308,7 @@ class SparkContext(
* and extra configuration options to pass to the input format.
*/
def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
- conf: Configuration,
+ conf: Configuration = hadoopConfiguration,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
@@ -365,6 +373,13 @@ class SparkContext(
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes))
}
+
+ protected[spark] def checkpointFile[T: ClassManifest](
+ path: String
+ ): RDD[T] = {
+ new CheckpointRDD[T](this, path)
+ }
+
/** Build the union of a list of RDDs. */
def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
@@ -382,11 +397,12 @@ class SparkContext(
new Accumulator(initialValue, param)
/**
- * Create an [[spark.Accumulable]] shared variable, with a `+=` method
+ * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
+ * Only the master can access the accumuable's `value`.
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
- def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
+ def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
new Accumulable(initialValue, param)
/**
@@ -404,12 +420,13 @@ class SparkContext(
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
- def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal)
+ def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
/**
- * Add a file to be downloaded into the working directory of this Spark job on every node.
+ * Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
val uri = new URI(path)
@@ -419,9 +436,10 @@ class SparkContext(
}
addedFiles(key) = System.currentTimeMillis
- // Fetch the file locally in case the task is executed locally
- val filename = new File(path.split("/").last)
- Utils.fetchFile(path, new File("."))
+ // Fetch the file locally in case a job is executed locally.
+ // Jobs that run through LocalScheduler will already fetch the required dependencies,
+ // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
@@ -437,11 +455,10 @@ class SparkContext(
}
/**
- * Clear the job's list of files added by `addFile` so that they do not get donwloaded to
+ * Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
*/
def clearFiles() {
- addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedFiles.clear()
}
@@ -465,23 +482,27 @@ class SparkContext(
* any new nodes.
*/
def clearJars() {
- addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedJars.clear()
}
/** Shut down the SparkContext. */
def stop() {
- dagScheduler.stop()
- dagScheduler = null
- taskScheduler = null
- // TODO: Cache.stop()?
- env.stop()
- // Clean up locally linked files
- clearFiles()
- clearJars()
- SparkEnv.set(null)
- ShuffleMapTask.clearCache()
- logInfo("Successfully stopped SparkContext")
+ if (dagScheduler != null) {
+ dagScheduler.stop()
+ dagScheduler = null
+ taskScheduler = null
+ // TODO: Cache.stop()?
+ env.stop()
+ // Clean up locally linked files
+ clearFiles()
+ clearJars()
+ SparkEnv.set(null)
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
+ logInfo("Successfully stopped SparkContext")
+ } else {
+ logInfo("SparkContext already stopped")
+ }
}
/**
@@ -518,6 +539,7 @@ class SparkContext(
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
+ rdd.doCheckpoint()
result
}
@@ -574,6 +596,26 @@ class SparkContext(
return f
}
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists and useExisting is set to true, then the
+ * exisiting directory will be used. Otherwise an exception will be thrown to
+ * prevent accidental overriding of checkpoint files in the existing directory.
+ */
+ def setCheckpointDir(dir: String, useExisting: Boolean = false) {
+ val path = new Path(dir)
+ val fs = path.getFileSystem(new Configuration())
+ if (!useExisting) {
+ if (fs.exists(path)) {
+ throw new Exception("Checkpoint directory '" + path + "' already exists.")
+ } else {
+ fs.mkdirs(path)
+ }
+ }
+ checkpointDir = Some(dir)
+ }
+
/** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
def defaultParallelism: Int = taskScheduler.defaultParallelism
@@ -595,6 +637,7 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
+
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 272d7cdad3..2a7a8af83d 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -22,24 +22,19 @@ class SparkEnv (
val actorSystem: ActorSystem,
val serializer: Serializer,
val closureSerializer: Serializer,
- val cacheTracker: CacheTracker,
+ val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
- val httpFileServer: HttpFileServer
+ val httpFileServer: HttpFileServer,
+ val sparkFilesDir: String
) {
- /** No-parameter constructor for unit tests. */
- def this() = {
- this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
- }
-
def stop() {
httpFileServer.stop()
mapOutputTracker.stop()
- cacheTracker.stop()
shuffleFetcher.stop()
broadcastManager.stop()
blockManager.stop()
@@ -86,10 +81,13 @@ object SparkEnv extends Logging {
}
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
-
- val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal)
+
+ val masterIp: String = System.getProperty("spark.master.host", "localhost")
+ val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
+ val blockManagerMaster = new BlockManagerMaster(
+ actorSystem, isMaster, isLocal, masterIp, masterPort)
val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer)
-
+
val connectionManager = blockManager.connectionManager
val broadcastManager = new BroadcastManager(isMaster)
@@ -97,18 +95,26 @@ object SparkEnv extends Logging {
val closureSerializer = instantiateClass[Serializer](
"spark.closure.serializer", "spark.JavaSerializer")
- val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager)
- blockManager.cacheTracker = cacheTracker
+ val cacheManager = new CacheManager(blockManager)
val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster)
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
-
+
val httpFileServer = new HttpFileServer()
httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
+ // Set the sparkFiles directory, used when downloading dependencies. In local mode,
+ // this is a temporary directory; in distributed mode, this is the executor's current working
+ // directory.
+ val sparkFilesDir: String = if (isMaster) {
+ Utils.createTempDir().getAbsolutePath
+ } else {
+ "."
+ }
+
// Warn about deprecated spark.cache.class property
if (System.getProperty("spark.cache.class") != null) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@@ -119,12 +125,13 @@ object SparkEnv extends Logging {
actorSystem,
serializer,
closureSerializer,
- cacheTracker,
+ cacheManager,
mapOutputTracker,
shuffleFetcher,
broadcastManager,
blockManager,
connectionManager,
- httpFileServer)
+ httpFileServer,
+ sparkFilesDir)
}
}
diff --git a/core/src/main/scala/spark/SparkFiles.java b/core/src/main/scala/spark/SparkFiles.java
new file mode 100644
index 0000000000..566aec622c
--- /dev/null
+++ b/core/src/main/scala/spark/SparkFiles.java
@@ -0,0 +1,25 @@
+package spark;
+
+import java.io.File;
+
+/**
+ * Resolves paths to files added through `SparkContext.addFile()`.
+ */
+public class SparkFiles {
+
+ private SparkFiles() {}
+
+ /**
+ * Get the absolute path of a file added through `SparkContext.addFile()`.
+ */
+ public static String get(String filename) {
+ return new File(getRootDirectory(), filename).getAbsolutePath();
+ }
+
+ /**
+ * Get the root directory that contains files added through `SparkContext.addFile()`.
+ */
+ public static String getRootDirectory() {
+ return SparkEnv.get().sparkFilesDir();
+ }
+}
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index d2746b26b3..eab85f85a2 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -5,8 +5,7 @@ import scala.collection.mutable.ArrayBuffer
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
- @transient
- val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+ @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
// Add a callback function to be executed on task completion. An example use
// is for HadoopRDD to register a callback to close the input stream.
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index 6d64b32174..ae77264372 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -1,7 +1,7 @@
package spark
import java.io._
-import java.net.{NetworkInterface, InetAddress, URL, URI}
+import java.net.{NetworkInterface, InetAddress, Inet4Address, URL, URI}
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
@@ -9,6 +9,8 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.io.Source
+import com.google.common.io.Files
+import com.google.common.util.concurrent.ThreadFactoryBuilder
/**
* Various utility methods used by Spark.
@@ -110,48 +112,56 @@ private object Utils extends Logging {
}
}
- /** Copy a file on the local file system */
- def copyFile(source: File, dest: File) {
- val in = new FileInputStream(source)
- val out = new FileOutputStream(dest)
- copyStream(in, out, true)
- }
-
- /** Download a file from a given URL to the local filesystem */
- def downloadFile(url: URL, localPath: String) {
- val in = url.openStream()
- val out = new FileOutputStream(localPath)
- Utils.copyStream(in, out, true)
- }
-
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
+ *
+ * Throws SparkException if the target file already exists and has different contents than
+ * the requested file.
*/
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
+ val tempDir = getLocalDir
+ val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
uri.getScheme match {
case "http" | "https" | "ftp" =>
- logInfo("Fetching " + url + " to " + targetFile)
+ logInfo("Fetching " + url + " to " + tempFile)
val in = new URL(url).openStream()
- val out = new FileOutputStream(targetFile)
+ val out = new FileOutputStream(tempFile)
Utils.copyStream(in, out, true)
+ if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
+ tempFile.delete()
+ throw new SparkException("File " + targetFile + " exists and does not match contents of" +
+ " " + url)
+ } else {
+ Files.move(tempFile, targetFile)
+ }
case "file" | null =>
- // Remove the file if it already exists
- targetFile.delete()
- // Symlink the file locally.
- if (uri.isAbsolute) {
- // url is absolute, i.e. it starts with "file:///". Extract the source
- // file's absolute path from the url.
- val sourceFile = new File(uri)
- logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
- FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
+ val sourceFile = if (uri.isAbsolute) {
+ new File(uri)
+ } else {
+ new File(url)
+ }
+ if (targetFile.exists && !Files.equal(sourceFile, targetFile)) {
+ throw new SparkException("File " + targetFile + " exists and does not match contents of" +
+ " " + url)
} else {
- // url is not absolute, i.e. itself is the path to the source file.
- logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
- FileUtil.symLink(url, targetFile.getAbsolutePath)
+ // Remove the file if it already exists
+ targetFile.delete()
+ // Symlink the file locally.
+ if (uri.isAbsolute) {
+ // url is absolute, i.e. it starts with "file:///". Extract the source
+ // file's absolute path from the url.
+ val sourceFile = new File(uri)
+ logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
+ FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
+ } else {
+ // url is not absolute, i.e. itself is the path to the source file.
+ logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
+ FileUtil.symLink(url, targetFile.getAbsolutePath)
+ }
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
@@ -159,8 +169,15 @@ private object Utils extends Logging {
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
- val out = new FileOutputStream(targetFile)
+ val out = new FileOutputStream(tempFile)
Utils.copyStream(in, out, true)
+ if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
+ tempFile.delete()
+ throw new SparkException("File " + targetFile + " exists and does not match contents of" +
+ " " + url)
+ } else {
+ Files.move(tempFile, targetFile)
+ }
}
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
@@ -171,7 +188,16 @@ private object Utils extends Logging {
Utils.execute(Seq("tar", "-xf", filename), targetDir)
}
// Make the file executable - That's necessary for scripts
- FileUtil.chmod(filename, "a+x")
+ FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
+ }
+
+ /**
+ * Get a temporary directory using Spark's spark.local.dir property, if set. This will always
+ * return a single directory, even though the spark.local.dir property might be a list of
+ * multiple paths.
+ */
+ def getLocalDir: String = {
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0)
}
/**
@@ -212,7 +238,8 @@ private object Utils extends Logging {
// Address resolves to something like 127.0.1.1, which happens on Debian; try to find
// a better address using the local network interfaces
for (ni <- NetworkInterface.getNetworkInterfaces) {
- for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress && !addr.isLoopbackAddress) {
+ for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress &&
+ !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) {
// We've found an address that looks reasonable!
logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
" a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress +
@@ -247,48 +274,28 @@ private object Utils extends Logging {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
}
- /**
- * Returns a standard ThreadFactory except all threads are daemons.
- */
- private def newDaemonThreadFactory: ThreadFactory = {
- new ThreadFactory {
- def newThread(r: Runnable): Thread = {
- var t = Executors.defaultThreadFactory.newThread (r)
- t.setDaemon (true)
- return t
- }
- }
- }
+ private[spark] val daemonThreadFactory: ThreadFactory =
+ new ThreadFactoryBuilder().setDaemon(true).build()
/**
* Wrapper over newCachedThreadPool.
*/
- def newDaemonCachedThreadPool(): ThreadPoolExecutor = {
- var threadPool = Executors.newCachedThreadPool.asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory (newDaemonThreadFactory)
-
- return threadPool
- }
+ def newDaemonCachedThreadPool(): ThreadPoolExecutor =
+ Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
* millisecond.
*/
def getUsedTimeMs(startTimeMs: Long): String = {
- return " " + (System.currentTimeMillis - startTimeMs) + " ms "
+ return " " + (System.currentTimeMillis - startTimeMs) + " ms"
}
/**
* Wrapper over newFixedThreadPool.
*/
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = {
- var threadPool = Executors.newFixedThreadPool(nThreads).asInstanceOf[ThreadPoolExecutor]
-
- threadPool.setThreadFactory(newDaemonThreadFactory)
-
- return threadPool
- }
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
+ Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
/**
* Delete a file or directory and its contents recursively.
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index 5c2be534ff..8ce32e0e2f 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -471,6 +471,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
fromRDD(new OrderedRDDFunctions(rdd).sortByKey(ascending))
}
+
+ /**
+ * Return an RDD with the keys of each tuple.
+ */
+ def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1))
+
+ /**
+ * Return an RDD with the values of each tuple.
+ */
+ def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2))
}
object JavaPairRDD {
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 81d3a94466..b3698ffa44 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -9,6 +9,7 @@ import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import spark.partial.{PartialResult, BoundedDouble}
import spark.storage.StorageLevel
+import com.google.common.base.Optional
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
@@ -298,4 +299,36 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Save this RDD as a SequenceFile of serialized objects.
*/
def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
+
+ /**
+ * Creates tuples of the elements in this RDD by applying `f`.
+ */
+ def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
+ implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ JavaPairRDD.fromRDD(rdd.keyBy(f))
+ }
+
+ /**
+ * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+ * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * RDDs will be removed. This function must be called before any job has been
+ * executed on this RDD. It is strongly recommended that this RDD is persisted in
+ * memory, otherwise saving it on a file will require recomputation.
+ */
+ def checkpoint() = rdd.checkpoint()
+
+ /**
+ * Return whether this RDD has been checkpointed or not
+ */
+ def isCheckpointed(): Boolean = rdd.isCheckpointed()
+
+ /**
+ * Gets the name of the file to which this RDD was checkpointed
+ */
+ def getCheckpointFile(): Optional[String] = {
+ rdd.getCheckpointFile match {
+ case Some(file) => Optional.of(file)
+ case _ => Optional.absent()
+ }
+ }
}
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index edbb187b1b..50b8970cd8 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import spark.{Accumulator, AccumulatorParam, RDD, SparkContext}
+import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext}
import spark.SparkContext.IntAccumulatorParam
import spark.SparkContext.DoubleAccumulatorParam
import spark.broadcast.Broadcast
@@ -265,26 +265,46 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/**
* Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
- * to using the `+=` method. Only the master can access the accumulator's `value`.
+ * to using the `add` method. Only the master can access the accumulator's `value`.
*/
- def intAccumulator(initialValue: Int): Accumulator[Int] =
- sc.accumulator(initialValue)(IntAccumulatorParam)
+ def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] =
+ sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]]
/**
* Create an [[spark.Accumulator]] double variable, which tasks can "add" values
- * to using the `+=` method. Only the master can access the accumulator's `value`.
+ * to using the `add` method. Only the master can access the accumulator's `value`.
*/
- def doubleAccumulator(initialValue: Double): Accumulator[Double] =
- sc.accumulator(initialValue)(DoubleAccumulatorParam)
+ def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] =
+ sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]]
+
+ /**
+ * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ */
+ def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue)
+
+ /**
+ * Create an [[spark.Accumulator]] double variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ */
+ def accumulator(initialValue: Double): Accumulator[java.lang.Double] =
+ doubleAccumulator(initialValue)
/**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
- * to using the `+=` method. Only the master can access the accumulator's `value`.
+ * to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
sc.accumulator(initialValue)(accumulatorParam)
/**
+ * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can
+ * "add" values with `add`. Only the master can access the accumuable's `value`.
+ */
+ def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] =
+ sc.accumulable(initialValue)(param)
+
+ /**
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
@@ -301,6 +321,75 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* (in that order of preference). If neither of these is set, return None.
*/
def getSparkHome(): Option[String] = sc.getSparkHome()
+
+ /**
+ * Add a file to be downloaded with this Spark job on every node.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
+ */
+ def addFile(path: String) {
+ sc.addFile(path)
+ }
+
+ /**
+ * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI.
+ */
+ def addJar(path: String) {
+ sc.addJar(path)
+ }
+
+ /**
+ * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
+ * any new nodes.
+ */
+ def clearJars() {
+ sc.clearJars()
+ }
+
+ /**
+ * Clear the job's list of files added by `addFile` so that they do not get downloaded to
+ * any new nodes.
+ */
+ def clearFiles() {
+ sc.clearFiles()
+ }
+
+ /**
+ * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse.
+ */
+ def hadoopConfiguration(): Configuration = {
+ sc.hadoopConfiguration
+ }
+
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists and useExisting is set to true, then the
+ * exisiting directory will be used. Otherwise an exception will be thrown to
+ * prevent accidental overriding of checkpoint files in the existing directory.
+ */
+ def setCheckpointDir(dir: String, useExisting: Boolean) {
+ sc.setCheckpointDir(dir, useExisting)
+ }
+
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists, an exception will be thrown to prevent accidental
+ * overriding of checkpoint files.
+ */
+ def setCheckpointDir(dir: String) {
+ sc.setCheckpointDir(dir)
+ }
+
+ protected def checkpointFile[T](path: String): JavaRDD[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ new JavaRDD(sc.checkpointFile(path))
+ }
}
object JavaSparkContext {
diff --git a/core/src/main/scala/spark/api/java/StorageLevels.java b/core/src/main/scala/spark/api/java/StorageLevels.java
index 722af3c06c..5e5845ac3a 100644
--- a/core/src/main/scala/spark/api/java/StorageLevels.java
+++ b/core/src/main/scala/spark/api/java/StorageLevels.java
@@ -17,4 +17,15 @@ public class StorageLevels {
public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2);
public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1);
public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2);
+
+ /**
+ * Create a new StorageLevel object.
+ * @param useDisk saved to disk, if true
+ * @param useMemory saved to memory, if true
+ * @param deserialized saved as deserialized objects, if true
+ * @param replication replication factor
+ */
+ public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) {
+ return StorageLevel.apply(useDisk, useMemory, deserialized, replication);
+ }
}
diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
new file mode 100644
index 0000000000..519e310323
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala
@@ -0,0 +1,48 @@
+package spark.api.python
+
+import spark.Partitioner
+
+import java.util.Arrays
+
+/**
+ * A [[spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ *
+ * Stores the unique id() of the Python-side partitioning function so that it is incorporated into
+ * equality comparisons. Correctness requires that the id is a unique identifier for the
+ * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning
+ * function). This can be ensured by using the Python id() function and maintaining a reference
+ * to the Python partitioning function so that its id() is not reused.
+ */
+private[spark] class PythonPartitioner(
+ override val numPartitions: Int,
+ val pyPartitionFunctionId: Long)
+ extends Partitioner {
+
+ override def getPartition(key: Any): Int = {
+ if (key == null) {
+ return 0
+ }
+ else {
+ val hashCode = {
+ if (key.isInstanceOf[Array[Byte]]) {
+ Arrays.hashCode(key.asInstanceOf[Array[Byte]])
+ } else {
+ key.hashCode()
+ }
+ }
+ val mod = hashCode % numPartitions
+ if (mod < 0) {
+ mod + numPartitions
+ } else {
+ mod // Guard against negative hash codes
+ }
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case h: PythonPartitioner =>
+ h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
+ case _ =>
+ false
+ }
+}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
new file mode 100644
index 0000000000..f43a152ca7
--- /dev/null
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -0,0 +1,293 @@
+package spark.api.python
+
+import java.io._
+import java.net._
+import java.util.{List => JList, ArrayList => JArrayList, Collections}
+
+import scala.collection.JavaConversions._
+import scala.io.Source
+
+import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import spark.broadcast.Broadcast
+import spark._
+import spark.rdd.PipedRDD
+
+
+private[spark] class PythonRDD[T: ClassManifest](
+ parent: RDD[T],
+ command: Seq[String],
+ envVars: java.util.Map[String, String],
+ preservePartitoning: Boolean,
+ pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]])
+ extends RDD[Array[Byte]](parent) {
+
+ // Similar to Runtime.exec(), if we are given a single string, split it into words
+ // using a standard StringTokenizer (i.e. by spaces)
+ def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
+ preservePartitoning: Boolean, pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]]) =
+ this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
+ broadcastVars, accumulator)
+
+ override def getSplits = parent.splits
+
+ override val partitioner = if (preservePartitoning) parent.partitioner else None
+
+ override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = {
+ val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
+
+ val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
+ // Add the environmental variables to the process.
+ val currentEnvVars = pb.environment()
+
+ for ((variable, value) <- envVars) {
+ currentEnvVars.put(variable, value)
+ }
+
+ val proc = pb.start()
+ val env = SparkEnv.get
+
+ // Start a thread to print the process's stderr to ours
+ new Thread("stderr reader for " + command) {
+ override def run() {
+ for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
+ System.err.println(line)
+ }
+ }
+ }.start()
+
+ // Start a thread to feed the process input from our parent's iterator
+ new Thread("stdin writer for " + command) {
+ override def run() {
+ SparkEnv.set(env)
+ val out = new PrintWriter(proc.getOutputStream)
+ val dOut = new DataOutputStream(proc.getOutputStream)
+ // Split index
+ dOut.writeInt(split.index)
+ // sparkFilesDir
+ PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
+ // Broadcast variables
+ dOut.writeInt(broadcastVars.length)
+ for (broadcast <- broadcastVars) {
+ dOut.writeLong(broadcast.id)
+ dOut.writeInt(broadcast.value.length)
+ dOut.write(broadcast.value)
+ dOut.flush()
+ }
+ // Serialized user code
+ for (elem <- command) {
+ out.println(elem)
+ }
+ out.flush()
+ // Data values
+ for (elem <- parent.iterator(split, context)) {
+ PythonRDD.writeAsPickle(elem, dOut)
+ }
+ dOut.flush()
+ out.flush()
+ proc.getOutputStream.close()
+ }
+ }.start()
+
+ // Return an iterator that read lines from the process's stdout
+ val stream = new DataInputStream(proc.getInputStream)
+ return new Iterator[Array[Byte]] {
+ def next(): Array[Byte] = {
+ val obj = _nextObj
+ _nextObj = read()
+ obj
+ }
+
+ private def read(): Array[Byte] = {
+ try {
+ val length = stream.readInt()
+ if (length != -1) {
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
+ } else {
+ // We've finished the data section of the output, but we can still read some
+ // accumulator updates; let's do that, breaking when we get EOFException
+ while (true) {
+ val len2 = stream.readInt()
+ val update = new Array[Byte](len2)
+ stream.readFully(update)
+ accumulator += Collections.singletonList(update)
+ }
+ new Array[Byte](0)
+ }
+ } catch {
+ case eof: EOFException => {
+ val exitStatus = proc.waitFor()
+ if (exitStatus != 0) {
+ throw new Exception("Subprocess exited with status " + exitStatus)
+ }
+ new Array[Byte](0)
+ }
+ case e => throw e
+ }
+ }
+
+ var _nextObj = read()
+
+ def hasNext = _nextObj.length != 0
+ }
+ }
+
+ val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
+}
+
+/**
+ * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
+ * This is used by PySpark's shuffle operations.
+ */
+private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
+ RDD[(Array[Byte], Array[Byte])](prev) {
+ override def getSplits = prev.splits
+ override def compute(split: Split, context: TaskContext) =
+ prev.iterator(split, context).grouped(2).map {
+ case Seq(a, b) => (a, b)
+ case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
+ }
+ val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
+}
+
+private[spark] object PythonRDD {
+
+ /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
+ def stripPickle(arr: Array[Byte]) : Array[Byte] = {
+ arr.slice(2, arr.length - 1)
+ }
+
+ /**
+ * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
+ * The data format is a 32-bit integer representing the pickled object's length (in bytes),
+ * followed by the pickled data.
+ *
+ * Pickle module:
+ *
+ * http://docs.python.org/2/library/pickle.html
+ *
+ * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
+ *
+ * http://hg.python.org/cpython/file/2.6/Lib/pickle.py
+ * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
+ *
+ * @param elem the object to write
+ * @param dOut a data output stream
+ */
+ def writeAsPickle(elem: Any, dOut: DataOutputStream) {
+ if (elem.isInstanceOf[Array[Byte]]) {
+ val arr = elem.asInstanceOf[Array[Byte]]
+ dOut.writeInt(arr.length)
+ dOut.write(arr)
+ } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
+ val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
+ val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
+ dOut.writeInt(length)
+ dOut.writeByte(Pickle.PROTO)
+ dOut.writeByte(Pickle.TWO)
+ dOut.write(PythonRDD.stripPickle(t._1))
+ dOut.write(PythonRDD.stripPickle(t._2))
+ dOut.writeByte(Pickle.TUPLE2)
+ dOut.writeByte(Pickle.STOP)
+ } else if (elem.isInstanceOf[String]) {
+ // For uniformity, strings are wrapped into Pickles.
+ val s = elem.asInstanceOf[String].getBytes("UTF-8")
+ val length = 2 + 1 + 4 + s.length + 1
+ dOut.writeInt(length)
+ dOut.writeByte(Pickle.PROTO)
+ dOut.writeByte(Pickle.TWO)
+ dOut.write(Pickle.BINUNICODE)
+ dOut.writeInt(Integer.reverseBytes(s.length))
+ dOut.write(s)
+ dOut.writeByte(Pickle.STOP)
+ } else {
+ throw new Exception("Unexpected RDD type")
+ }
+ }
+
+ def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+ JavaRDD[Array[Byte]] = {
+ val file = new DataInputStream(new FileInputStream(filename))
+ val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+ try {
+ while (true) {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ objs.append(obj)
+ }
+ } catch {
+ case eof: EOFException => {}
+ case e => throw e
+ }
+ JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+ }
+
+ def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ val file = new DataOutputStream(new FileOutputStream(filename))
+ for (item <- items) {
+ writeAsPickle(item, file)
+ }
+ file.close()
+ }
+
+ def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] =
+ rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head
+}
+
+private object Pickle {
+ val PROTO: Byte = 0x80.toByte
+ val TWO: Byte = 0x02.toByte
+ val BINUNICODE: Byte = 'X'
+ val STOP: Byte = '.'
+ val TUPLE2: Byte = 0x86.toByte
+ val EMPTY_LIST: Byte = ']'
+ val MARK: Byte = '('
+ val APPENDS: Byte = 'e'
+}
+
+private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
+ override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
+}
+
+/**
+ * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
+ * collects a list of pickled strings that we pass to Python through a socket.
+ */
+class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+ extends AccumulatorParam[JList[Array[Byte]]] {
+
+ override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
+
+ override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
+ : JList[Array[Byte]] = {
+ if (serverHost == null) {
+ // This happens on the worker node, where we just want to remember all the updates
+ val1.addAll(val2)
+ val1
+ } else {
+ // This happens on the master, where we pass the updates to Python through a socket
+ val socket = new Socket(serverHost, serverPort)
+ val in = socket.getInputStream
+ val out = new DataOutputStream(socket.getOutputStream)
+ out.writeInt(val2.size)
+ for (array <- val2) {
+ out.writeInt(array.length)
+ out.write(array)
+ }
+ out.flush()
+ // Wait for a byte from the Python side as an acknowledgement
+ val byteRead = in.read()
+ if (byteRead == -1) {
+ throw new SparkException("EOF reached before Python server acknowledged")
+ }
+ socket.close()
+ null
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/broadcast/Broadcast.scala b/core/src/main/scala/spark/broadcast/Broadcast.scala
index 6055bfd045..2ffe7f741d 100644
--- a/core/src/main/scala/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/spark/broadcast/Broadcast.scala
@@ -5,7 +5,7 @@ import java.util.concurrent.atomic.AtomicLong
import spark._
-abstract class Broadcast[T](id: Long) extends Serializable {
+abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
index 7eb4ddb74f..8e490e6bad 100644
--- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
@@ -11,6 +11,7 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark._
import spark.storage.StorageLevel
+import util.{MetadataCleaner, TimeStampedHashSet}
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
@@ -64,6 +65,10 @@ private object HttpBroadcast extends Logging {
private var serverUri: String = null
private var server: HttpServer = null
+ private val files = new TimeStampedHashSet[String]
+ private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
+
+
def initialize(isMaster: Boolean) {
synchronized {
if (!initialized) {
@@ -85,11 +90,12 @@ private object HttpBroadcast extends Logging {
server = null
}
initialized = false
+ cleaner.cancel()
}
}
private def createServer() {
- broadcastDir = Utils.createTempDir()
+ broadcastDir = Utils.createTempDir(Utils.getLocalDir)
server = new HttpServer(broadcastDir)
server.start()
serverUri = server.uri
@@ -108,6 +114,7 @@ private object HttpBroadcast extends Logging {
val serOut = ser.serializeStream(out)
serOut.writeObject(value)
serOut.close()
+ files += file.getAbsolutePath
}
def read[T](id: Long): T = {
@@ -123,4 +130,21 @@ private object HttpBroadcast extends Logging {
serIn.close()
obj
}
+
+ def cleanup(cleanupTime: Long) {
+ val iterator = files.internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ val (file, time) = (entry.getKey, entry.getValue)
+ if (time < cleanupTime) {
+ try {
+ iterator.remove()
+ new File(file.toString).delete()
+ logInfo("Deleted broadcast file '" + file + "'")
+ } catch {
+ case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
+ }
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala
index 457122745b..35f40c6e91 100644
--- a/core/src/main/scala/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/spark/deploy/DeployMessage.scala
@@ -4,7 +4,6 @@ import spark.deploy.ExecutorState.ExecutorState
import spark.deploy.master.{WorkerInfo, JobInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
-import scala.collection.mutable.HashMap
private[spark] sealed trait DeployMessage extends Serializable
@@ -42,7 +41,8 @@ private[spark] case class LaunchExecutor(
execId: Int,
jobDesc: JobDescription,
cores: Int,
- memory: Int)
+ memory: Int,
+ sparkHome: String)
extends DeployMessage
diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/JobDescription.scala
index 20879c5f11..7160fc05fc 100644
--- a/core/src/main/scala/spark/deploy/JobDescription.scala
+++ b/core/src/main/scala/spark/deploy/JobDescription.scala
@@ -4,7 +4,8 @@ private[spark] class JobDescription(
val name: String,
val cores: Int,
val memoryPerSlave: Int,
- val command: Command)
+ val command: Command,
+ val sparkHome: String)
extends Serializable {
val user = System.getProperty("user.name", "<unknown>")
diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala
new file mode 100644
index 0000000000..732fa08064
--- /dev/null
+++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala
@@ -0,0 +1,78 @@
+package spark.deploy
+
+import master.{JobInfo, WorkerInfo}
+import worker.ExecutorRunner
+import cc.spray.json._
+
+/**
+ * spray-json helper class containing implicit conversion to json for marshalling responses
+ */
+private[spark] object JsonProtocol extends DefaultJsonProtocol {
+ implicit object WorkerInfoJsonFormat extends RootJsonWriter[WorkerInfo] {
+ def write(obj: WorkerInfo) = JsObject(
+ "id" -> JsString(obj.id),
+ "host" -> JsString(obj.host),
+ "webuiaddress" -> JsString(obj.webUiAddress),
+ "cores" -> JsNumber(obj.cores),
+ "coresused" -> JsNumber(obj.coresUsed),
+ "memory" -> JsNumber(obj.memory),
+ "memoryused" -> JsNumber(obj.memoryUsed)
+ )
+ }
+
+ implicit object JobInfoJsonFormat extends RootJsonWriter[JobInfo] {
+ def write(obj: JobInfo) = JsObject(
+ "starttime" -> JsNumber(obj.startTime),
+ "id" -> JsString(obj.id),
+ "name" -> JsString(obj.desc.name),
+ "cores" -> JsNumber(obj.desc.cores),
+ "user" -> JsString(obj.desc.user),
+ "memoryperslave" -> JsNumber(obj.desc.memoryPerSlave),
+ "submitdate" -> JsString(obj.submitDate.toString))
+ }
+
+ implicit object JobDescriptionJsonFormat extends RootJsonWriter[JobDescription] {
+ def write(obj: JobDescription) = JsObject(
+ "name" -> JsString(obj.name),
+ "cores" -> JsNumber(obj.cores),
+ "memoryperslave" -> JsNumber(obj.memoryPerSlave),
+ "user" -> JsString(obj.user)
+ )
+ }
+
+ implicit object ExecutorRunnerJsonFormat extends RootJsonWriter[ExecutorRunner] {
+ def write(obj: ExecutorRunner) = JsObject(
+ "id" -> JsNumber(obj.execId),
+ "memory" -> JsNumber(obj.memory),
+ "jobid" -> JsString(obj.jobId),
+ "jobdesc" -> obj.jobDesc.toJson.asJsObject
+ )
+ }
+
+ implicit object MasterStateJsonFormat extends RootJsonWriter[MasterState] {
+ def write(obj: MasterState) = JsObject(
+ "url" -> JsString("spark://" + obj.uri),
+ "workers" -> JsArray(obj.workers.toList.map(_.toJson)),
+ "cores" -> JsNumber(obj.workers.map(_.cores).sum),
+ "coresused" -> JsNumber(obj.workers.map(_.coresUsed).sum),
+ "memory" -> JsNumber(obj.workers.map(_.memory).sum),
+ "memoryused" -> JsNumber(obj.workers.map(_.memoryUsed).sum),
+ "activejobs" -> JsArray(obj.activeJobs.toList.map(_.toJson)),
+ "completedjobs" -> JsArray(obj.completedJobs.toList.map(_.toJson))
+ )
+ }
+
+ implicit object WorkerStateJsonFormat extends RootJsonWriter[WorkerState] {
+ def write(obj: WorkerState) = JsObject(
+ "id" -> JsString(obj.workerId),
+ "masterurl" -> JsString(obj.masterUrl),
+ "masterwebuiurl" -> JsString(obj.masterWebUiUrl),
+ "cores" -> JsNumber(obj.cores),
+ "coresused" -> JsNumber(obj.coresUsed),
+ "memory" -> JsNumber(obj.memory),
+ "memoryused" -> JsNumber(obj.memoryUsed),
+ "executors" -> JsArray(obj.executors.toList.map(_.toJson)),
+ "finishedexecutors" -> JsArray(obj.finishedExecutors.toList.map(_.toJson))
+ )
+ }
+}
diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala
index 57a7e123b7..8764c400e2 100644
--- a/core/src/main/scala/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/spark/deploy/client/TestClient.scala
@@ -25,7 +25,7 @@ private[spark] object TestClient {
val url = args(0)
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
val desc = new JobDescription(
- "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()))
+ "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener)
client.start()
diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala
index b30c8e99b5..2c2cd0231b 100644
--- a/core/src/main/scala/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/spark/deploy/master/Master.scala
@@ -156,7 +156,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
if (spreadOutJobs) {
// Try to spread out each job among all the nodes, until it has all its cores
for (job <- waitingJobs if job.coresLeft > 0) {
- val usableWorkers = workers.toArray.filter(canUse(job, _)).sortBy(_.coresFree).reverse
+ val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE)
+ .filter(canUse(job, _)).sortBy(_.coresFree).reverse
val numUsable = usableWorkers.length
val assigned = new Array[Int](numUsable) // Number of cores to give on each node
var toAssign = math.min(job.coresLeft, usableWorkers.map(_.coresFree).sum)
@@ -172,7 +173,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
for (pos <- 0 until numUsable) {
if (assigned(pos) > 0) {
val exec = job.addExecutor(usableWorkers(pos), assigned(pos))
- launchExecutor(usableWorkers(pos), exec)
+ launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome)
job.state = JobState.RUNNING
}
}
@@ -185,7 +186,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
val coresToUse = math.min(worker.coresFree, job.coresLeft)
if (coresToUse > 0) {
val exec = job.addExecutor(worker, coresToUse)
- launchExecutor(worker, exec)
+ launchExecutor(worker, exec, job.desc.sparkHome)
job.state = JobState.RUNNING
}
}
@@ -194,15 +195,17 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
}
- def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) {
+ def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
- worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory)
+ worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome)
exec.job.actor ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
publicAddress: String): WorkerInfo = {
+ // There may be one or more refs to dead workers on this same node (w/ different ID's), remove them.
+ workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
workers += worker
idToWorker(worker.id) = worker
@@ -213,7 +216,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
def removeWorker(worker: WorkerInfo) {
logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port)
- workers -= worker
+ worker.setState(WorkerState.DEAD)
idToWorker -= worker.id
actorToWorker -= worker.actor
addressToWorker -= worker.actor.path.address
diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
index 3cdd3721f5..458ee2d665 100644
--- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
+++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala
@@ -8,7 +8,11 @@ import akka.util.duration._
import cc.spray.Directives
import cc.spray.directives._
import cc.spray.typeconversion.TwirlSupport._
+import cc.spray.http.MediaTypes
+import cc.spray.typeconversion.SprayJsonSupport._
+
import spark.deploy._
+import spark.deploy.JsonProtocol._
private[spark]
class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives {
@@ -19,29 +23,51 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
val handler = {
get {
- path("") {
- completeWith {
+ (path("") & parameters('format ?)) {
+ case Some(js) if js.equalsIgnoreCase("json") =>
val future = master ? RequestMasterState
- future.map {
- masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState])
+ respondWithMediaType(MediaTypes.`application/json`) { ctx =>
+ ctx.complete(future.mapTo[MasterState])
+ }
+ case _ =>
+ completeWith {
+ val future = master ? RequestMasterState
+ future.map {
+ masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState])
+ }
}
- }
} ~
path("job") {
- parameter("jobId") { jobId =>
- completeWith {
+ parameters("jobId", 'format ?) {
+ case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) =>
val future = master ? RequestMasterState
- future.map { state =>
- val masterState = state.asInstanceOf[MasterState]
-
- // A bit ugly an inefficient, but we won't have a number of jobs
- // so large that it will make a significant difference.
- (masterState.activeJobs ++ masterState.completedJobs).find(_.id == jobId) match {
- case Some(job) => spark.deploy.master.html.job_details.render(job)
- case _ => null
+ val jobInfo = for (masterState <- future.mapTo[MasterState]) yield {
+ masterState.activeJobs.find(_.id == jobId) match {
+ case Some(job) => job
+ case _ => masterState.completedJobs.find(_.id == jobId) match {
+ case Some(job) => job
+ case _ => null
+ }
+ }
+ }
+ respondWithMediaType(MediaTypes.`application/json`) { ctx =>
+ ctx.complete(jobInfo.mapTo[JobInfo])
+ }
+ case (jobId, _) =>
+ completeWith {
+ val future = master ? RequestMasterState
+ future.map { state =>
+ val masterState = state.asInstanceOf[MasterState]
+
+ masterState.activeJobs.find(_.id == jobId) match {
+ case Some(job) => spark.deploy.master.html.job_details.render(job)
+ case _ => masterState.completedJobs.find(_.id == jobId) match {
+ case Some(job) => spark.deploy.master.html.job_details.render(job)
+ case _ => null
+ }
+ }
}
}
- }
}
} ~
pathPrefix("static") {
diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
index a0a698ef04..5a7f5fef8a 100644
--- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala
@@ -14,7 +14,7 @@ private[spark] class WorkerInfo(
val publicAddress: String) {
var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
-
+ var state: WorkerState.Value = WorkerState.ALIVE
var coresUsed = 0
var memoryUsed = 0
@@ -42,4 +42,8 @@ private[spark] class WorkerInfo(
def webUiAddress : String = {
"http://" + this.publicAddress + ":" + this.webUiPort
}
+
+ def setState(state: WorkerState.Value) = {
+ this.state = state
+ }
}
diff --git a/core/src/main/scala/spark/deploy/master/WorkerState.scala b/core/src/main/scala/spark/deploy/master/WorkerState.scala
new file mode 100644
index 0000000000..0bf35014c8
--- /dev/null
+++ b/core/src/main/scala/spark/deploy/master/WorkerState.scala
@@ -0,0 +1,7 @@
+package spark.deploy.master
+
+private[spark] object WorkerState extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED") {
+ type WorkerState = Value
+
+ val ALIVE, DEAD, DECOMMISSIONED = Value
+}
diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
index beceb55ecd..0d1fe2a6b4 100644
--- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala
@@ -106,11 +106,6 @@ private[spark] class ExecutorRunner(
throw new IOException("Failed to create directory " + executorDir)
}
- // Download the files it depends on into it (disabled for now)
- //for (url <- jobDesc.fileUrls) {
- // fetchFile(url, executorDir)
- //}
-
// Launch the process
val command = buildCommandSeq()
val builder = new ProcessBuilder(command: _*).directory(executorDir)
diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala
index 7c9e588ea2..19bf2be118 100644
--- a/core/src/main/scala/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/spark/deploy/worker/Worker.scala
@@ -119,10 +119,10 @@ private[spark] class Worker(
logError("Worker registration failed: " + message)
System.exit(1)
- case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_) =>
+ case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name))
val manager = new ExecutorRunner(
- jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, sparkHome, workDir)
+ jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
executors(jobId + "/" + execId) = manager
manager.start()
coresUsed += cores_
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
index 340920025b..37524a7c82 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala
@@ -104,9 +104,25 @@ private[spark] class WorkerArguments(args: Array[String]) {
}
def inferDefaultMemory(): Int = {
- val bean = ManagementFactory.getOperatingSystemMXBean
- .asInstanceOf[com.sun.management.OperatingSystemMXBean]
- val totalMb = (bean.getTotalPhysicalMemorySize / 1024 / 1024).toInt
+ val ibmVendor = System.getProperty("java.vendor").contains("IBM")
+ var totalMb = 0
+ try {
+ val bean = ManagementFactory.getOperatingSystemMXBean()
+ if (ibmVendor) {
+ val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean")
+ val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory")
+ totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
+ } else {
+ val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean")
+ val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize")
+ totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
+ }
+ } catch {
+ case e: Exception => {
+ totalMb = 2*1024
+ System.out.println("Failed to get total physical memory. Using " + totalMb + " MB")
+ }
+ }
// Leave out 1 GB for the operating system, but don't return a negative memory size
math.max(totalMb - 1024, 512)
}
diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
index d06f4884ee..f9489d99fc 100644
--- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
+++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala
@@ -7,7 +7,11 @@ import akka.util.Timeout
import akka.util.duration._
import cc.spray.Directives
import cc.spray.typeconversion.TwirlSupport._
+import cc.spray.http.MediaTypes
+import cc.spray.typeconversion.SprayJsonSupport._
+
import spark.deploy.{WorkerState, RequestWorkerState}
+import spark.deploy.JsonProtocol._
private[spark]
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
@@ -18,13 +22,20 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
val handler = {
get {
- path("") {
- completeWith{
+ (path("") & parameters('format ?)) {
+ case Some(js) if js.equalsIgnoreCase("json") => {
val future = worker ? RequestWorkerState
- future.map { workerState =>
- spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState])
+ respondWithMediaType(MediaTypes.`application/json`) { ctx =>
+ ctx.complete(future.mapTo[WorkerState])
}
}
+ case _ =>
+ completeWith{
+ val future = worker ? RequestWorkerState
+ future.map { workerState =>
+ spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState])
+ }
+ }
} ~
path("log") {
parameters("jobId", "executorId", "logType") { (jobId, executorId, logType) =>
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 2552958d27..28d9d40d43 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -159,22 +159,24 @@ private[spark] class Executor extends Logging {
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- // Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentFiles(name) = timestamp
- }
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentJars(name) = timestamp
- // Add it to our class loader
- val localName = name.split("/").last
- val url = new File(".", localName).toURI.toURL
- if (!urlClassLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- urlClassLoader.addURL(url)
+ synchronized {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentFiles(name) = timestamp
+ }
+ for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
+ if (!urlClassLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ urlClassLoader.addURL(url)
+ }
}
}
}
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index 915f71ba9f..a29bf974d2 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -24,9 +24,6 @@ private[spark] class StandaloneExecutorBackend(
with ExecutorBackend
with Logging {
- val threadPool = new ThreadPoolExecutor(
- 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
-
var master: ActorRef = null
override def preStart() {
diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index 80262ab7b4..c193bf7c8d 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -135,8 +135,11 @@ extends Connection(SocketChannel.open, selector_) {
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
messages += message // this is probably incorrect, it wont work as fifo
- if (!message.started) logDebug("Starting to send [" + message + "]")
- message.started = true
+ if (!message.started) {
+ logDebug("Starting to send [" + message + "]")
+ message.started = true
+ message.startTime = System.currentTimeMillis
+ }
return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 642fa4b525..2ecd14f536 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -43,18 +43,17 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
val selector = SelectorProvider.provider.openSelector()
- val handleMessageExecutor = Executors.newFixedThreadPool(4)
+ val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
val serverChannel = ServerSocketChannel.open()
val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val messageStatuses = new HashMap[Int, MessageStatus]
- val connectionRequests = new SynchronizedQueue[SendingConnection]
+ val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)]
- implicit val futureExecContext = ExecutionContext.fromExecutor(
- Executors.newCachedThreadPool(DaemonThreadFactory))
-
+ implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
+
var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
serverChannel.configureBlocking(false)
@@ -79,10 +78,10 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def run() {
try {
while(!selectorThread.isInterrupted) {
- while(!connectionRequests.isEmpty) {
- val sendingConnection = connectionRequests.dequeue
+ for( (connectionManagerId, sendingConnection) <- connectionRequests) {
sendingConnection.connect()
addConnection(sendingConnection)
+ connectionRequests -= connectionManagerId
}
sendMessageRequests.synchronized {
while(!sendMessageRequests.isEmpty) {
@@ -300,8 +299,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
- val newConnection = new SendingConnection(inetSocketAddress, selector)
- connectionRequests += newConnection
+ val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector))
newConnection
}
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
@@ -473,6 +471,7 @@ private[spark] object ConnectionManager {
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
+ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
}
diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
index 47ceaf3c07..533e4610f3 100644
--- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala
+++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala
@@ -13,8 +13,14 @@ import akka.util.duration._
private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
+ //<mesos cluster> - the master URL
+ //<slaves file> - a list slaves to run connectionTest on
+ //[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts
+ //[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10
+ //[count] - how many times to run, default is 3
+ //[await time in seconds] : await time (in seconds), default is 600
if (args.length < 2) {
- println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")
+ println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ")
System.exit(1)
}
@@ -29,16 +35,19 @@ private[spark] object ConnectionManagerTest extends Logging{
/*println("Slaves")*/
/*slaves.foreach(println)*/
-
- val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map(
+ val tasknum = if (args.length > 2) args(2).toInt else slaves.length
+ val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
+ val count = if (args.length > 4) args(4).toInt else 3
+ val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second
+ println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime)
+ val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map(
i => SparkEnv.get.connectionManager.id).collect()
println("\nSlave ConnectionManagerIds")
slaveConnManagerIds.foreach(println)
println
- val count = 10
(0 until count).foreach(i => {
- val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => {
+ val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => {
val connManager = SparkEnv.get.connectionManager
val thisConnManagerId = connManager.id
connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
@@ -46,7 +55,6 @@ private[spark] object ConnectionManagerTest extends Logging{
None
})
- val size = 100 * 1024 * 1024
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
@@ -56,13 +64,13 @@ private[spark] object ConnectionManagerTest extends Logging{
logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
})
- val results = futures.map(f => Await.result(f, 1.second))
+ val results = futures.map(f => Await.result(f, awaitTime))
val finishTime = System.currentTimeMillis
Thread.sleep(5000)
val mb = size * results.size / 1024.0 / 1024.0
val ms = finishTime - startTime
- val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
+ val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
logInfo(resultStr)
resultStr
}).collect()
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index f98528a183..2c022f88e0 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -1,9 +1,7 @@
package spark.rdd
import scala.collection.mutable.HashMap
-
-import spark.{Dependency, RDD, SparkContext, SparkEnv, Split, TaskContext}
-
+import spark.{RDD, SparkContext, SparkEnv, Split, TaskContext}
private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
val index = idx
@@ -11,22 +9,20 @@ private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split
private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
- extends RDD[T](sc) {
+ extends RDD[T](sc, Nil) {
- @transient
- val splits_ = (0 until blockIds.size).map(i => {
+ @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => {
new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
}).toArray
- @transient
- lazy val locations_ = {
+ @transient lazy val locations_ = {
val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
val locations = blockManager.getLocations(blockIds)
HashMap(blockIds.zip(locations):_*)
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
@@ -38,9 +34,11 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
}
}
- override def preferredLocations(split: Split) =
+ override def getPreferredLocations(split: Split) =
locations_(split.asInstanceOf[BlockRDDSplit].blockId)
- override val dependencies: List[Dependency[_]] = Nil
+ override def clearDependencies() {
+ splits_ = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala
index 4a7e5f3d06..453d410ad4 100644
--- a/core/src/main/scala/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala
@@ -1,37 +1,53 @@
package spark.rdd
-import spark.{NarrowDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext}
private[spark]
-class CartesianSplit(idx: Int, val s1: Split, val s2: Split) extends Split with Serializable {
+class CartesianSplit(
+ idx: Int,
+ @transient rdd1: RDD[_],
+ @transient rdd2: RDD[_],
+ s1Index: Int,
+ s2Index: Int
+ ) extends Split {
+ var s1 = rdd1.splits(s1Index)
+ var s2 = rdd2.splits(s2Index)
override val index: Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ s1 = rdd1.splits(s1Index)
+ s2 = rdd2.splits(s2Index)
+ oos.defaultWriteObject()
+ }
}
private[spark]
class CartesianRDD[T: ClassManifest, U:ClassManifest](
sc: SparkContext,
- rdd1: RDD[T],
- rdd2: RDD[U])
- extends RDD[Pair[T, U]](sc)
+ var rdd1 : RDD[T],
+ var rdd2 : RDD[U])
+ extends RDD[Pair[T, U]](sc, Nil)
with Serializable {
val numSplitsInRdd2 = rdd2.splits.size
- @transient
- val splits_ = {
+ @transient var splits_ = {
// create the cross product split
val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
val idx = s1.index * numSplitsInRdd2 + s2.index
- array(idx) = new CartesianSplit(idx, s1, s2)
+ array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index)
}
array
}
- override def splits = splits_
+ override def getSplits = splits_
- override def preferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Split) = {
val currSplit = split.asInstanceOf[CartesianSplit]
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
}
@@ -42,7 +58,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
- override val dependencies = List(
+ var deps_ = List(
new NarrowDependency(rdd1) {
def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
},
@@ -50,4 +66,13 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2)
}
)
+
+ override def getDependencies = deps_
+
+ override def clearDependencies() {
+ deps_ = Nil
+ splits_ = null
+ rdd1 = null
+ rdd2 = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
new file mode 100644
index 0000000000..6f00f6ac73
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -0,0 +1,128 @@
+package spark.rdd
+
+import spark._
+import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.{NullWritable, BytesWritable}
+import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.fs.Path
+import java.io.{File, IOException, EOFException}
+import java.text.NumberFormat
+
+private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split {
+ override val index: Int = idx
+}
+
+/**
+ * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
+ */
+private[spark]
+class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
+ extends RDD[T](sc, Nil) {
+
+ @transient val path = new Path(checkpointPath)
+ @transient val fs = path.getFileSystem(new Configuration())
+
+ @transient val splits_ : Array[Split] = {
+ val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted
+ splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray
+ }
+
+ checkpointData = Some(new RDDCheckpointData[T](this))
+ checkpointData.get.cpFile = Some(checkpointPath)
+
+ override def getSplits = splits_
+
+ override def getPreferredLocations(split: Split): Seq[String] = {
+ val status = fs.getFileStatus(path)
+ val locations = fs.getFileBlockLocations(status, 0, status.getLen)
+ locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+ }
+
+ override def compute(split: Split, context: TaskContext): Iterator[T] = {
+ CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context)
+ }
+
+ override def checkpoint() {
+ // Do nothing. Hadoop RDD should not be checkpointed.
+ }
+}
+
+private[spark] object CheckpointRDD extends Logging {
+
+ def splitIdToFileName(splitId: Int): String = {
+ val numfmt = NumberFormat.getInstance()
+ numfmt.setMinimumIntegerDigits(5)
+ numfmt.setGroupingUsed(false)
+ "part-" + numfmt.format(splitId)
+ }
+
+ def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) {
+ val outputDir = new Path(path)
+ val fs = outputDir.getFileSystem(new Configuration())
+
+ val finalOutputName = splitIdToFileName(context.splitId)
+ val finalOutputPath = new Path(outputDir, finalOutputName)
+ val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId)
+
+ if (fs.exists(tempOutputPath)) {
+ throw new IOException("Checkpoint failed: temporary path " +
+ tempOutputPath + " already exists")
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
+ val fileOutputStream = if (blockSize < 0) {
+ fs.create(tempOutputPath, false, bufferSize)
+ } else {
+ // This is mainly for testing purpose
+ fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
+ }
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val serializeStream = serializer.serializeStream(fileOutputStream)
+ serializeStream.writeAll(iterator)
+ serializeStream.close()
+
+ if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ if (!fs.delete(finalOutputPath, true)) {
+ throw new IOException("Checkpoint failed: failed to delete earlier output of task "
+ + context.attemptId)
+ }
+ if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ throw new IOException("Checkpoint failed: failed to save output of task: "
+ + context.attemptId)
+ }
+ }
+ }
+
+ def readFromFile[T](path: String, context: TaskContext): Iterator[T] = {
+ val inputPath = new Path(path)
+ val fs = inputPath.getFileSystem(new Configuration())
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val fileInputStream = fs.open(inputPath, bufferSize)
+ val serializer = SparkEnv.get.serializer.newInstance()
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => deserializeStream.close())
+
+ deserializeStream.asIterator.asInstanceOf[Iterator[T]]
+ }
+
+ // Test whether CheckpointRDD generate expected number of splits despite
+ // each split file having multiple blocks. This needs to be run on a
+ // cluster (mesos or standalone) using HDFS.
+ def main(args: Array[String]) {
+ import spark._
+
+ val Array(cluster, hdfsPath) = args
+ val sc = new SparkContext(cluster, "CheckpointRDD Test")
+ val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
+ val path = new Path(hdfsPath, "temp")
+ val fs = path.getFileSystem(new Configuration())
+ sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
+ val cpRDD = new CheckpointRDD[Int](sc, path.toString)
+ assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same")
+ assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same")
+ fs.delete(path)
+ }
+}
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index de0d9fad88..8fafd27bb6 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -1,14 +1,30 @@
package spark.rdd
+import java.io.{ObjectOutputStream, IOException}
+import java.util.{HashMap => JHashMap}
+import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
private[spark] sealed trait CoGroupSplitDep extends Serializable
-private[spark] case class NarrowCoGroupSplitDep(rdd: RDD[_], split: Split) extends CoGroupSplitDep
+
+private[spark] case class NarrowCoGroupSplitDep(
+ rdd: RDD[_],
+ splitIndex: Int,
+ var split: Split
+ ) extends CoGroupSplitDep {
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split = rdd.splits(splitIndex)
+ oos.defaultWriteObject()
+ }
+}
+
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark]
@@ -24,30 +40,29 @@ private[spark] class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
-class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
- extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging {
+class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
+ extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging {
val aggr = new CoGroupAggregator
- @transient
- override val dependencies = {
+ @transient var deps_ = {
val deps = new ArrayBuffer[Dependency[_]]
for ((rdd, index) <- rdds.zipWithIndex) {
- val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
- if (mapSideCombinedRDD.partitioner == Some(part)) {
- logInfo("Adding one-to-one dependency with " + mapSideCombinedRDD)
- deps += new OneToOneDependency(mapSideCombinedRDD)
+ if (rdd.partitioner == Some(part)) {
+ logInfo("Adding one-to-one dependency with " + rdd)
+ deps += new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
+ val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
}
}
deps.toList
}
- @transient
- val splits_ : Array[Split] = {
- val firstRdd = rdds.head
+ override def getDependencies = deps_
+
+ @transient var splits_ : Array[Split] = {
val array = new Array[Split](part.numPartitions)
for (i <- 0 until array.size) {
array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) =>
@@ -55,28 +70,33 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
case s: ShuffleDependency[_, _] =>
new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep
case _ =>
- new NarrowCoGroupSplitDep(r, r.splits(i)): CoGroupSplitDep
+ new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep
}
}.toList)
}
array
}
- override def splits = splits_
-
+ override def getSplits = splits_
+
override val partitioner = Some(part)
- override def preferredLocations(s: Split) = Nil
-
override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size
- val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
+ val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
- map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any]))
+ val seq = map.get(k)
+ if (seq != null) {
+ seq
+ } else {
+ val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
+ map.put(k, seq)
+ seq
+ }
}
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
- case NarrowCoGroupSplitDep(rdd, itsSplit) => {
+ case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => {
// Read them from the parent
for ((k, v) <- rdd.iterator(itsSplit, context)) {
getSeq(k.asInstanceOf[K])(depNum) += v
@@ -93,6 +113,12 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair)
}
}
- map.iterator
+ JavaConversions.mapAsScalaMap(map).iterator
+ }
+
+ override def clearDependencies() {
+ deps_ = null
+ splits_ = null
+ rdds = null
}
}
diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
index 1affe0e0ef..167755bbba 100644
--- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala
@@ -1,9 +1,22 @@
package spark.rdd
-import spark.{NarrowDependency, RDD, Split, TaskContext}
+import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+private[spark] case class CoalescedRDDSplit(
+ index: Int,
+ @transient rdd: RDD[_],
+ parentsIndices: Array[Int]
+ ) extends Split {
+ var parents: Seq[Split] = parentsIndices.map(rdd.splits(_))
-private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) extends Split
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ parents = parentsIndices.map(rdd.splits(_))
+ oos.defaultWriteObject()
+ }
+}
/**
* Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of
@@ -13,34 +26,44 @@ private class CoalescedRDDSplit(val index: Int, val parents: Array[Split]) exten
* This transformation is useful when an RDD with many partitions gets filtered into a smaller one,
* or to avoid having a large number of small tasks when processing a directory with many files.
*/
-class CoalescedRDD[T: ClassManifest](prev: RDD[T], maxPartitions: Int)
- extends RDD[T](prev.context) {
+class CoalescedRDD[T: ClassManifest](
+ var prev: RDD[T],
+ maxPartitions: Int)
+ extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
- @transient val splits_ : Array[Split] = {
+ @transient var splits_ : Array[Split] = {
val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) {
- prevSplits.zipWithIndex.map{ case (s, idx) => new CoalescedRDDSplit(idx, Array(s)) }
+ prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
} else {
(0 until maxPartitions).map { i =>
val rangeStart = (i * prevSplits.length) / maxPartitions
val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
- new CoalescedRDDSplit(i, prevSplits.slice(rangeStart, rangeEnd))
+ new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray)
}.toArray
}
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[T] = {
- split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap {
- parentSplit => prev.iterator(parentSplit, context)
+ split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
+ firstParent[T].iterator(parentSplit, context)
}
}
- val dependencies = List(
+ var deps_ : List[Dependency[_]] = List(
new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
- splits(id).asInstanceOf[CoalescedRDDSplit].parents.map(_.index)
+ splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
}
)
+
+ override def getDependencies() = deps_
+
+ override def clearDependencies() {
+ deps_ = Nil
+ splits_ = null
+ prev = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index b148da28de..6dbe235bd9 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -2,10 +2,15 @@ package spark.rdd
import spark.{OneToOneDependency, RDD, Split, TaskContext}
+private[spark] class FilteredRDD[T: ClassManifest](
+ prev: RDD[T],
+ f: T => Boolean)
+ extends RDD[T](prev) {
-private[spark]
-class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f)
-} \ No newline at end of file
+ override def getSplits = firstParent[T].splits
+
+ override val partitioner = prev.partitioner // Since filter cannot change a partition's keys
+
+ override def compute(split: Split, context: TaskContext) =
+ firstParent[T].iterator(split, context).filter(f)
+}
diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
index 785662b2da..1b604c66e2 100644
--- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala
@@ -1,16 +1,16 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
+
private[spark]
class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => TraversableOnce[U])
- extends RDD[U](prev.context) {
+ extends RDD[U](prev) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
+ override def getSplits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) =
- prev.iterator(split, context).flatMap(f)
+ firstParent[T].iterator(split, context).flatMap(f)
}
diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala
index fac8ffb4cb..051bffed19 100644
--- a/core/src/main/scala/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala
@@ -1,12 +1,12 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
+private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T])
+ extends RDD[Array[T]](prev) {
+
+ override def getSplits = firstParent[T].splits
-private[spark]
-class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split, context: TaskContext) =
- Array(prev.iterator(split, context).toArray).iterator
-} \ No newline at end of file
+ Array(firstParent[T].iterator(split, context).toArray).iterator
+}
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index ab163f569b..f547f53812 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -22,9 +22,8 @@ import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskCo
* A Spark split class that wraps around a Hadoop InputSplit.
*/
private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
- extends Split
- with Serializable {
-
+ extends Split {
+
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@@ -43,7 +42,7 @@ class HadoopRDD[K, V](
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int)
- extends RDD[(K, V)](sc) {
+ extends RDD[(K, V)](sc, Nil) {
// A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
val confBroadcast = sc.broadcast(new SerializableWritable(conf))
@@ -64,7 +63,7 @@ class HadoopRDD[K, V](
.asInstanceOf[InputFormat[K, V]]
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopSplit]
@@ -110,11 +109,13 @@ class HadoopRDD[K, V](
}
}
- override def preferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Split) = {
// TODO: Filtering out "localhost" in case of file:// URLs
val hadoopSplit = split.asInstanceOf[HadoopSplit]
hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
}
- override val dependencies: List[Dependency[_]] = Nil
+ override def checkpoint() {
+ // Do nothing. Hadoop RDD should not be checkpointed.
+ }
}
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
index c764505345..073f7d7d2a 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala
@@ -1,6 +1,6 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
private[spark]
@@ -8,11 +8,13 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false)
- extends RDD[U](prev.context) {
+ extends RDD[U](prev) {
- override val partitioner = if (preservesPartitioning) prev.partitioner else None
+ override val partitioner =
+ if (preservesPartitioning) firstParent[T].partitioner else None
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, context: TaskContext) = f(prev.iterator(split, context))
+ override def getSplits = firstParent[T].splits
+
+ override def compute(split: Split, context: TaskContext) =
+ f(firstParent[T].iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
index 3d9888bd34..2ddc3d01b6 100644
--- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
+++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala
@@ -1,6 +1,7 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
+
/**
* A variant of the MapPartitionsRDD that passes the split index into the
@@ -11,12 +12,13 @@ private[spark]
class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean)
- extends RDD[U](prev.context) {
+ preservesPartitioning: Boolean
+ ) extends RDD[U](prev) {
+
+ override def getSplits = firstParent[T].splits
override val partitioner = if (preservesPartitioning) prev.partitioner else None
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
+
override def compute(split: Split, context: TaskContext) =
- f(split.index, prev.iterator(split, context))
+ f(split.index, firstParent[T].iterator(split, context))
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala
index 70fa8f4497..c6ceb272cd 100644
--- a/core/src/main/scala/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/spark/rdd/MappedRDD.scala
@@ -1,14 +1,15 @@
package spark.rdd
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
+import spark.{RDD, Split, TaskContext}
private[spark]
class MappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => U)
- extends RDD[U](prev.context) {
+ extends RDD[U](prev) {
- override def splits = prev.splits
- override val dependencies = List(new OneToOneDependency(prev))
- override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).map(f)
+ override def getSplits = firstParent[T].splits
+
+ override def compute(split: Split, context: TaskContext) =
+ firstParent[T].iterator(split, context).map(f)
} \ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
index 197ed5ea17..c3b155fcbd 100644
--- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala
@@ -20,11 +20,12 @@ class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit
}
class NewHadoopRDD[K, V](
- sc: SparkContext,
+ sc : SparkContext,
inputFormatClass: Class[_ <: InputFormat[K, V]],
- keyClass: Class[K], valueClass: Class[V],
+ keyClass: Class[K],
+ valueClass: Class[V],
@transient conf: Configuration)
- extends RDD[(K, V)](sc)
+ extends RDD[(K, V)](sc, Nil)
with HadoopMapReduceUtil {
// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
@@ -36,11 +37,9 @@ class NewHadoopRDD[K, V](
formatter.format(new Date())
}
- @transient
- private val jobId = new JobID(jobtrackerId, id)
+ @transient private val jobId = new JobID(jobtrackerId, id)
- @transient
- private val splits_ : Array[Split] = {
+ @transient private val splits_ : Array[Split] = {
val inputFormat = inputFormatClass.newInstance
val jobContext = newJobContext(conf, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
@@ -51,7 +50,7 @@ class NewHadoopRDD[K, V](
result
}
- override def splits = splits_
+ override def getSplits = splits_
override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit]
@@ -86,10 +85,8 @@ class NewHadoopRDD[K, V](
}
}
- override def preferredLocations(split: Split) = {
+ override def getPreferredLocations(split: Split) = {
val theSplit = split.asInstanceOf[NewHadoopSplit]
theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
}
-
- override val dependencies: List[Dependency[_]] = Nil
}
diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala
index 336e193217..6631f83510 100644
--- a/core/src/main/scala/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/spark/rdd/PipedRDD.scala
@@ -8,7 +8,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
-import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
+import spark.{RDD, SparkEnv, Split, TaskContext}
/**
@@ -16,18 +16,18 @@ import spark.{OneToOneDependency, RDD, SparkEnv, Split, TaskContext}
* (printing them one per line) and returns the output as a collection of strings.
*/
class PipedRDD[T: ClassManifest](
- parent: RDD[T], command: Seq[String], envVars: Map[String, String])
- extends RDD[String](parent.context) {
+ prev: RDD[T],
+ command: Seq[String],
+ envVars: Map[String, String])
+ extends RDD[String](prev) {
- def this(parent: RDD[T], command: Seq[String]) = this(parent, command, Map())
+ def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map())
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String) = this(parent, PipedRDD.tokenize(command))
+ def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
- override def splits = parent.splits
-
- override val dependencies = List(new OneToOneDependency(parent))
+ override def getSplits = firstParent[T].splits
override def compute(split: Split, context: TaskContext): Iterator[String] = {
val pb = new ProcessBuilder(command)
@@ -52,7 +52,7 @@ class PipedRDD[T: ClassManifest](
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
- for (elem <- parent.iterator(split, context)) {
+ for (elem <- firstParent[T].iterator(split, context)) {
out.println(elem)
}
out.close()
diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala
index 6e4797aabb..e24ad23b21 100644
--- a/core/src/main/scala/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/spark/rdd/SampledRDD.scala
@@ -1,11 +1,11 @@
package spark.rdd
import java.util.Random
+
import cern.jet.random.Poisson
import cern.jet.random.engine.DRand
-import spark.{OneToOneDependency, RDD, Split, TaskContext}
-
+import spark.{RDD, Split, TaskContext}
private[spark]
class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable {
@@ -14,23 +14,20 @@ class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Seriali
class SampledRDD[T: ClassManifest](
prev: RDD[T],
- withReplacement: Boolean,
+ withReplacement: Boolean,
frac: Double,
seed: Int)
- extends RDD[T](prev.context) {
+ extends RDD[T](prev) {
- @transient
- val splits_ = {
+ @transient var splits_ : Array[Split] = {
val rg = new Random(seed)
- prev.splits.map(x => new SampledRDDSplit(x, rg.nextInt))
+ firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt))
}
- override def splits = splits_.asInstanceOf[Array[Split]]
-
- override val dependencies = List(new OneToOneDependency(prev))
+ override def getSplits = splits_
- override def preferredLocations(split: Split) =
- prev.preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
+ override def getPreferredLocations(split: Split) =
+ firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev)
override def compute(splitIn: Split, context: TaskContext) = {
val split = splitIn.asInstanceOf[SampledRDDSplit]
@@ -38,7 +35,7 @@ class SampledRDD[T: ClassManifest](
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
val poisson = new Poisson(frac, new DRand(split.seed))
- prev.iterator(split.prev, context).flatMap { element =>
+ firstParent[T].iterator(split.prev, context).flatMap { element =>
val count = poisson.nextInt()
if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
@@ -48,7 +45,11 @@ class SampledRDD[T: ClassManifest](
}
} else { // Sampling without replacement
val rand = new Random(split.seed)
- prev.iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
+ firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
}
}
+
+ override def clearDependencies() {
+ splits_ = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index f832633646..28ff19876d 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -1,7 +1,7 @@
package spark.rdd
-import spark.{OneToOneDependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
-
+import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext}
+import spark.SparkContext._
private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx
@@ -10,28 +10,28 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
/**
* The resulting RDD from a shuffle (e.g. repartitioning of data).
- * @param parent the parent RDD.
+ * @param prev the parent RDD.
* @param part the partitioner used to partition the RDD
* @tparam K the key class.
* @tparam V the value class.
*/
class ShuffledRDD[K, V](
- @transient parent: RDD[(K, V)],
- part: Partitioner) extends RDD[(K, V)](parent.context) {
+ prev: RDD[(K, V)],
+ part: Partitioner)
+ extends RDD[(K, V)](prev.context, List(new ShuffleDependency(prev, part))) {
override val partitioner = Some(part)
- @transient
- val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
-
- override def splits = splits_
-
- override def preferredLocations(split: Split) = Nil
+ @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
- val dep = new ShuffleDependency(parent, part)
- override val dependencies = List(dep)
+ override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
- SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index)
+ val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
+ SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
+ }
+
+ override def clearDependencies() {
+ splits_ = null
}
}
diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala
index a08473f7be..82f0a44ecd 100644
--- a/core/src/main/scala/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/spark/rdd/UnionRDD.scala
@@ -1,43 +1,46 @@
package spark.rdd
import scala.collection.mutable.ArrayBuffer
-
import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
+ extends Split {
-private[spark] class UnionSplit[T: ClassManifest](
- idx: Int,
- rdd: RDD[T],
- split: Split)
- extends Split
- with Serializable {
+ var split: Split = rdd.splits(splitIndex)
def iterator(context: TaskContext) = rdd.iterator(split, context)
+
def preferredLocations() = rdd.preferredLocations(split)
+
override val index: Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split = rdd.splits(splitIndex)
+ oos.defaultWriteObject()
+ }
}
class UnionRDD[T: ClassManifest](
sc: SparkContext,
- @transient rdds: Seq[RDD[T]])
- extends RDD[T](sc)
- with Serializable {
+ @transient var rdds: Seq[RDD[T]])
+ extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
- @transient
- val splits_ : Array[Split] = {
+ @transient var splits_ : Array[Split] = {
val array = new Array[Split](rdds.map(_.splits.size).sum)
var pos = 0
for (rdd <- rdds; split <- rdd.splits) {
- array(pos) = new UnionSplit(pos, rdd, split)
+ array(pos) = new UnionSplit(pos, rdd, split.index)
pos += 1
}
array
}
- override def splits = splits_
+ override def getSplits = splits_
- @transient
- override val dependencies = {
+ @transient var deps_ = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
@@ -47,9 +50,17 @@ class UnionRDD[T: ClassManifest](
deps.toList
}
+ override def getDependencies = deps_
+
override def compute(s: Split, context: TaskContext): Iterator[T] =
s.asInstanceOf[UnionSplit[T]].iterator(context)
- override def preferredLocations(s: Split): Seq[String] =
+ override def getPreferredLocations(s: Split): Seq[String] =
s.asInstanceOf[UnionSplit[T]].preferredLocations()
+
+ override def clearDependencies() {
+ deps_ = null
+ splits_ = null
+ rdds = null
+ }
}
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
index 92d667ff1e..d950b06c85 100644
--- a/core/src/main/scala/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -1,53 +1,65 @@
package spark.rdd
import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
idx: Int,
- rdd1: RDD[T],
- rdd2: RDD[U],
- split1: Split,
- split2: Split)
- extends Split
- with Serializable {
+ @transient rdd1: RDD[T],
+ @transient rdd2: RDD[U]
+ ) extends Split {
- def iterator(context: TaskContext): Iterator[(T, U)] =
- rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
+ var split1 = rdd1.splits(idx)
+ var split2 = rdd1.splits(idx)
+ override val index: Int = idx
- def preferredLocations(): Seq[String] =
- rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+ def splits = (split1, split2)
- override val index: Int = idx
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split1 = rdd1.splits(idx)
+ split2 = rdd2.splits(idx)
+ oos.defaultWriteObject()
+ }
}
class ZippedRDD[T: ClassManifest, U: ClassManifest](
sc: SparkContext,
- @transient rdd1: RDD[T],
- @transient rdd2: RDD[U])
- extends RDD[(T, U)](sc)
+ var rdd1: RDD[T],
+ var rdd2: RDD[U])
+ extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
with Serializable {
- @transient
- val splits_ : Array[Split] = {
+ // TODO: FIX THIS.
+
+ @transient var splits_ : Array[Split] = {
if (rdd1.splits.size != rdd2.splits.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
val array = new Array[Split](rdd1.splits.size)
for (i <- 0 until rdd1.splits.size) {
- array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i))
+ array(i) = new ZippedSplit(i, rdd1, rdd2)
}
array
}
- override def splits = splits_
+ override def getSplits = splits_
- @transient
- override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
+ override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
+ val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
+ }
- override def compute(s: Split, context: TaskContext): Iterator[(T, U)] =
- s.asInstanceOf[ZippedSplit[T, U]].iterator(context)
+ override def getPreferredLocations(s: Split): Seq[String] = {
+ val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
+ rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+ }
- override def preferredLocations(s: Split): Seq[String] =
- s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()
+ override def clearDependencies() {
+ splits_ = null
+ rdd1 = null
+ rdd2 = null
+ }
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 29757b1178..b320be8863 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -14,6 +14,7 @@ import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
import spark.storage.BlockManagerMaster
import spark.storage.BlockManagerId
+import util.{MetadataCleaner, TimeStampedHashMap}
/**
* A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
@@ -61,28 +62,35 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val nextStageId = new AtomicInteger(0)
- val idToStage = new HashMap[Int, Stage]
+ val idToStage = new TimeStampedHashMap[Int, Stage]
- val shuffleToMapStage = new HashMap[Int, Stage]
+ val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
var cacheLocs = new HashMap[Int, Array[List[String]]]
val env = SparkEnv.get
- val cacheTracker = env.cacheTracker
val mapOutputTracker = env.mapOutputTracker
+ val blockManagerMaster = env.blockManager.master
- val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
- // that's not going to be a realistic assumption in general
+ // For tracking failed nodes, we use the MapOutputTracker's generation number, which is
+ // sent with every task. When we detect a node failing, we note the current generation number
+ // and failed host, increment it for new tasks, and use this to ignore stray ShuffleMapTask
+ // results.
+ // TODO: Garbage collect information about failure generations when we know there are no more
+ // stray messages to detect.
+ val failedGeneration = new HashMap[String, Long]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
- val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
+ val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits
val activeJobs = new HashSet[ActiveJob]
val resultStageToJob = new HashMap[Stage, ActiveJob]
+ val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
+
// Start a thread to run the DAGScheduler event loop
new Thread("DAGScheduler") {
setDaemon(true)
@@ -92,11 +100,17 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}.start()
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ if (!cacheLocs.contains(rdd.id)) {
+ val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
+ cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
+ locations => locations.map(_.ip).toList
+ }.toArray
+ }
cacheLocs(rdd.id)
}
- def updateCacheLocs() {
- cacheLocs = cacheTracker.getLocationsSnapshot()
+ def clearCacheLocs() {
+ cacheLocs.clear
}
/**
@@ -123,7 +137,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
- cacheTracker.registerRDD(rdd.id, rdd.splits.size)
if (shuffleDep != None) {
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size)
}
@@ -145,8 +158,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visited += r
// Kind of ugly: need to register RDDs with the cache here since
// we can't do it in its constructor because # of splits is unknown
- logInfo("Registering parent RDD " + r.id + " (" + r.origin + ")")
- cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
@@ -247,7 +258,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val runId = nextRunId.getAndIncrement()
val finalStage = newStage(finalRDD, None, runId)
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
- updateCacheLocs()
+ clearCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
" output partitions")
logInfo("Final stage: " + finalStage + " (" + finalStage.origin + ")")
@@ -290,7 +301,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
logInfo("Resubmitting failed stages")
- updateCacheLocs()
+ clearCacheLocs()
val failed2 = failed.toArray
failed.clear()
for (stage <- failed2.sortBy(_.priority)) {
@@ -426,7 +437,9 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val status = event.result.asInstanceOf[MapStatus]
val host = status.address.ip
logInfo("ShuffleMapTask finished with host " + host)
- if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos
+ if (failedGeneration.contains(host) && smt.generation <= failedGeneration(host)) {
+ logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + host)
+ } else {
stage.addOutputLoc(smt.partition, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
@@ -436,11 +449,18 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
logInfo("waiting: " + waiting)
logInfo("failed: " + failed)
if (stage.shuffleDep != None) {
+ // We supply true to increment the generation number here in case this is a
+ // recomputation of the map outputs. In that case, some nodes may have cached
+ // locations with holes (from when we detected the error) and will need the
+ // generation incremented to refetch them.
+ // TODO: Only increment the generation number if this is not the first time
+ // we registered these map outputs.
mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId,
- stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray)
+ stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
+ true)
}
- updateCacheLocs()
+ clearCacheLocs()
if (stage.outputLocs.count(_ == Nil) != 0) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
@@ -492,7 +512,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
// TODO: mark the host as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- handleHostLost(bmAddress.ip)
+ handleHostLost(bmAddress.ip, Some(task.generation))
}
case other =>
@@ -504,11 +524,15 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
/**
* Responds to a host being lost. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use hostLost() to post a host lost event from outside.
+ *
+ * Optionally the generation during which the failure was caught can be passed to avoid allowing
+ * stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- def handleHostLost(host: String) {
- if (!deadHosts.contains(host)) {
- logInfo("Host lost: " + host)
- deadHosts += host
+ def handleHostLost(host: String, maybeGeneration: Option[Long] = None) {
+ val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration)
+ if (!failedGeneration.contains(host) || failedGeneration(host) < currentGeneration) {
+ failedGeneration(host) = currentGeneration
+ logInfo("Host lost: " + host + " (generation " + currentGeneration + ")")
env.blockManager.master.notifyADeadHost(host)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
@@ -516,8 +540,13 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
}
- cacheTracker.cacheLost(host)
- updateCacheLocs()
+ if (shuffleToMapStage.isEmpty) {
+ mapOutputTracker.incrementGeneration()
+ }
+ clearCacheLocs()
+ } else {
+ logDebug("Additional host lost message for " + host +
+ "(generation " + currentGeneration + ")")
}
}
@@ -594,8 +623,23 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
return Nil
}
+ def cleanup(cleanupTime: Long) {
+ var sizeBefore = idToStage.size
+ idToStage.clearOldValues(cleanupTime)
+ logInfo("idToStage " + sizeBefore + " --> " + idToStage.size)
+
+ sizeBefore = shuffleToMapStage.size
+ shuffleToMapStage.clearOldValues(cleanupTime)
+ logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size)
+
+ sizeBefore = pendingTasks.size
+ pendingTasks.clearOldValues(cleanupTime)
+ logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
+ }
+
def stop() {
eventQueue.put(StopDAGScheduler)
+ metadataCleaner.cancel()
taskSched.stop()
}
}
diff --git a/core/src/main/scala/spark/scheduler/MapStatus.scala b/core/src/main/scala/spark/scheduler/MapStatus.scala
index 4532d9497f..fae643f3a8 100644
--- a/core/src/main/scala/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/spark/scheduler/MapStatus.scala
@@ -20,7 +20,7 @@ private[spark] class MapStatus(var address: BlockManagerId, var compressedSizes:
}
def readExternal(in: ObjectInput) {
- address = new BlockManagerId(in)
+ address = BlockManagerId(in)
compressedSizes = new Array[Byte](in.readInt())
in.readFully(compressedSizes)
}
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index e492279b4e..8cd4c661eb 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -1,26 +1,112 @@
package spark.scheduler
import spark._
+import java.io._
+import util.{MetadataCleaner, TimeStampedHashMap}
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+private[spark] object ResultTask {
+
+ // A simple map between the stage id to the serialized byte array of a task.
+ // Served as a cache for task serialization because serialization can be
+ // expensive on the master node if it needs to launch thousands of tasks.
+ val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
+
+ val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)
+
+ def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
+ synchronized {
+ val old = serializedInfoCache.get(stageId).orNull
+ if (old != null) {
+ return old
+ } else {
+ val out = new ByteArrayOutputStream
+ val ser = SparkEnv.get.closureSerializer.newInstance
+ val objOut = ser.serializeStream(new GZIPOutputStream(out))
+ objOut.writeObject(rdd)
+ objOut.writeObject(func)
+ objOut.close()
+ val bytes = out.toByteArray
+ serializedInfoCache.put(stageId, bytes)
+ return bytes
+ }
+ }
+ }
+
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
+ synchronized {
+ val loader = Thread.currentThread.getContextClassLoader
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val ser = SparkEnv.get.closureSerializer.newInstance
+ val objIn = ser.deserializeStream(in)
+ val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+ val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
+ return (rdd, func)
+ }
+ }
+
+ def clearCache() {
+ synchronized {
+ serializedInfoCache.clear()
+ }
+ }
+}
+
private[spark] class ResultTask[T, U](
stageId: Int,
- rdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- val partition: Int,
+ var rdd: RDD[T],
+ var func: (TaskContext, Iterator[T]) => U,
+ var partition: Int,
@transient locs: Seq[String],
val outputId: Int)
- extends Task[U](stageId) {
+ extends Task[U](stageId) with Externalizable {
- val split = rdd.splits(partition)
+ def this() = this(0, null, null, 0, null, 0)
+
+ var split = if (rdd == null) {
+ null
+ } else {
+ rdd.splits(partition)
+ }
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
- val result = func(context, rdd.iterator(split, context))
- context.executeOnCompleteCallbacks()
- result
+ try {
+ func(context, rdd.iterator(split, context))
+ } finally {
+ context.executeOnCompleteCallbacks()
+ }
}
override def preferredLocations: Seq[String] = locs
override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+
+ override def writeExternal(out: ObjectOutput) {
+ RDDCheckpointData.synchronized {
+ split = rdd.splits(partition)
+ out.writeInt(stageId)
+ val bytes = ResultTask.serializeInfo(
+ stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeInt(outputId)
+ out.writeObject(split)
+ }
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val stageId = in.readInt()
+ val numBytes = in.readInt()
+ val bytes = new Array[Byte](numBytes)
+ in.readFully(bytes)
+ val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
+ rdd = rdd_.asInstanceOf[RDD[T]]
+ func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
+ partition = in.readInt()
+ val outputId = in.readInt()
+ split = in.readObject().asInstanceOf[Split]
+ }
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index bd1911fce2..19f5328eee 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -14,17 +14,20 @@ import com.ning.compress.lzf.LZFOutputStream
import spark._
import spark.storage._
+import util.{TimeStampedHashMap, MetadataCleaner}
private[spark] object ShuffleMapTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
- val serializedInfoCache = new JHashMap[Int, Array[Byte]]
+ val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
+
+ val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
- val old = serializedInfoCache.get(stageId)
+ val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
return old
} else {
@@ -87,13 +90,16 @@ private[spark] class ShuffleMapTask(
}
override def writeExternal(out: ObjectOutput) {
- out.writeInt(stageId)
- val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
- out.writeInt(bytes.length)
- out.write(bytes)
- out.writeInt(partition)
- out.writeLong(generation)
- out.writeObject(split)
+ RDDCheckpointData.synchronized {
+ split = rdd.splits(partition)
+ out.writeInt(stageId)
+ val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeLong(generation)
+ out.writeObject(split)
+ }
}
override def readExternal(in: ObjectInput) {
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 20f6e65020..a639b72795 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -252,19 +252,24 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
def slaveLost(slaveId: String, reason: ExecutorLossReason) {
var failedHost: Option[String] = None
synchronized {
- val host = slaveIdToHost(slaveId)
- if (hostsAlive.contains(host)) {
- logError("Lost an executor on " + host + ": " + reason)
- slaveIdsWithExecutors -= slaveId
- hostsAlive -= host
- activeTaskSetsQueue.foreach(_.hostLost(host))
- failedHost = Some(host)
- } else {
- // We may get multiple slaveLost() calls with different loss reasons. For example, one
- // may be triggered by a dropped connection from the slave while another may be a report
- // of executor termination from Mesos. We produce log messages for both so we eventually
- // report the termination reason.
- logError("Lost an executor on " + host + " (already removed): " + reason)
+ slaveIdToHost.get(slaveId) match {
+ case Some(host) =>
+ if (hostsAlive.contains(host)) {
+ logError("Lost an executor on " + host + ": " + reason)
+ slaveIdsWithExecutors -= slaveId
+ hostsAlive -= host
+ activeTaskSetsQueue.foreach(_.hostLost(host))
+ failedHost = Some(host)
+ } else {
+ // We may get multiple slaveLost() calls with different loss reasons. For example, one
+ // may be triggered by a dropped connection from the slave while another may be a report
+ // of executor termination from Mesos. We produce log messages for both so we eventually
+ // report the termination reason.
+ logError("Lost an executor on " + host + " (already removed): " + reason)
+ }
+ case None =>
+ // We were told about a slave being lost before we could even allocate work to it
+ logError("Lost slave " + slaveId + " (no work assigned yet)")
}
}
if (failedHost != None) {
diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index e2301347e5..4f82cd96dd 100644
--- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -39,7 +39,8 @@ private[spark] class SparkDeploySchedulerBackend(
StandaloneSchedulerBackend.ACTOR_NAME)
val args = Seq(masterUrl, "{{SLAVEID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
- val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command)
+ val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone"))
+ val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome)
client = new Client(sc.env.actorSystem, master, jobDesc, this)
client.start()
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index cf4aae03a7..a089b71644 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -201,7 +201,11 @@ private[spark] class TaskSetManager(
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val preferred = isPreferredLocation(task, host)
- val prefStr = if (preferred) "preferred" else "non-preferred"
+ val prefStr = if (preferred) {
+ "preferred"
+ } else {
+ "non-preferred, not one of " + task.preferredLocations.mkString(", ")
+ }
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
taskSet.id, index, taskId, slaveId, host, prefStr))
// Do various bookkeeping
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index eb20fe41b2..9ff7c02097 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -20,7 +20,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
with Logging {
var attemptId = new AtomicInteger(0)
- var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
+ var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get
var listener: TaskSchedulerListener = null
@@ -81,7 +81,10 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished task " + idInJob)
- listener.taskEnded(task, Success, resultToReturn, accumUpdates)
+
+ // If the threadpool has not already been shutdown, notify DAGScheduler
+ if (!Thread.currentThread().isInterrupted)
+ listener.taskEnded(task, Success, resultToReturn, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
@@ -91,7 +94,8 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
- listener.taskEnded(task, new ExceptionFailure(t), null, null)
+ if (!Thread.currentThread().isInterrupted)
+ listener.taskEnded(task, new ExceptionFailure(t), null, null)
}
}
}
@@ -108,22 +112,24 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- // Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentFiles(name) = timestamp
- }
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File("."))
- currentJars(name) = timestamp
- // Add it to our class loader
- val localName = name.split("/").last
- val url = new File(".", localName).toURI.toURL
- if (!classLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- classLoader.addURL(url)
+ synchronized {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentFiles(name) = timestamp
+ }
+ for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
+ if (!classLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ classLoader.addURL(url)
+ }
}
}
}
diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
index c45c7df69c..014906b028 100644
--- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
@@ -64,13 +64,9 @@ private[spark] class CoarseMesosSchedulerBackend(
val taskIdToSlaveId = new HashMap[Int, String]
val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed
- val sparkHome = sc.getSparkHome() match {
- case Some(path) =>
- path
- case None =>
- throw new SparkException("Spark home is not set; set it through the spark.home system " +
- "property, the SPARK_HOME environment variable or the SparkContext constructor")
- }
+ val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
+ "Spark home is not set; set it through the spark.home system " +
+ "property, the SPARK_HOME environment variable or the SparkContext constructor"))
val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt
@@ -184,7 +180,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Helper function to pull out a resource from a Mesos Resources protobuf */
- def getResource(res: JList[Resource], name: String): Double = {
+ private def getResource(res: JList[Resource], name: String): Double = {
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
@@ -193,7 +189,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Build a Mesos resource protobuf object */
- def createResource(resourceName: String, quantity: Double): Protos.Resource = {
+ private def createResource(resourceName: String, quantity: Double): Protos.Resource = {
Resource.newBuilder()
.setName(resourceName)
.setType(Value.Type.SCALAR)
@@ -202,7 +198,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
/** Check whether a Mesos task state represents a finished task */
- def isFinished(state: MesosTaskState) = {
+ private def isFinished(state: MesosTaskState) = {
state == MesosTaskState.TASK_FINISHED ||
state == MesosTaskState.TASK_FAILED ||
state == MesosTaskState.TASK_KILLED ||
diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
index 8c7a1dfbc0..2989e31f5e 100644
--- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala
@@ -76,13 +76,9 @@ private[spark] class MesosSchedulerBackend(
}
def createExecutorInfo(): ExecutorInfo = {
- val sparkHome = sc.getSparkHome() match {
- case Some(path) =>
- path
- case None =>
- throw new SparkException("Spark home is not set; set it through the spark.home system " +
- "property, the SPARK_HOME environment variable or the SparkContext constructor")
- }
+ val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
+ "Spark home is not set; set it through the spark.home system " +
+ "property, the SPARK_HOME environment variable or the SparkContext constructor"))
val execScript = new File(sparkHome, "spark-executor").getCanonicalPath
val environment = Environment.newBuilder()
sc.executorEnvs.foreach { case (key, value) =>
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index df295b1820..19cdaaa984 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -1,59 +1,39 @@
package spark.storage
-import akka.actor.{ActorSystem, Cancellable}
-import akka.dispatch.{Await, Future}
-import akka.util.Duration
-import akka.util.duration._
-
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-
-import java.io.{InputStream, OutputStream, Externalizable, ObjectInput, ObjectOutput}
-import java.nio.{MappedByteBuffer, ByteBuffer}
+import java.io.{InputStream, OutputStream}
+import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
import scala.collection.JavaConversions._
-import spark.{CacheTracker, Logging, SizeEstimator, SparkEnv, SparkException, Utils}
-import spark.network._
-import spark.serializer.Serializer
-import spark.util.ByteBufferInputStream
-import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
-import sun.nio.ch.DirectBuffer
-
-
-private[spark] class BlockManagerId(var ip: String, var port: Int) extends Externalizable {
- def this() = this(null, 0) // For deserialization only
-
- def this(in: ObjectInput) = this(in.readUTF(), in.readInt())
+import akka.actor.{ActorSystem, Cancellable, Props}
+import akka.dispatch.{Await, Future}
+import akka.util.Duration
+import akka.util.duration._
- override def writeExternal(out: ObjectOutput) {
- out.writeUTF(ip)
- out.writeInt(port)
- }
+import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
- override def readExternal(in: ObjectInput) {
- ip = in.readUTF()
- port = in.readInt()
- }
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
- override def toString = "BlockManagerId(" + ip + ", " + port + ")"
+import spark.{Logging, SizeEstimator, SparkEnv, SparkException, Utils}
+import spark.network._
+import spark.serializer.Serializer
+import spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap}
- override def hashCode = ip.hashCode * 41 + port
+import sun.nio.ch.DirectBuffer
- override def equals(that: Any) = that match {
- case id: BlockManagerId => port == id.port && ip == id.ip
- case _ => false
- }
-}
private[spark]
case class BlockException(blockId: String, message: String, ex: Exception = null)
extends Exception(message)
private[spark]
-class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
- val serializer: Serializer, maxMemory: Long)
+class BlockManager(
+ actorSystem: ActorSystem,
+ val master: BlockManagerMaster,
+ val serializer: Serializer,
+ maxMemory: Long)
extends Logging {
class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
@@ -79,7 +59,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
}
- private val blockInfo = new ConcurrentHashMap[String, BlockInfo](1000)
+ private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore: BlockStore =
@@ -89,10 +69,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
implicit val futureExecContext = connectionManager.futureExecContext
val connectionManagerId = connectionManager.id
- val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port)
-
- // TODO: This will be removed after cacheTracker is removed from the code base.
- var cacheTracker: CacheTracker = null
+ val blockManagerId = BlockManagerId(connectionManagerId.host, connectionManagerId.port)
// Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
// for receiving shuffle outputs)
@@ -110,16 +87,20 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
val host = System.getProperty("spark.hostname", Utils.localHostName())
+ val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
+ name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
+
@volatile private var shuttingDown = false
private def heartBeat() {
- if (!master.mustHeartBeat(HeartBeat(blockManagerId))) {
+ if (!master.sendHeartBeat(blockManagerId)) {
reregister()
}
}
var heartBeatTask: Cancellable = null
+ val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks)
initialize()
/**
@@ -134,8 +115,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
* BlockManagerWorker actor.
*/
private def initialize() {
- master.mustRegisterBlockManager(
- RegisterBlockManager(blockManagerId, maxMemory))
+ master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
BlockManagerWorker.startBlockManagerWorker(this)
if (!BlockManager.getDisableHeartBeatsForTesting) {
heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) {
@@ -156,8 +136,8 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
*/
private def reportAllBlocks() {
logInfo("Reporting " + blockInfo.size + " blocks to the master.")
- for (blockId <- blockInfo.keys) {
- if (!tryToReportBlockStatus(blockId)) {
+ for ((blockId, info) <- blockInfo) {
+ if (!tryToReportBlockStatus(blockId, info)) {
logError("Failed to report " + blockId + " to master; giving up.")
return
}
@@ -171,26 +151,22 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
def reregister() {
// TODO: We might need to rate limit reregistering.
logInfo("BlockManager reregistering with master")
- master.mustRegisterBlockManager(
- RegisterBlockManager(blockManagerId, maxMemory))
+ master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
reportAllBlocks()
}
/**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
- def getLevel(blockId: String): StorageLevel = {
- val info = blockInfo.get(blockId)
- if (info != null) info.level else null
- }
+ def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
/**
* Tell the master about the current storage status of a block. This will send a block update
* message reflecting the current status, *not* the desired storage level in its block info.
* For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk.
*/
- def reportBlockStatus(blockId: String) {
- val needReregister = !tryToReportBlockStatus(blockId)
+ def reportBlockStatus(blockId: String, info: BlockInfo) {
+ val needReregister = !tryToReportBlockStatus(blockId, info)
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
// Reregistering will report our new block for free.
@@ -200,33 +176,27 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
/**
- * Actually send a BlockUpdate message. Returns the mater's response, which will be true if the
- * block was successfully recorded and false if the slave needs to re-register.
+ * Actually send a UpdateBlockInfo message. Returns the mater's response,
+ * which will be true if the block was successfully recorded and false if
+ * the slave needs to re-register.
*/
- private def tryToReportBlockStatus(blockId: String): Boolean = {
- val (curLevel, inMemSize, onDiskSize, tellMaster) = blockInfo.get(blockId) match {
- case null =>
- (StorageLevel.NONE, 0L, 0L, false)
- case info =>
- info.synchronized {
- info.level match {
- case null =>
- (StorageLevel.NONE, 0L, 0L, false)
- case level =>
- val inMem = level.useMemory && memoryStore.contains(blockId)
- val onDisk = level.useDisk && diskStore.contains(blockId)
- (
- new StorageLevel(onDisk, inMem, level.deserialized, level.replication),
- if (inMem) memoryStore.getSize(blockId) else 0L,
- if (onDisk) diskStore.getSize(blockId) else 0L,
- info.tellMaster
- )
- }
- }
+ private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = {
+ val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
+ info.level match {
+ case null =>
+ (StorageLevel.NONE, 0L, 0L, false)
+ case level =>
+ val inMem = level.useMemory && memoryStore.contains(blockId)
+ val onDisk = level.useDisk && diskStore.contains(blockId)
+ val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
+ val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
+ val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
+ (storageLevel, memSize, diskSize, info.tellMaster)
+ }
}
if (tellMaster) {
- master.mustBlockUpdate(BlockUpdate(blockManagerId, blockId, curLevel, inMemSize, onDiskSize))
+ master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)
} else {
true
}
@@ -238,7 +208,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
*/
def getLocations(blockId: String): Seq[String] = {
val startTimeMs = System.currentTimeMillis
- var managers = master.mustGetLocations(GetLocations(blockId))
+ var managers = master.getLocations(blockId)
val locations = managers.map(_.ip)
logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs))
return locations
@@ -249,8 +219,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
*/
def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
val startTimeMs = System.currentTimeMillis
- val locations = master.mustGetLocationsMultipleBlockIds(
- GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray
+ val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray
logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
@@ -272,7 +241,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
}
- val info = blockInfo.get(blockId)
+ val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
info.waitForReady() // In case the block is still being put() by another thread
@@ -357,7 +326,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
}
- val info = blockInfo.get(blockId)
+ val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
info.waitForReady() // In case the block is still being put() by another thread
@@ -413,7 +382,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
logDebug("Getting remote block " + blockId)
// Get locations of block
- val locations = master.mustGetLocations(GetLocations(blockId))
+ val locations = master.getLocations(blockId)
// Get block from remote locations
for (loc <- locations) {
@@ -615,7 +584,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
throw new IllegalArgumentException("Storage level is null or invalid")
}
- val oldBlock = blockInfo.get(blockId)
+ val oldBlock = blockInfo.get(blockId).orNull
if (oldBlock != null) {
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
oldBlock.waitForReady()
@@ -670,7 +639,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
// and tell the master about it.
myInfo.markReady(size)
if (tellMaster) {
- reportBlockStatus(blockId)
+ reportBlockStatus(blockId, myInfo)
}
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
@@ -690,10 +659,6 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
BlockManager.dispose(bytesAfterPut)
- // TODO: This code will be removed when CacheTracker is gone.
- if (blockId.startsWith("rdd")) {
- notifyCacheTracker(blockId)
- }
logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs))
return size
@@ -716,7 +681,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
throw new IllegalArgumentException("Storage level is null or invalid")
}
- if (blockInfo.containsKey(blockId)) {
+ if (blockInfo.contains(blockId)) {
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
return
}
@@ -757,15 +722,10 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
// and tell the master about it.
myInfo.markReady(bytes.limit)
if (tellMaster) {
- reportBlockStatus(blockId)
+ reportBlockStatus(blockId, myInfo)
}
}
- // TODO: This code will be removed when CacheTracker is gone.
- if (blockId.startsWith("rdd")) {
- notifyCacheTracker(blockId)
- }
-
// If replication had started, then wait for it to finish
if (level.replication > 1) {
if (replicationFuture == null) {
@@ -788,10 +748,9 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
*/
var cachedPeers: Seq[BlockManagerId] = null
private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
- val tLevel: StorageLevel =
- new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
+ val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
if (cachedPeers == null) {
- cachedPeers = master.mustGetPeers(GetPeers(blockManagerId, level.replication - 1))
+ cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
}
for (peer: BlockManagerId <- cachedPeers) {
val start = System.nanoTime
@@ -808,16 +767,6 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
}
- // TODO: This code will be removed when CacheTracker is gone.
- private def notifyCacheTracker(key: String) {
- if (cacheTracker != null) {
- val rddInfo = key.split("_")
- val rddId: Int = rddInfo(1).toInt
- val partition: Int = rddInfo(2).toInt
- cacheTracker.notifyFromBlockManager(spark.AddedToCache(rddId, partition, host))
- }
- }
-
/**
* Read a block consisting of a single object.
*/
@@ -838,7 +787,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
*/
def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) {
logInfo("Dropping block " + blockId + " from memory")
- val info = blockInfo.get(blockId)
+ val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
val level = info.level
@@ -851,9 +800,12 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
diskStore.putBytes(blockId, bytes, level)
}
}
- memoryStore.remove(blockId)
+ val blockWasRemoved = memoryStore.remove(blockId)
+ if (!blockWasRemoved) {
+ logWarning("Block " + blockId + " could not be dropped from memory as it does not exist")
+ }
if (info.tellMaster) {
- reportBlockStatus(blockId)
+ reportBlockStatus(blockId, info)
}
if (!level.useDisk) {
// The block is completely gone from this node; forget it so we can put() it again later.
@@ -865,6 +817,53 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
}
}
+ /**
+ * Remove a block from both memory and disk.
+ */
+ def removeBlock(blockId: String) {
+ logInfo("Removing block " + blockId)
+ val info = blockInfo.get(blockId).orNull
+ if (info != null) info.synchronized {
+ // Removals are idempotent in disk store and memory store. At worst, we get a warning.
+ val removedFromMemory = memoryStore.remove(blockId)
+ val removedFromDisk = diskStore.remove(blockId)
+ if (!removedFromMemory && !removedFromDisk) {
+ logWarning("Block " + blockId + " could not be removed as it was not found in either " +
+ "the disk or memory store")
+ }
+ blockInfo.remove(blockId)
+ if (info.tellMaster) {
+ reportBlockStatus(blockId, info)
+ }
+ } else {
+ // The block has already been removed; do nothing.
+ logWarning("Asked to remove block " + blockId + ", which does not exist")
+ }
+ }
+
+ def dropOldBlocks(cleanupTime: Long) {
+ logInfo("Dropping blocks older than " + cleanupTime)
+ val iterator = blockInfo.internalMap.entrySet().iterator()
+ while (iterator.hasNext) {
+ val entry = iterator.next()
+ val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
+ if (time < cleanupTime) {
+ info.synchronized {
+ val level = info.level
+ if (level.useMemory) {
+ memoryStore.remove(id)
+ }
+ if (level.useDisk) {
+ diskStore.remove(id)
+ }
+ iterator.remove()
+ logInfo("Dropped block " + id)
+ }
+ reportBlockStatus(id, info)
+ }
+ }
+ }
+
def shouldCompress(blockId: String): Boolean = {
if (blockId.startsWith("shuffle_")) {
compressShuffle
@@ -914,6 +913,7 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
heartBeatTask.cancel()
}
connectionManager.stop()
+ master.actorSystem.stop(slaveActor)
blockInfo.clear()
memoryStore.clear()
diskStore.clear()
@@ -923,6 +923,9 @@ class BlockManager(actorSystem: ActorSystem, val master: BlockManagerMaster,
private[spark]
object BlockManager extends Logging {
+
+ val ID_GENERATOR = new IdGenerator
+
def getMaxMemoryFromSystemProperties: Long = {
val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
(Runtime.getRuntime.maxMemory * memoryFraction).toLong
diff --git a/core/src/main/scala/spark/storage/BlockManagerId.scala b/core/src/main/scala/spark/storage/BlockManagerId.scala
new file mode 100644
index 0000000000..abb8b45a1f
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerId.scala
@@ -0,0 +1,70 @@
+package spark.storage
+
+import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
+import java.util.concurrent.ConcurrentHashMap
+
+/**
+ * This class represent an unique identifier for a BlockManager.
+ * The first 2 constructors of this class is made private to ensure that
+ * BlockManagerId objects can be created only using the factory method in
+ * [[spark.storage.BlockManager$]]. This allows de-duplication of id objects.
+ * Also, constructor parameters are private to ensure that parameters cannot
+ * be modified from outside this class.
+ */
+private[spark] class BlockManagerId private (
+ private var ip_ : String,
+ private var port_ : Int
+ ) extends Externalizable {
+
+ private def this() = this(null, 0) // For deserialization only
+
+ def ip = ip_
+
+ def port = port_
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeUTF(ip_)
+ out.writeInt(port_)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ ip_ = in.readUTF()
+ port_ = in.readInt()
+ }
+
+ @throws(classOf[IOException])
+ private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
+
+ override def toString = "BlockManagerId(" + ip + ", " + port + ")"
+
+ override def hashCode = ip.hashCode * 41 + port
+
+ override def equals(that: Any) = that match {
+ case id: BlockManagerId => port == id.port && ip == id.ip
+ case _ => false
+ }
+}
+
+
+private[spark] object BlockManagerId {
+
+ def apply(ip: String, port: Int) =
+ getCachedBlockManagerId(new BlockManagerId(ip, port))
+
+ def apply(in: ObjectInput) = {
+ val obj = new BlockManagerId()
+ obj.readExternal(in)
+ getCachedBlockManagerId(obj)
+ }
+
+ val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
+
+ def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
+ if (blockManagerIdCache.containsKey(id)) {
+ blockManagerIdCache.get(id)
+ } else {
+ blockManagerIdCache.put(id, id)
+ id
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index 0a4e68f437..a3d8671834 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -1,676 +1,167 @@
package spark.storage
-import java.io._
-import java.util.{HashMap => JHashMap}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.mutable.ArrayBuffer
import scala.util.Random
-import akka.actor._
-import akka.dispatch._
+import akka.actor.{Actor, ActorRef, ActorSystem, Props}
+import akka.dispatch.Await
import akka.pattern.ask
-import akka.remote._
import akka.util.{Duration, Timeout}
import akka.util.duration._
import spark.{Logging, SparkException, Utils}
-private[spark]
-sealed trait ToBlockManagerMaster
-
-private[spark]
-case class RegisterBlockManager(
- blockManagerId: BlockManagerId,
- maxMemSize: Long)
- extends ToBlockManagerMaster
-
-private[spark]
-case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
-
-private[spark]
-class BlockUpdate(
- var blockManagerId: BlockManagerId,
- var blockId: String,
- var storageLevel: StorageLevel,
- var memSize: Long,
- var diskSize: Long)
- extends ToBlockManagerMaster
- with Externalizable {
-
- def this() = this(null, null, null, 0, 0) // For deserialization only
-
- override def writeExternal(out: ObjectOutput) {
- blockManagerId.writeExternal(out)
- out.writeUTF(blockId)
- storageLevel.writeExternal(out)
- out.writeInt(memSize.toInt)
- out.writeInt(diskSize.toInt)
- }
-
- override def readExternal(in: ObjectInput) {
- blockManagerId = new BlockManagerId()
- blockManagerId.readExternal(in)
- blockId = in.readUTF()
- storageLevel = new StorageLevel()
- storageLevel.readExternal(in)
- memSize = in.readInt()
- diskSize = in.readInt()
- }
-}
-
-private[spark]
-object BlockUpdate {
- def apply(blockManagerId: BlockManagerId,
- blockId: String,
- storageLevel: StorageLevel,
- memSize: Long,
- diskSize: Long): BlockUpdate = {
- new BlockUpdate(blockManagerId, blockId, storageLevel, memSize, diskSize)
- }
-
- // For pattern-matching
- def unapply(h: BlockUpdate): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
- Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
- }
-}
-
-private[spark]
-case class GetLocations(blockId: String) extends ToBlockManagerMaster
-
-private[spark]
-case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster
-
-private[spark]
-case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
-
-private[spark]
-case class RemoveHost(host: String) extends ToBlockManagerMaster
-
-private[spark]
-case object StopBlockManagerMaster extends ToBlockManagerMaster
-
-private[spark]
-case object GetMemoryStatus extends ToBlockManagerMaster
-
-private[spark]
-case object ExpireDeadHosts extends ToBlockManagerMaster
-
-
-private[spark] class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
-
- class BlockManagerInfo(
- val blockManagerId: BlockManagerId,
- timeMs: Long,
- val maxMem: Long) {
- private var _lastSeenMs = timeMs
- private var _remainingMem = maxMem
- private val _blocks = new JHashMap[String, StorageLevel]
-
- logInfo("Registering block manager %s:%d with %s RAM".format(
- blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem)))
-
- def updateLastSeenMs() {
- _lastSeenMs = System.currentTimeMillis()
- }
-
- def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long)
- : Unit = synchronized {
-
- updateLastSeenMs()
-
- if (_blocks.containsKey(blockId)) {
- // The block exists on the slave already.
- val originalLevel: StorageLevel = _blocks.get(blockId)
-
- if (originalLevel.useMemory) {
- _remainingMem += memSize
- }
- }
-
- if (storageLevel.isValid) {
- // isValid means it is either stored in-memory or on-disk.
- _blocks.put(blockId, storageLevel)
- if (storageLevel.useMemory) {
- _remainingMem -= memSize
- logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
- Utils.memoryBytesToString(_remainingMem)))
- }
- if (storageLevel.useDisk) {
- logInfo("Added %s on disk on %s:%d (size: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
- }
- } else if (_blocks.containsKey(blockId)) {
- // If isValid is not true, drop the block.
- val originalLevel: StorageLevel = _blocks.get(blockId)
- _blocks.remove(blockId)
- if (originalLevel.useMemory) {
- _remainingMem += memSize
- logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
- Utils.memoryBytesToString(_remainingMem)))
- }
- if (originalLevel.useDisk) {
- logInfo("Removed %s on %s:%d on disk (size: %s)".format(
- blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
- }
- }
- }
-
- def remainingMem: Long = _remainingMem
-
- def lastSeenMs: Long = _lastSeenMs
-
- def blocks: JHashMap[String, StorageLevel] = _blocks
-
- override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
-
- def clear() {
- _blocks.clear()
- }
- }
-
- private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo]
- private val blockManagerIdByHost = new HashMap[String, BlockManagerId]
- private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
-
- initLogging()
-
- val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs",
- "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong
-
- val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs",
- "5000").toLong
-
- var timeoutCheckingTask: Cancellable = null
+private[spark] class BlockManagerMaster(
+ val actorSystem: ActorSystem,
+ isMaster: Boolean,
+ isLocal: Boolean,
+ masterIp: String,
+ masterPort: Int)
+ extends Logging {
- override def preStart() {
- if (!BlockManager.getDisableHeartBeatsForTesting) {
- timeoutCheckingTask = context.system.scheduler.schedule(
- 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
- }
- super.preStart()
- }
+ val AKKA_RETRY_ATTEMPS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
+ val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
- def removeBlockManager(blockManagerId: BlockManagerId) {
- val info = blockManagerInfo(blockManagerId)
- blockManagerIdByHost.remove(blockManagerId.ip)
- blockManagerInfo.remove(blockManagerId)
- var iterator = info.blocks.keySet.iterator
- while (iterator.hasNext) {
- val blockId = iterator.next
- val locations = blockInfo.get(blockId)._2
- locations -= blockManagerId
- if (locations.size == 0) {
- blockInfo.remove(locations)
- }
- }
- }
-
- def expireDeadHosts() {
- logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.")
- val now = System.currentTimeMillis()
- val minSeenTime = now - slaveTimeout
- val toRemove = new HashSet[BlockManagerId]
- for (info <- blockManagerInfo.values) {
- if (info.lastSeenMs < minSeenTime) {
- logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats")
- toRemove += info.blockManagerId
- }
- }
- // TODO: Remove corresponding block infos
- toRemove.foreach(removeBlockManager)
- }
-
- def removeHost(host: String) {
- logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
- logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
- blockManagerIdByHost.get(host).foreach(removeBlockManager)
- logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
- sender ! true
- }
+ val MASTER_AKKA_ACTOR_NAME = "BlockMasterManager"
+ val SLAVE_AKKA_ACTOR_NAME = "BlockSlaveManager"
+ val DEFAULT_MANAGER_IP: String = Utils.localHostName()
- def heartBeat(blockManagerId: BlockManagerId) {
- if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
- sender ! true
- } else {
- sender ! false
- }
+ val timeout = 10.seconds
+ var masterActor: ActorRef = {
+ if (isMaster) {
+ val masterActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
+ name = MASTER_AKKA_ACTOR_NAME)
+ logInfo("Registered BlockManagerMaster Actor")
+ masterActor
} else {
- blockManagerInfo(blockManagerId).updateLastSeenMs()
- sender ! true
+ val url = "akka://spark@%s:%s/user/%s".format(masterIp, masterPort, MASTER_AKKA_ACTOR_NAME)
+ logInfo("Connecting to BlockManagerMaster: " + url)
+ actorSystem.actorFor(url)
}
}
- def receive = {
- case RegisterBlockManager(blockManagerId, maxMemSize) =>
- register(blockManagerId, maxMemSize)
-
- case BlockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
- blockUpdate(blockManagerId, blockId, storageLevel, deserializedSize, size)
-
- case GetLocations(blockId) =>
- getLocations(blockId)
-
- case GetLocationsMultipleBlockIds(blockIds) =>
- getLocationsMultipleBlockIds(blockIds)
-
- case GetPeers(blockManagerId, size) =>
- getPeersDeterministic(blockManagerId, size)
- /*getPeers(blockManagerId, size)*/
-
- case GetMemoryStatus =>
- getMemoryStatus
-
- case RemoveHost(host) =>
- removeHost(host)
- sender ! true
-
- case StopBlockManagerMaster =>
- logInfo("Stopping BlockManagerMaster")
- sender ! true
- if (timeoutCheckingTask != null) {
- timeoutCheckingTask.cancel
- }
- context.stop(self)
-
- case ExpireDeadHosts =>
- expireDeadHosts()
-
- case HeartBeat(blockManagerId) =>
- heartBeat(blockManagerId)
-
- case other =>
- logInfo("Got unknown message: " + other)
+ /** Remove a dead host from the master actor. This is only called on the master side. */
+ def notifyADeadHost(host: String) {
+ tell(RemoveHost(host))
+ logInfo("Removed " + host + " successfully in notifyADeadHost")
}
- // Return a map from the block manager id to max memory and remaining memory.
- private def getMemoryStatus() {
- val res = blockManagerInfo.map { case(blockManagerId, info) =>
- (blockManagerId, (info.maxMem, info.remainingMem))
- }.toMap
- sender ! res
+ /**
+ * Send the master actor a heart beat from the slave. Returns true if everything works out,
+ * false if the master does not know about the given block manager, which means the block
+ * manager should re-register.
+ */
+ def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = {
+ askMasterWithRetry[Boolean](HeartBeat(blockManagerId))
}
- private def register(blockManagerId: BlockManagerId, maxMemSize: Long) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockManagerId + " "
- logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
- if (blockManagerIdByHost.contains(blockManagerId.ip) &&
- blockManagerIdByHost(blockManagerId.ip) != blockManagerId) {
- val oldId = blockManagerIdByHost(blockManagerId.ip)
- logInfo("Got second registration for host " + blockManagerId +
- "; removing old slave " + oldId)
- removeBlockManager(oldId)
- }
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
- logInfo("Got Register Msg from master node, don't register it")
- } else {
- blockManagerInfo += (blockManagerId -> new BlockManagerInfo(
- blockManagerId, System.currentTimeMillis(), maxMemSize))
- }
- blockManagerIdByHost += (blockManagerId.ip -> blockManagerId)
- logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs))
- sender ! true
+ /** Register the BlockManager's id with the master. */
+ def registerBlockManager(
+ blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ logInfo("Trying to register BlockManager")
+ tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor))
+ logInfo("Registered BlockManager")
}
- private def blockUpdate(
+ def updateBlockInfo(
blockManagerId: BlockManagerId,
blockId: String,
storageLevel: StorageLevel,
memSize: Long,
- diskSize: Long) {
-
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockManagerId + " " + blockId + " "
-
- if (!blockManagerInfo.contains(blockManagerId)) {
- if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
- // We intentionally do not register the master (except in local mode),
- // so we should not indicate failure.
- sender ! true
- } else {
- sender ! false
- }
- return
- }
-
- if (blockId == null) {
- blockManagerInfo(blockManagerId).updateLastSeenMs()
- logDebug("Got in block update 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs))
- sender ! true
- return
- }
-
- blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize)
-
- var locations: HashSet[BlockManagerId] = null
- if (blockInfo.containsKey(blockId)) {
- locations = blockInfo.get(blockId)._2
- } else {
- locations = new HashSet[BlockManagerId]
- blockInfo.put(blockId, (storageLevel.replication, locations))
- }
-
- if (storageLevel.isValid) {
- locations += blockManagerId
- } else {
- locations.remove(blockManagerId)
- }
-
- if (locations.size == 0) {
- blockInfo.remove(blockId)
- }
- sender ! true
+ diskSize: Long): Boolean = {
+ val res = askMasterWithRetry[Boolean](
+ UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize))
+ logInfo("Updated info of block " + blockId)
+ res
}
- private def getLocations(blockId: String) {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " " + blockId + " "
- logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs))
- if (blockInfo.containsKey(blockId)) {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(blockInfo.get(blockId)._2)
- logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at "
- + Utils.getUsedTimeMs(startTimeMs))
- sender ! res.toSeq
- } else {
- logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs))
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- sender ! res
- }
+ /** Get locations of the blockId from the master */
+ def getLocations(blockId: String): Seq[BlockManagerId] = {
+ askMasterWithRetry[Seq[BlockManagerId]](GetLocations(blockId))
}
- private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
- def getLocations(blockId: String): Seq[BlockManagerId] = {
- val tmp = blockId
- logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp)
- if (blockInfo.containsKey(blockId)) {
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(blockInfo.get(blockId)._2)
- logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq)
- return res.toSeq
- } else {
- logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp)
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- return res.toSeq
- }
- }
-
- logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq)
- var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
- for (blockId <- blockIds) {
- res.append(getLocations(blockId))
- }
- logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq)
- sender ! res.toSeq
+ /** Get locations of multiple blockIds from the master */
+ def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ askMasterWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
- private def getPeers(blockManagerId: BlockManagerId, size: Int) {
- var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
- res.appendAll(peers)
- res -= blockManagerId
- val rand = new Random(System.currentTimeMillis())
- while (res.length > size) {
- res.remove(rand.nextInt(res.length))
+ /** Get ids of other nodes in the cluster from the master */
+ def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
+ val result = askMasterWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
+ if (result.length != numPeers) {
+ throw new SparkException(
+ "Error getting peers, only got " + result.size + " instead of " + numPeers)
}
- sender ! res.toSeq
+ result
}
- private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) {
- var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
- var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
-
- val peersWithIndices = peers.zipWithIndex
- val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1)
- if (selfIndex == -1) {
- throw new Exception("Self index for " + blockManagerId + " not found")
- }
-
- var index = selfIndex
- while (res.size < size) {
- index += 1
- if (index == selfIndex) {
- throw new Exception("More peer expected than available")
- }
- res += peers(index % peers.size)
- }
- sender ! res.toSeq
+ /**
+ * Remove a block from the slaves that have it. This can only be used to remove
+ * blocks that the master knows about.
+ */
+ def removeBlock(blockId: String) {
+ askMasterWithRetry(RemoveBlock(blockId))
}
-}
-
-private[spark] class BlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean)
- extends Logging {
-
- val AKKA_ACTOR_NAME: String = "BlockMasterManager"
- val REQUEST_RETRY_INTERVAL_MS = 100
- val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost")
- val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt
- val DEFAULT_MANAGER_IP: String = Utils.localHostName()
- val timeout = 10.seconds
- var masterActor: ActorRef = null
-
- if (isMaster) {
- masterActor = actorSystem.actorOf(
- Props(new BlockManagerMasterActor(isLocal)), name = AKKA_ACTOR_NAME)
- logInfo("Registered BlockManagerMaster Actor")
- } else {
- val url = "akka://spark@%s:%s/user/%s".format(
- DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT, AKKA_ACTOR_NAME)
- logInfo("Connecting to BlockManagerMaster: " + url)
- masterActor = actorSystem.actorFor(url)
+ /**
+ * Return the memory status for each block manager, in the form of a map from
+ * the block manager's id to two long values. The first value is the maximum
+ * amount of memory allocated for the block manager, while the second is the
+ * amount of remaining memory.
+ */
+ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
+ askMasterWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
}
+ /** Stop the master actor, called only on the Spark master node */
def stop() {
if (masterActor != null) {
- communicate(StopBlockManagerMaster)
+ tell(StopBlockManagerMaster)
masterActor = null
logInfo("BlockManagerMaster stopped")
}
}
- // Send a message to the master actor and get its result within a default timeout, or
- // throw a SparkException if this fails.
- def askMaster(message: Any): Any = {
- try {
- val future = masterActor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with BlockManagerMaster", e)
- }
- }
-
- // Send a one-way message to the master actor, to which we expect it to reply with true.
- def communicate(message: Any) {
- if (askMaster(message) != true) {
- throw new SparkException("Error reply received from BlockManagerMaster")
- }
- }
-
- def notifyADeadHost(host: String) {
- communicate(RemoveHost(host))
- logInfo("Removed " + host + " successfully in notifyADeadHost")
- }
-
- def mustRegisterBlockManager(msg: RegisterBlockManager) {
- logInfo("Trying to register BlockManager")
- while (! syncRegisterBlockManager(msg)) {
- logWarning("Failed to register " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- }
- logInfo("Done registering BlockManager")
- }
-
- def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = {
- //val masterActor = RemoteActor.select(node, name)
- val startTimeMs = System.currentTimeMillis()
- val tmp = " msg " + msg + " "
- logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- try {
- communicate(msg)
- logInfo("BlockManager registered successfully @ syncRegisterBlockManager")
- logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- return true
- } catch {
- case e: Exception =>
- logError("Failed in syncRegisterBlockManager", e)
- return false
- }
- }
-
- def mustHeartBeat(msg: HeartBeat): Boolean = {
- var res = syncHeartBeat(msg)
- while (!res.isDefined) {
- logWarning("Failed to send heart beat " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
+ /** Send a one-way message to the master actor, to which we expect it to reply with true. */
+ private def tell(message: Any) {
+ if (!askMasterWithRetry[Boolean](message)) {
+ throw new SparkException("BlockManagerMasterActor returned false, expected true.")
}
- return res.get
}
- def syncHeartBeat(msg: HeartBeat): Option[Boolean] = {
- try {
- val answer = askMaster(msg).asInstanceOf[Boolean]
- return Some(answer)
- } catch {
- case e: Exception =>
- logError("Failed in syncHeartBeat", e)
- return None
+ /**
+ * Send a message to the master actor and get its result within a default timeout, or
+ * throw a SparkException if this fails.
+ */
+ private def askMasterWithRetry[T](message: Any): T = {
+ // TODO: Consider removing multiple attempts
+ if (masterActor == null) {
+ throw new SparkException("Error sending message to BlockManager as masterActor is null " +
+ "[message = " + message + "]")
}
- }
-
- def mustBlockUpdate(msg: BlockUpdate): Boolean = {
- var res = syncBlockUpdate(msg)
- while (!res.isDefined) {
- logWarning("Failed to send block update " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- }
- return res.get
- }
-
- def syncBlockUpdate(msg: BlockUpdate): Option[Boolean] = {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " msg " + msg + " "
- logDebug("Got in syncBlockUpdate " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs))
-
- try {
- val answer = askMaster(msg).asInstanceOf[Boolean]
- logDebug("Block update sent successfully")
- logDebug("Got in synbBlockUpdate " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs))
- return Some(answer)
- } catch {
- case e: Exception =>
- logError("Failed in syncBlockUpdate", e)
- return None
- }
- }
-
- def mustGetLocations(msg: GetLocations): Seq[BlockManagerId] = {
- var res = syncGetLocations(msg)
- while (res == null) {
- logInfo("Failed to get locations " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- res = syncGetLocations(msg)
- }
- return res
- }
-
- def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = {
- val startTimeMs = System.currentTimeMillis()
- val tmp = " msg " + msg + " "
- logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- try {
- val answer = askMaster(msg).asInstanceOf[ArrayBuffer[BlockManagerId]]
- if (answer != null) {
- logDebug("GetLocations successful")
- logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- return answer
- } else {
- logError("Master replied null in response to GetLocations")
- return null
+ var attempts = 0
+ var lastException: Exception = null
+ while (attempts < AKKA_RETRY_ATTEMPS) {
+ attempts += 1
+ try {
+ val future = masterActor.ask(message)(timeout)
+ val result = Await.result(future, timeout)
+ if (result == null) {
+ throw new Exception("BlockManagerMaster returned null")
+ }
+ return result.asInstanceOf[T]
+ } catch {
+ case ie: InterruptedException => throw ie
+ case e: Exception =>
+ lastException = e
+ logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e)
}
- } catch {
- case e: Exception =>
- logError("GetLocations failed", e)
- return null
+ Thread.sleep(AKKA_RETRY_INTERVAL_MS)
}
- }
- def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
- Seq[Seq[BlockManagerId]] = {
- var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg)
- while (res == null) {
- logWarning("Failed to GetLocationsMultipleBlockIds " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- res = syncGetLocationsMultipleBlockIds(msg)
- }
- return res
+ throw new SparkException(
+ "Error sending message to BlockManagerMaster [message = " + message + "]", lastException)
}
- def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds):
- Seq[Seq[BlockManagerId]] = {
- val startTimeMs = System.currentTimeMillis
- val tmp = " msg " + msg + " "
- logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- try {
- val answer = askMaster(msg).asInstanceOf[Seq[Seq[BlockManagerId]]]
- if (answer != null) {
- logDebug("GetLocationsMultipleBlockIds successful")
- logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp +
- Utils.getUsedTimeMs(startTimeMs))
- return answer
- } else {
- logError("Master replied null in response to GetLocationsMultipleBlockIds")
- return null
- }
- } catch {
- case e: Exception =>
- logError("GetLocationsMultipleBlockIds failed", e)
- return null
- }
- }
-
- def mustGetPeers(msg: GetPeers): Seq[BlockManagerId] = {
- var res = syncGetPeers(msg)
- while ((res == null) || (res.length != msg.size)) {
- logInfo("Failed to get peers " + msg)
- Thread.sleep(REQUEST_RETRY_INTERVAL_MS)
- res = syncGetPeers(msg)
- }
-
- return res
- }
-
- def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = {
- val startTimeMs = System.currentTimeMillis
- val tmp = " msg " + msg + " "
- logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs))
-
- try {
- val answer = askMaster(msg).asInstanceOf[Seq[BlockManagerId]]
- if (answer != null) {
- logDebug("GetPeers successful")
- logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs))
- return answer
- } else {
- logError("Master replied null in response to GetPeers")
- return null
- }
- } catch {
- case e: Exception =>
- logError("GetPeers failed", e)
- return null
- }
- }
-
- def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
- askMaster(GetMemoryStatus).asInstanceOf[Map[BlockManagerId, (Long, Long)]]
- }
}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
new file mode 100644
index 0000000000..f4d026da33
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerMasterActor.scala
@@ -0,0 +1,401 @@
+package spark.storage
+
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.JavaConversions._
+import scala.util.Random
+
+import akka.actor.{Actor, ActorRef, Cancellable}
+import akka.util.{Duration, Timeout}
+import akka.util.duration._
+
+import spark.{Logging, Utils}
+
+/**
+ * BlockManagerMasterActor is an actor on the master node to track statuses of
+ * all slaves' block managers.
+ */
+private[spark]
+class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
+
+ // Mapping from block manager id to the block manager's information.
+ private val blockManagerInfo =
+ new HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
+
+ // Mapping from host name to block manager id. We allow multiple block managers
+ // on the same host name (ip).
+ private val blockManagerIdByHost = new HashMap[String, ArrayBuffer[BlockManagerId]]
+
+ // Mapping from block id to the set of block managers that have the block.
+ private val blockLocations = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]]
+
+ initLogging()
+
+ val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs",
+ "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong
+
+ val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs",
+ "5000").toLong
+
+ var timeoutCheckingTask: Cancellable = null
+
+ override def preStart() {
+ if (!BlockManager.getDisableHeartBeatsForTesting) {
+ timeoutCheckingTask = context.system.scheduler.schedule(
+ 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
+ }
+ super.preStart()
+ }
+
+ def receive = {
+ case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) =>
+ register(blockManagerId, maxMemSize, slaveActor)
+
+ case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
+ updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size)
+
+ case GetLocations(blockId) =>
+ getLocations(blockId)
+
+ case GetLocationsMultipleBlockIds(blockIds) =>
+ getLocationsMultipleBlockIds(blockIds)
+
+ case GetPeers(blockManagerId, size) =>
+ getPeersDeterministic(blockManagerId, size)
+ /*getPeers(blockManagerId, size)*/
+
+ case GetMemoryStatus =>
+ getMemoryStatus
+
+ case RemoveBlock(blockId) =>
+ removeBlock(blockId)
+
+ case RemoveHost(host) =>
+ removeHost(host)
+ sender ! true
+
+ case StopBlockManagerMaster =>
+ logInfo("Stopping BlockManagerMaster")
+ sender ! true
+ if (timeoutCheckingTask != null) {
+ timeoutCheckingTask.cancel
+ }
+ context.stop(self)
+
+ case ExpireDeadHosts =>
+ expireDeadHosts()
+
+ case HeartBeat(blockManagerId) =>
+ heartBeat(blockManagerId)
+
+ case other =>
+ logInfo("Got unknown message: " + other)
+ }
+
+ def removeBlockManager(blockManagerId: BlockManagerId) {
+ val info = blockManagerInfo(blockManagerId)
+
+ // Remove the block manager from blockManagerIdByHost. If the list of block
+ // managers belonging to the IP is empty, remove the entry from the hash map.
+ blockManagerIdByHost.get(blockManagerId.ip).foreach { managers: ArrayBuffer[BlockManagerId] =>
+ managers -= blockManagerId
+ if (managers.size == 0) blockManagerIdByHost.remove(blockManagerId.ip)
+ }
+
+ // Remove it from blockManagerInfo and remove all the blocks.
+ blockManagerInfo.remove(blockManagerId)
+ var iterator = info.blocks.keySet.iterator
+ while (iterator.hasNext) {
+ val blockId = iterator.next
+ val locations = blockLocations.get(blockId)._2
+ locations -= blockManagerId
+ if (locations.size == 0) {
+ blockLocations.remove(locations)
+ }
+ }
+ }
+
+ def expireDeadHosts() {
+ logDebug("Checking for hosts with no recent heart beats in BlockManagerMaster.")
+ val now = System.currentTimeMillis()
+ val minSeenTime = now - slaveTimeout
+ val toRemove = new HashSet[BlockManagerId]
+ for (info <- blockManagerInfo.values) {
+ if (info.lastSeenMs < minSeenTime) {
+ logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats")
+ toRemove += info.blockManagerId
+ }
+ }
+ toRemove.foreach(removeBlockManager)
+ }
+
+ def removeHost(host: String) {
+ logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.")
+ logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq)
+ blockManagerIdByHost.get(host).foreach(_.foreach(removeBlockManager))
+ logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq)
+ sender ! true
+ }
+
+ def heartBeat(blockManagerId: BlockManagerId) {
+ if (!blockManagerInfo.contains(blockManagerId)) {
+ if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ sender ! true
+ } else {
+ sender ! false
+ }
+ } else {
+ blockManagerInfo(blockManagerId).updateLastSeenMs()
+ sender ! true
+ }
+ }
+
+ // Remove a block from the slaves that have it. This can only be used to remove
+ // blocks that the master knows about.
+ private def removeBlock(blockId: String) {
+ val block = blockLocations.get(blockId)
+ if (block != null) {
+ block._2.foreach { blockManagerId: BlockManagerId =>
+ val blockManager = blockManagerInfo.get(blockManagerId)
+ if (blockManager.isDefined) {
+ // Remove the block from the slave's BlockManager.
+ // Doesn't actually wait for a confirmation and the message might get lost.
+ // If message loss becomes frequent, we should add retry logic here.
+ blockManager.get.slaveActor ! RemoveBlock(blockId)
+ }
+ }
+ }
+ sender ! true
+ }
+
+ // Return a map from the block manager id to max memory and remaining memory.
+ private def getMemoryStatus() {
+ val res = blockManagerInfo.map { case(blockManagerId, info) =>
+ (blockManagerId, (info.maxMem, info.remainingMem))
+ }.toMap
+ sender ! res
+ }
+
+ private def register(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockManagerId + " "
+
+ if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ logInfo("Got Register Msg from master node, don't register it")
+ } else {
+ blockManagerIdByHost.get(blockManagerId.ip) match {
+ case Some(managers) =>
+ // A block manager of the same host name already exists.
+ logInfo("Got another registration for host " + blockManagerId)
+ managers += blockManagerId
+ case None =>
+ blockManagerIdByHost += (blockManagerId.ip -> ArrayBuffer(blockManagerId))
+ }
+
+ blockManagerInfo += (blockManagerId -> new BlockManagerMasterActor.BlockManagerInfo(
+ blockManagerId, System.currentTimeMillis(), maxMemSize, slaveActor))
+ }
+ sender ! true
+ }
+
+ private def updateBlockInfo(
+ blockManagerId: BlockManagerId,
+ blockId: String,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long) {
+
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockManagerId + " " + blockId + " "
+
+ if (!blockManagerInfo.contains(blockManagerId)) {
+ if (blockManagerId.ip == Utils.localHostName() && !isLocal) {
+ // We intentionally do not register the master (except in local mode),
+ // so we should not indicate failure.
+ sender ! true
+ } else {
+ sender ! false
+ }
+ return
+ }
+
+ if (blockId == null) {
+ blockManagerInfo(blockManagerId).updateLastSeenMs()
+ sender ! true
+ return
+ }
+
+ blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize)
+
+ var locations: HashSet[BlockManagerId] = null
+ if (blockLocations.containsKey(blockId)) {
+ locations = blockLocations.get(blockId)._2
+ } else {
+ locations = new HashSet[BlockManagerId]
+ blockLocations.put(blockId, (storageLevel.replication, locations))
+ }
+
+ if (storageLevel.isValid) {
+ locations.add(blockManagerId)
+ } else {
+ locations.remove(blockManagerId)
+ }
+
+ // Remove the block from master tracking if it has been removed on all slaves.
+ if (locations.size == 0) {
+ blockLocations.remove(blockId)
+ }
+ sender ! true
+ }
+
+ private def getLocations(blockId: String) {
+ val startTimeMs = System.currentTimeMillis()
+ val tmp = " " + blockId + " "
+ if (blockLocations.containsKey(blockId)) {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(blockLocations.get(blockId)._2)
+ sender ! res.toSeq
+ } else {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ sender ! res
+ }
+ }
+
+ private def getLocationsMultipleBlockIds(blockIds: Array[String]) {
+ def getLocations(blockId: String): Seq[BlockManagerId] = {
+ val tmp = blockId
+ if (blockLocations.containsKey(blockId)) {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(blockLocations.get(blockId)._2)
+ return res.toSeq
+ } else {
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ return res.toSeq
+ }
+ }
+
+ var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]]
+ for (blockId <- blockIds) {
+ res.append(getLocations(blockId))
+ }
+ sender ! res.toSeq
+ }
+
+ private def getPeers(blockManagerId: BlockManagerId, size: Int) {
+ var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+ res.appendAll(peers)
+ res -= blockManagerId
+ val rand = new Random(System.currentTimeMillis())
+ while (res.length > size) {
+ res.remove(rand.nextInt(res.length))
+ }
+ sender ! res.toSeq
+ }
+
+ private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) {
+ var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
+ var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId]
+
+ val selfIndex = peers.indexOf(blockManagerId)
+ if (selfIndex == -1) {
+ throw new Exception("Self index for " + blockManagerId + " not found")
+ }
+
+ // Note that this logic will select the same node multiple times if there aren't enough peers
+ var index = selfIndex
+ while (res.size < size) {
+ index += 1
+ if (index == selfIndex) {
+ throw new Exception("More peer expected than available")
+ }
+ res += peers(index % peers.size)
+ }
+ sender ! res.toSeq
+ }
+}
+
+
+private[spark]
+object BlockManagerMasterActor {
+
+ case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long)
+
+ class BlockManagerInfo(
+ val blockManagerId: BlockManagerId,
+ timeMs: Long,
+ val maxMem: Long,
+ val slaveActor: ActorRef)
+ extends Logging {
+
+ private var _lastSeenMs: Long = timeMs
+ private var _remainingMem: Long = maxMem
+
+ // Mapping from block id to its status.
+ private val _blocks = new JHashMap[String, BlockStatus]
+
+ logInfo("Registering block manager %s:%d with %s RAM".format(
+ blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem)))
+
+ def updateLastSeenMs() {
+ _lastSeenMs = System.currentTimeMillis()
+ }
+
+ def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, diskSize: Long)
+ : Unit = synchronized {
+
+ updateLastSeenMs()
+
+ if (_blocks.containsKey(blockId)) {
+ // The block exists on the slave already.
+ val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel
+
+ if (originalLevel.useMemory) {
+ _remainingMem += memSize
+ }
+ }
+
+ if (storageLevel.isValid) {
+ // isValid means it is either stored in-memory or on-disk.
+ _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize))
+ if (storageLevel.useMemory) {
+ _remainingMem -= memSize
+ logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ Utils.memoryBytesToString(_remainingMem)))
+ }
+ if (storageLevel.useDisk) {
+ logInfo("Added %s on disk on %s:%d (size: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
+ }
+ } else if (_blocks.containsKey(blockId)) {
+ // If isValid is not true, drop the block.
+ val blockStatus: BlockStatus = _blocks.get(blockId)
+ _blocks.remove(blockId)
+ if (blockStatus.storageLevel.useMemory) {
+ _remainingMem += blockStatus.memSize
+ logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
+ Utils.memoryBytesToString(_remainingMem)))
+ }
+ if (blockStatus.storageLevel.useDisk) {
+ logInfo("Removed %s on %s:%d on disk (size: %s)".format(
+ blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
+ }
+ }
+ }
+
+ def remainingMem: Long = _remainingMem
+
+ def lastSeenMs: Long = _lastSeenMs
+
+ def blocks: JHashMap[String, BlockStatus] = _blocks
+
+ override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
+
+ def clear() {
+ _blocks.clear()
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
new file mode 100644
index 0000000000..30483b0b37
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerMessages.scala
@@ -0,0 +1,100 @@
+package spark.storage
+
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
+
+import akka.actor.ActorRef
+
+
+//////////////////////////////////////////////////////////////////////////////////
+// Messages from the master to slaves.
+//////////////////////////////////////////////////////////////////////////////////
+private[spark]
+sealed trait ToBlockManagerSlave
+
+// Remove a block from the slaves that have it. This can only be used to remove
+// blocks that the master knows about.
+private[spark]
+case class RemoveBlock(blockId: String) extends ToBlockManagerSlave
+
+
+//////////////////////////////////////////////////////////////////////////////////
+// Messages from slaves to the master.
+//////////////////////////////////////////////////////////////////////////////////
+private[spark]
+sealed trait ToBlockManagerMaster
+
+private[spark]
+case class RegisterBlockManager(
+ blockManagerId: BlockManagerId,
+ maxMemSize: Long,
+ sender: ActorRef)
+ extends ToBlockManagerMaster
+
+private[spark]
+case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
+
+private[spark]
+class UpdateBlockInfo(
+ var blockManagerId: BlockManagerId,
+ var blockId: String,
+ var storageLevel: StorageLevel,
+ var memSize: Long,
+ var diskSize: Long)
+ extends ToBlockManagerMaster
+ with Externalizable {
+
+ def this() = this(null, null, null, 0, 0) // For deserialization only
+
+ override def writeExternal(out: ObjectOutput) {
+ blockManagerId.writeExternal(out)
+ out.writeUTF(blockId)
+ storageLevel.writeExternal(out)
+ out.writeInt(memSize.toInt)
+ out.writeInt(diskSize.toInt)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ blockManagerId = BlockManagerId(in)
+ blockId = in.readUTF()
+ storageLevel = StorageLevel(in)
+ memSize = in.readInt()
+ diskSize = in.readInt()
+ }
+}
+
+private[spark]
+object UpdateBlockInfo {
+ def apply(blockManagerId: BlockManagerId,
+ blockId: String,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long): UpdateBlockInfo = {
+ new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)
+ }
+
+ // For pattern-matching
+ def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
+ Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
+ }
+}
+
+private[spark]
+case class GetLocations(blockId: String) extends ToBlockManagerMaster
+
+private[spark]
+case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster
+
+private[spark]
+case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
+
+private[spark]
+case class RemoveHost(host: String) extends ToBlockManagerMaster
+
+private[spark]
+case object StopBlockManagerMaster extends ToBlockManagerMaster
+
+private[spark]
+case object GetMemoryStatus extends ToBlockManagerMaster
+
+private[spark]
+case object ExpireDeadHosts extends ToBlockManagerMaster
diff --git a/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala
new file mode 100644
index 0000000000..f570cdc52d
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockManagerSlaveActor.scala
@@ -0,0 +1,16 @@
+package spark.storage
+
+import akka.actor.Actor
+
+import spark.{Logging, SparkException, Utils}
+
+
+/**
+ * An actor to take commands from the master to execute options. For example,
+ * this is used to remove blocks from the slave's BlockManager.
+ */
+class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
+ override def receive = {
+ case RemoveBlock(blockId) => blockManager.removeBlock(blockId)
+ }
+}
diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala
index 3f234df654..30d7500e01 100644
--- a/core/src/main/scala/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/spark/storage/BlockMessage.scala
@@ -64,7 +64,7 @@ private[spark] class BlockMessage() {
val booleanInt = buffer.getInt()
val replication = buffer.getInt()
- level = new StorageLevel(booleanInt, replication)
+ level = StorageLevel(booleanInt, replication)
val dataLength = buffer.getInt()
data = ByteBuffer.allocate(dataLength)
diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala
index 096bf8bdd9..8188d3595e 100644
--- a/core/src/main/scala/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/spark/storage/BlockStore.scala
@@ -31,7 +31,12 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
def getValues(blockId: String): Option[Iterator[Any]]
- def remove(blockId: String)
+ /**
+ * Remove a block, if it exists.
+ * @param blockId the block to remove.
+ * @return True if the block was found and removed, False otherwise.
+ */
+ def remove(blockId: String): Boolean
def contains(blockId: String): Boolean
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala
index b5561479db..7e5b820cbb 100644
--- a/core/src/main/scala/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/spark/storage/DiskStore.scala
@@ -92,10 +92,13 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
}
- override def remove(blockId: String) {
+ override def remove(blockId: String): Boolean = {
val file = getFile(blockId)
if (file.exists()) {
file.delete()
+ true
+ } else {
+ false
}
}
diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala
index 02098b82fe..ae88ff0bb1 100644
--- a/core/src/main/scala/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/spark/storage/MemoryStore.scala
@@ -17,7 +17,6 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true)
private var currentMemory = 0L
-
// Object used to ensure that only one thread is putting blocks and if necessary, dropping
// blocks from the memory store.
private val putLock = new Object()
@@ -90,7 +89,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def remove(blockId: String) {
+ override def remove(blockId: String): Boolean = {
entries.synchronized {
val entry = entries.get(blockId)
if (entry != null) {
@@ -98,8 +97,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
currentMemory -= entry.size
logInfo("Block %s of size %d dropped from memory (free %d)".format(
blockId, entry.size, freeMemory))
+ true
} else {
- logWarning("Block " + blockId + " could not be removed as it does not exist")
+ false
}
}
}
diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala
index c497f03e0c..d1d1c61c1c 100644
--- a/core/src/main/scala/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/spark/storage/StorageLevel.scala
@@ -1,53 +1,60 @@
package spark.storage
-import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
/**
* Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
* whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory
* in a serialized format, and whether to replicate the RDD partitions on multiple nodes.
* The [[spark.storage.StorageLevel$]] singleton object contains some static constants for
- * commonly useful storage levels.
+ * commonly useful storage levels. To create your own storage level object, use the factor method
+ * of the singleton object (`StorageLevel(...)`).
*/
-class StorageLevel(
- var useDisk: Boolean,
- var useMemory: Boolean,
- var deserialized: Boolean,
- var replication: Int = 1)
+class StorageLevel private(
+ private var useDisk_ : Boolean,
+ private var useMemory_ : Boolean,
+ private var deserialized_ : Boolean,
+ private var replication_ : Int = 1)
extends Externalizable {
// TODO: Also add fields for caching priority, dataset ID, and flushing.
-
- def this(flags: Int, replication: Int) {
+ private def this(flags: Int, replication: Int) {
this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
}
def this() = this(false, true, false) // For deserialization
+ def useDisk = useDisk_
+ def useMemory = useMemory_
+ def deserialized = deserialized_
+ def replication = replication_
+
+ assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
+
override def clone(): StorageLevel = new StorageLevel(
this.useDisk, this.useMemory, this.deserialized, this.replication)
override def equals(other: Any): Boolean = other match {
case s: StorageLevel =>
- s.useDisk == useDisk &&
+ s.useDisk == useDisk &&
s.useMemory == useMemory &&
s.deserialized == deserialized &&
- s.replication == replication
+ s.replication == replication
case _ =>
false
}
-
+
def isValid = ((useMemory || useDisk) && (replication > 0))
def toInt: Int = {
var ret = 0
- if (useDisk) {
+ if (useDisk_) {
ret |= 4
}
- if (useMemory) {
+ if (useMemory_) {
ret |= 2
}
- if (deserialized) {
+ if (deserialized_) {
ret |= 1
}
return ret
@@ -55,21 +62,27 @@ class StorageLevel(
override def writeExternal(out: ObjectOutput) {
out.writeByte(toInt)
- out.writeByte(replication)
+ out.writeByte(replication_)
}
override def readExternal(in: ObjectInput) {
val flags = in.readByte()
- useDisk = (flags & 4) != 0
- useMemory = (flags & 2) != 0
- deserialized = (flags & 1) != 0
- replication = in.readByte()
+ useDisk_ = (flags & 4) != 0
+ useMemory_ = (flags & 2) != 0
+ deserialized_ = (flags & 1) != 0
+ replication_ = in.readByte()
}
+ @throws(classOf[IOException])
+ private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this)
+
override def toString: String =
"StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
+
+ override def hashCode(): Int = toInt * 41 + replication
}
+
object StorageLevel {
val NONE = new StorageLevel(false, false, false)
val DISK_ONLY = new StorageLevel(true, false, false)
@@ -82,4 +95,31 @@ object StorageLevel {
val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2)
val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
+
+ /** Create a new StorageLevel object */
+ def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) =
+ getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication))
+
+ /** Create a new StorageLevel object from its integer representation */
+ def apply(flags: Int, replication: Int) =
+ getCachedStorageLevel(new StorageLevel(flags, replication))
+
+ /** Read StorageLevel object from ObjectInput stream */
+ def apply(in: ObjectInput) = {
+ val obj = new StorageLevel()
+ obj.readExternal(in)
+ getCachedStorageLevel(obj)
+ }
+
+ private[spark]
+ val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
+
+ private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
+ if (storageLevelCache.containsKey(level)) {
+ storageLevelCache.get(level)
+ } else {
+ storageLevelCache.put(level, level)
+ level
+ }
+ }
}
diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala
index 5bb5a29cc4..689f07b969 100644
--- a/core/src/main/scala/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/spark/storage/ThreadingTest.scala
@@ -58,8 +58,10 @@ private[spark] object ThreadingTest {
val startTime = System.currentTimeMillis()
manager.get(blockId) match {
case Some(retrievedBlock) =>
- assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList, "Block " + blockId + " did not match")
- println("Got block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms")
+ assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList,
+ "Block " + blockId + " did not match")
+ println("Got block " + blockId + " in " +
+ (System.currentTimeMillis - startTime) + " ms")
case None =>
assert(false, "Block " + blockId + " could not be retrieved")
}
@@ -73,7 +75,9 @@ private[spark] object ThreadingTest {
System.setProperty("spark.kryoserializer.buffer.mb", "1")
val actorSystem = ActorSystem("test")
val serializer = new KryoSerializer
- val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true)
+ val masterIp: String = System.getProperty("spark.master.host", "localhost")
+ val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
+ val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, masterIp, masterPort)
val blockManager = new BlockManager(actorSystem, blockManagerMaster, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
@@ -86,6 +90,7 @@ private[spark] object ThreadingTest {
actorSystem.shutdown()
actorSystem.awaitTermination()
println("Everything stopped.")
- println("It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.")
+ println(
+ "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.")
}
}
diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala
index e67cb0336d..fbd0ff46bf 100644
--- a/core/src/main/scala/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/spark/util/AkkaUtils.scala
@@ -32,6 +32,7 @@ private[spark] object AkkaUtils {
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
akka.actor.provider = "akka.remote.RemoteActorRefProvider"
akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
+ akka.remote.log-remote-lifecycle-events = on
akka.remote.netty.hostname = "%s"
akka.remote.netty.port = %d
akka.remote.netty.connection-timeout = %ds
diff --git a/core/src/main/scala/spark/util/IdGenerator.scala b/core/src/main/scala/spark/util/IdGenerator.scala
new file mode 100644
index 0000000000..b6e309fe1a
--- /dev/null
+++ b/core/src/main/scala/spark/util/IdGenerator.scala
@@ -0,0 +1,14 @@
+package spark.util
+
+import java.util.concurrent.atomic.AtomicInteger
+
+/**
+ * A util used to get a unique generation ID. This is a wrapper around Java's
+ * AtomicInteger. An example usage is in BlockManager, where each BlockManager
+ * instance would start an Akka actor and we use this utility to assign the Akka
+ * actors unique names.
+ */
+private[spark] class IdGenerator {
+ private var id = new AtomicInteger
+ def next: Int = id.incrementAndGet
+}
diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala
new file mode 100644
index 0000000000..139e21d09e
--- /dev/null
+++ b/core/src/main/scala/spark/util/MetadataCleaner.scala
@@ -0,0 +1,44 @@
+package spark.util
+
+import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors}
+import java.util.{TimerTask, Timer}
+import spark.Logging
+
+
+class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
+
+ val delaySeconds = MetadataCleaner.getDelaySeconds
+ val periodSeconds = math.max(10, delaySeconds / 10)
+ val timer = new Timer(name + " cleanup timer", true)
+
+ val task = new TimerTask {
+ def run() {
+ try {
+ if (delaySeconds > 0) {
+ cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
+ logInfo("Ran metadata cleaner for " + name)
+ }
+ } catch {
+ case e: Exception => logError("Error running cleanup task for " + name, e)
+ }
+ }
+ }
+
+ if (periodSeconds > 0) {
+ logInfo(
+ "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and "
+ + "period of " + periodSeconds + " secs")
+ timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000)
+ }
+
+ def cancel() {
+ timer.cancel()
+ }
+}
+
+
+object MetadataCleaner {
+ def getDelaySeconds = (System.getProperty("spark.cleaner.delay", "-100").toDouble * 60).toInt
+ def setDelaySeconds(delay: Long) { System.setProperty("spark.cleaner.delay", delay.toString) }
+}
+
diff --git a/core/src/main/scala/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala
new file mode 100644
index 0000000000..e3f00ea8c7
--- /dev/null
+++ b/core/src/main/scala/spark/util/RateLimitedOutputStream.scala
@@ -0,0 +1,62 @@
+package spark.util
+
+import scala.annotation.tailrec
+
+import java.io.OutputStream
+import java.util.concurrent.TimeUnit._
+
+class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream {
+ val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS)
+ val CHUNK_SIZE = 8192
+ var lastSyncTime = System.nanoTime
+ var bytesWrittenSinceSync: Long = 0
+
+ override def write(b: Int) {
+ waitToWrite(1)
+ out.write(b)
+ }
+
+ override def write(bytes: Array[Byte]) {
+ write(bytes, 0, bytes.length)
+ }
+
+ @tailrec
+ override final def write(bytes: Array[Byte], offset: Int, length: Int) {
+ val writeSize = math.min(length - offset, CHUNK_SIZE)
+ if (writeSize > 0) {
+ waitToWrite(writeSize)
+ out.write(bytes, offset, writeSize)
+ write(bytes, offset + writeSize, length)
+ }
+ }
+
+ override def flush() {
+ out.flush()
+ }
+
+ override def close() {
+ out.close()
+ }
+
+ @tailrec
+ private def waitToWrite(numBytes: Int) {
+ val now = System.nanoTime
+ val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS)
+ val rate = bytesWrittenSinceSync.toDouble / elapsedSecs
+ if (rate < bytesPerSec) {
+ // It's okay to write; just update some variables and return
+ bytesWrittenSinceSync += numBytes
+ if (now > lastSyncTime + SYNC_INTERVAL) {
+ // Sync interval has passed; let's resync
+ lastSyncTime = now
+ bytesWrittenSinceSync = numBytes
+ }
+ } else {
+ // Calculate how much time we should sleep to bring ourselves to the desired rate.
+ // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala)
+ val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS)
+ if (sleepTime > 0) Thread.sleep(sleepTime)
+ waitToWrite(numBytes)
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
new file mode 100644
index 0000000000..bb7c5c01c8
--- /dev/null
+++ b/core/src/main/scala/spark/util/TimeStampedHashMap.scala
@@ -0,0 +1,93 @@
+package spark.util
+
+import java.util.concurrent.ConcurrentHashMap
+import scala.collection.JavaConversions
+import scala.collection.mutable.Map
+
+/**
+ * This is a custom implementation of scala.collection.mutable.Map which stores the insertion
+ * time stamp along with each key-value pair. Key-value pairs that are older than a particular
+ * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in
+ * replacement of scala.collection.mutable.HashMap.
+ */
+class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
+ val internalMap = new ConcurrentHashMap[A, (B, Long)]()
+
+ def get(key: A): Option[B] = {
+ val value = internalMap.get(key)
+ if (value != null) Some(value._1) else None
+ }
+
+ def iterator: Iterator[(A, B)] = {
+ val jIterator = internalMap.entrySet().iterator()
+ JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1))
+ }
+
+ override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = {
+ val newMap = new TimeStampedHashMap[A, B1]
+ newMap.internalMap.putAll(this.internalMap)
+ newMap.internalMap.put(kv._1, (kv._2, currentTime))
+ newMap
+ }
+
+ override def - (key: A): Map[A, B] = {
+ val newMap = new TimeStampedHashMap[A, B]
+ newMap.internalMap.putAll(this.internalMap)
+ newMap.internalMap.remove(key)
+ newMap
+ }
+
+ override def += (kv: (A, B)): this.type = {
+ internalMap.put(kv._1, (kv._2, currentTime))
+ this
+ }
+
+ override def -= (key: A): this.type = {
+ internalMap.remove(key)
+ this
+ }
+
+ override def update(key: A, value: B) {
+ this += ((key, value))
+ }
+
+ override def apply(key: A): B = {
+ val value = internalMap.get(key)
+ if (value == null) throw new NoSuchElementException()
+ value._1
+ }
+
+ override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
+ JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
+ }
+
+ override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
+
+ override def size(): Int = internalMap.size()
+
+ override def foreach[U](f: ((A, B)) => U): Unit = {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ val kv = (entry.getKey, entry.getValue._1)
+ f(kv)
+ }
+ }
+
+ /**
+ * Removes old key-value pairs that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ if (entry.getValue._2 < threshTime) {
+ logDebug("Removing key " + entry.getKey)
+ iterator.remove()
+ }
+ }
+ }
+
+ private def currentTime: Long = System.currentTimeMillis()
+
+}
diff --git a/core/src/main/scala/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/spark/util/TimeStampedHashSet.scala
new file mode 100644
index 0000000000..5f1cc93752
--- /dev/null
+++ b/core/src/main/scala/spark/util/TimeStampedHashSet.scala
@@ -0,0 +1,69 @@
+package spark.util
+
+import scala.collection.mutable.Set
+import scala.collection.JavaConversions
+import java.util.concurrent.ConcurrentHashMap
+
+
+class TimeStampedHashSet[A] extends Set[A] {
+ val internalMap = new ConcurrentHashMap[A, Long]()
+
+ def contains(key: A): Boolean = {
+ internalMap.contains(key)
+ }
+
+ def iterator: Iterator[A] = {
+ val jIterator = internalMap.entrySet().iterator()
+ JavaConversions.asScalaIterator(jIterator).map(_.getKey)
+ }
+
+ override def + (elem: A): Set[A] = {
+ val newSet = new TimeStampedHashSet[A]
+ newSet ++= this
+ newSet += elem
+ newSet
+ }
+
+ override def - (elem: A): Set[A] = {
+ val newSet = new TimeStampedHashSet[A]
+ newSet ++= this
+ newSet -= elem
+ newSet
+ }
+
+ override def += (key: A): this.type = {
+ internalMap.put(key, currentTime)
+ this
+ }
+
+ override def -= (key: A): this.type = {
+ internalMap.remove(key)
+ this
+ }
+
+ override def empty: Set[A] = new TimeStampedHashSet[A]()
+
+ override def size(): Int = internalMap.size()
+
+ override def foreach[U](f: (A) => U): Unit = {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ f(iterator.next.getKey)
+ }
+ }
+
+ /**
+ * Removes old values that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ if (entry.getValue < threshTime) {
+ iterator.remove()
+ }
+ }
+ }
+
+ private def currentTime: Long = System.currentTimeMillis()
+}
diff --git a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
index c32ab30401..be69e9bf02 100644
--- a/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/worker_row.scala.html
@@ -7,6 +7,7 @@
<a href="@worker.webUiAddress">@worker.id</href>
</td>
<td>@{worker.host}:@{worker.port}</td>
+ <td>@worker.state</td>
<td>@worker.cores (@worker.coresUsed Used)</td>
<td>@{Utils.memoryMegabytesToString(worker.memory)}
(@{Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</td>
diff --git a/core/src/main/twirl/spark/deploy/master/worker_table.scala.html b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html
index fad1af41dc..b249411a62 100644
--- a/core/src/main/twirl/spark/deploy/master/worker_table.scala.html
+++ b/core/src/main/twirl/spark/deploy/master/worker_table.scala.html
@@ -5,6 +5,7 @@
<tr>
<th>ID</th>
<th>Address</th>
+ <th>State</th>
<th>Cores</th>
<th>Memory</th>
</tr>
diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties
index 4c99e450bc..6ec89c0184 100644
--- a/core/src/test/resources/log4j.properties
+++ b/core/src/test/resources/log4j.properties
@@ -1,8 +1,8 @@
-# Set everything to be logged to the console
+# Set everything to be logged to the file core/target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
-log4j.appender.file.file=spark-tests.log
+log4j.appender.file.file=core/target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
diff --git a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala b/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala
deleted file mode 100644
index 37cafd1e8e..0000000000
--- a/core/src/test/scala/spark/BoundedMemoryCacheSuite.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-package spark
-
-import org.scalatest.FunSuite
-import org.scalatest.PrivateMethodTester
-import org.scalatest.matchers.ShouldMatchers
-
-// TODO: Replace this with a test of MemoryStore
-class BoundedMemoryCacheSuite extends FunSuite with PrivateMethodTester with ShouldMatchers {
- test("constructor test") {
- val cache = new BoundedMemoryCache(60)
- expect(60)(cache.getCapacity)
- }
-
- test("caching") {
- // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
- val oldArch = System.setProperty("os.arch", "amd64")
- val oldOops = System.setProperty("spark.test.useCompressedOops", "true")
- val initialize = PrivateMethod[Unit]('initialize)
- SizeEstimator invokePrivate initialize()
-
- val cache = new BoundedMemoryCache(60) {
- //TODO sorry about this, but there is not better way how to skip 'cacheTracker.dropEntry'
- override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) {
- logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size))
- }
- }
-
- // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length
- // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6.
- // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
- // Work around to check for either.
-
- //should be OK
- cache.put("1", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48)))
-
- //we cannot add this to cache (there is not enough space in cache) & we cannot evict the only value from
- //cache because it's from the same dataset
- expect(CachePutFailure())(cache.put("1", 1, "Meh"))
-
- //should be OK, dataset '1' can be evicted from cache
- cache.put("2", 0, "Meh") should (equal (CachePutSuccess(56)) or equal (CachePutSuccess(48)))
-
- //should fail, cache should obey it's capacity
- expect(CachePutFailure())(cache.put("3", 0, "Very_long_and_useless_string"))
-
- if (oldArch != null) {
- System.setProperty("os.arch", oldArch)
- } else {
- System.clearProperty("os.arch")
- }
-
- if (oldOops != null) {
- System.setProperty("spark.test.useCompressedOops", oldOops)
- } else {
- System.clearProperty("spark.test.useCompressedOops")
- }
- }
-}
diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala
deleted file mode 100644
index 467605981b..0000000000
--- a/core/src/test/scala/spark/CacheTrackerSuite.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-package spark
-
-import org.scalatest.FunSuite
-
-import scala.collection.mutable.HashMap
-
-import akka.actor._
-import akka.dispatch._
-import akka.pattern.ask
-import akka.remote._
-import akka.util.Duration
-import akka.util.Timeout
-import akka.util.duration._
-
-class CacheTrackerSuite extends FunSuite {
- // Send a message to an actor and wait for a reply, in a blocking manner
- private def ask(actor: ActorRef, message: Any): Any = {
- try {
- val timeout = 10.seconds
- val future = actor.ask(message)(timeout)
- return Await.result(future, timeout)
- } catch {
- case e: Exception =>
- throw new SparkException("Error communicating with actor", e)
- }
- }
-
- test("CacheTrackerActor slave initialization & cache status") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("RegisterRDD") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 3)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil)))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("AddedToCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
-
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- test("DroppedFromCache") {
- //System.setProperty("spark.master.port", "1345")
- val initialSize = 2L << 20
-
- val actorSystem = ActorSystem("test")
- val tracker = actorSystem.actorOf(Props[CacheTrackerActor])
-
- assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true)
-
- assert(ask(tracker, RegisterRDD(1, 2)) === true)
- assert(ask(tracker, RegisterRDD(2, 1)) === true)
-
- assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true)
- assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true)
- assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001"))))
-
- assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true)
-
- assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L)))
- assert(getCacheLocations(tracker) ===
- Map(1 -> List(List("host001"),List()), 2 -> List(List("host001"))))
-
- assert(ask(tracker, StopCacheTracker) === true)
-
- actorSystem.shutdown()
- actorSystem.awaitTermination()
- }
-
- /**
- * Helper function to get cacheLocations from CacheTracker
- */
- def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = {
- val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]]
- answer.map { case (i, arr) => (i, arr.toList) }
- }
-}
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
new file mode 100644
index 0000000000..51573254ca
--- /dev/null
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -0,0 +1,357 @@
+package spark
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import java.io.File
+import spark.rdd._
+import spark.SparkContext._
+import storage.StorageLevel
+
+class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging {
+ initLogging()
+
+ var sc: SparkContext = _
+ var checkpointDir: File = _
+ val partitioner = new HashPartitioner(2)
+
+ before {
+ checkpointDir = File.createTempFile("temp", "")
+ checkpointDir.delete()
+
+ sc = new SparkContext("local", "test")
+ sc.setCheckpointDir(checkpointDir.toString)
+ }
+
+ after {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
+
+ if (checkpointDir != null) {
+ checkpointDir.delete()
+ }
+ }
+
+ test("RDDs with one-to-one dependencies") {
+ testCheckpointing(_.map(x => x.toString))
+ testCheckpointing(_.flatMap(x => 1 to x))
+ testCheckpointing(_.filter(_ % 2 == 0))
+ testCheckpointing(_.sample(false, 0.5, 0))
+ testCheckpointing(_.glom())
+ testCheckpointing(_.mapPartitions(_.map(_.toString)))
+ testCheckpointing(r => new MapPartitionsWithSplitRDD(r,
+ (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
+ testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
+ testCheckpointing(_.pipe(Seq("cat")))
+ }
+
+ test("ParallelCollection") {
+ val parCollection = sc.makeRDD(1 to 4, 2)
+ val numSplits = parCollection.splits.size
+ parCollection.checkpoint()
+ assert(parCollection.dependencies === Nil)
+ val result = parCollection.collect()
+ assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result)
+ assert(parCollection.dependencies != Nil)
+ assert(parCollection.splits.length === numSplits)
+ assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList)
+ assert(parCollection.collect() === result)
+ }
+
+ test("BlockRDD") {
+ val blockId = "id"
+ val blockManager = SparkEnv.get.blockManager
+ blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
+ val blockRDD = new BlockRDD[String](sc, Array(blockId))
+ val numSplits = blockRDD.splits.size
+ blockRDD.checkpoint()
+ val result = blockRDD.collect()
+ assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result)
+ assert(blockRDD.dependencies != Nil)
+ assert(blockRDD.splits.length === numSplits)
+ assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList)
+ assert(blockRDD.collect() === result)
+ }
+
+ test("ShuffledRDD") {
+ testCheckpointing(rdd => {
+ // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
+ new ShuffledRDD(rdd.map(x => (x % 2, 1)), partitioner)
+ })
+ }
+
+ test("UnionRDD") {
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+
+ // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed.
+ // Current implementation of UnionRDD has transient reference to parent RDDs,
+ // so only the splits will reduce in serialized size, not the RDD.
+ testCheckpointing(_.union(otherRDD), false, true)
+ testParentCheckpointing(_.union(otherRDD), false, true)
+ }
+
+ test("CartesianRDD") {
+ def otherRDD = sc.makeRDD(1 to 10, 1)
+ testCheckpointing(new CartesianRDD(sc, _, otherRDD))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false)
+
+ // Test that the CartesianRDD updates parent splits (CartesianRDD.s1/s2) after
+ // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+ // Note that this test is very specific to the current implementation of CartesianRDD.
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint // checkpoint that MappedRDD
+ val cartesian = new CartesianRDD(sc, ones, ones)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ cartesian.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
+ assert(
+ (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) &&
+ (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2),
+ "CartesianRDD.parents not updated after parent RDD checkpointed"
+ )
+ }
+
+ test("CoalescedRDD") {
+ testCheckpointing(new CoalescedRDD(_, 2))
+
+ // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of CoalescedRDDSplit has transient reference to parent RDD,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(new CoalescedRDD(_, 2), true, false)
+
+ // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after
+ // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
+ // Note that this test is very specific to the current implementation of CoalescedRDDSplits
+ val ones = sc.makeRDD(1 to 100, 10).map(x => x)
+ ones.checkpoint // checkpoint that MappedRDD
+ val coalesced = new CoalescedRDD(ones, 2)
+ val splitBeforeCheckpoint =
+ serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ coalesced.count() // do the checkpointing
+ val splitAfterCheckpoint =
+ serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
+ assert(
+ splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head,
+ "CoalescedRDDSplit.parents not updated after parent RDD checkpointed"
+ )
+ }
+
+ test("CoGroupedRDD") {
+ val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD()
+ testCheckpointing(rdd => {
+ CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
+ }, false, true)
+
+ val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD()
+ testParentCheckpointing(rdd => {
+ CheckpointSuite.cogroup(
+ longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
+ }, false, true)
+ }
+
+ test("ZippedRDD") {
+ testCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+
+ // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed
+ // Current implementation of ZippedRDDSplit has transient references to parent RDDs,
+ // so only the RDD will reduce in serialized size, not the splits.
+ testParentCheckpointing(
+ rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+
+ }
+
+ /**
+ * Test checkpointing of the final RDD generated by the given operation. By default,
+ * this method tests whether the size of serialized RDD has reduced after checkpointing or not.
+ * It can also test whether the size of serialized RDD splits has reduced after checkpointing or
+ * not, but this is not done by default as usually the splits do not refer to any RDD and
+ * therefore never store the lineage.
+ */
+ def testCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean = true,
+ testRDDSplitSize: Boolean = false
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.headOption.orNull
+ val rddType = operatedRDD.getClass.getSimpleName
+ val numSplits = operatedRDD.splits.length
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ operatedRDD.checkpoint()
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the checkpoint file has been created
+ assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result)
+
+ // Test whether dependencies have been changed from its earlier parent RDD
+ assert(operatedRDD.dependencies.head.rdd != parentRDD)
+
+ // Test whether the splits have been changed to the new Hadoop splits
+ assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList)
+
+ // Test whether the number of splits is same as before
+ assert(operatedRDD.splits.length === numSplits)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced. If the RDD
+ // does not have any dependency to another RDD (e.g., ParallelCollection,
+ // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing.
+ if (testRDDSize) {
+ logInfo("Size of " + rddType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after checkpointing " +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
+ }
+
+ // Test whether serialized size of the splits has reduced. If the splits
+ // do not have any non-transient reference to another RDD or another RDD's splits, it
+ // does not refer to a lineage and therefore may not reduce in size after checkpointing.
+ // However, if the original splits before checkpointing do refer to a parent RDD, the splits
+ // must be forgotten after checkpointing (to remove all reference to parent RDDs) and
+ // replaced with the HadoopSplits of the checkpointed RDD.
+ if (testRDDSplitSize) {
+ logInfo("Size of " + rddType + " splits "
+ + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]")
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " splits did not reduce after checkpointing " +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
+ }
+ }
+
+ /**
+ * Test whether checkpointing of the parent of the generated RDD also
+ * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
+ * RDDs splits. So even if the parent RDD is checkpointed and its splits changed,
+ * this RDD will remember the splits and therefore potentially the whole lineage.
+ */
+ def testParentCheckpointing[U: ClassManifest](
+ op: (RDD[Int]) => RDD[U],
+ testRDDSize: Boolean,
+ testRDDSplitSize: Boolean
+ ) {
+ // Generate the final RDD using given RDD operation
+ val baseRDD = generateLongLineageRDD
+ val operatedRDD = op(baseRDD)
+ val parentRDD = operatedRDD.dependencies.head.rdd
+ val rddType = operatedRDD.getClass.getSimpleName
+ val parentRDDType = parentRDD.getClass.getSimpleName
+
+ // Find serialized sizes before and after the checkpoint
+ val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one
+ val result = operatedRDD.collect()
+ val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+
+ // Test whether the data in the checkpointed RDD is same as original
+ assert(operatedRDD.collect() === result)
+
+ // Test whether serialized size of the RDD has reduced because of its parent being
+ // checkpointed. If this RDD or its parent RDD do not have any dependency
+ // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may
+ // not reduce in size after checkpointing.
+ if (testRDDSize) {
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType +
+ "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
+ }
+
+ // Test whether serialized size of the splits has reduced because of its parent being
+ // checkpointed. If the splits do not have any non-transient reference to another RDD
+ // or another RDD's splits, it does not refer to a lineage and therefore may not reduce
+ // in size after checkpointing. However, if the splits do refer to the *splits* of a parent
+ // RDD, then these splits must update reference to the parent RDD splits as the parent RDD's
+ // splits must have changed after checkpointing.
+ if (testRDDSplitSize) {
+ assert(
+ splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
+ "Size of " + rddType + " splits did not reduce after checkpointing parent " + parentRDDType +
+ "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
+ )
+ }
+
+ }
+
+ /**
+ * Generate an RDD with a long lineage of one-to-one dependencies.
+ */
+ def generateLongLineageRDD(): RDD[Int] = {
+ var rdd = sc.makeRDD(1 to 100, 4)
+ for (i <- 1 to 50) {
+ rdd = rdd.map(x => x + 1)
+ }
+ rdd
+ }
+
+ /**
+ * Generate an RDD with a long lineage specifically for CoGroupedRDD.
+ * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage
+ * and narrow dependency with this RDD. This method generate such an RDD by a sequence
+ * of cogroups and mapValues which creates a long lineage of narrow dependencies.
+ */
+ def generateLongLineageRDDForCoGroupedRDD() = {
+ val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _)
+
+ def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
+
+ var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones)
+ for(i <- 1 to 10) {
+ cogrouped = cogrouped.mapValues(add).cogroup(ones)
+ }
+ cogrouped.mapValues(add)
+ }
+
+ /**
+ * Get serialized sizes of the RDD and its splits
+ */
+ def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
+ (Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size)
+ }
+
+ /**
+ * Serialize and deserialize an object. This is useful to verify the objects
+ * contents after deserialization (e.g., the contents of an RDD split after
+ * it is sent to a slave along with a task)
+ */
+ def serializeDeserialize[T](obj: T): T = {
+ val bytes = Utils.serialize(obj)
+ Utils.deserialize[T](bytes)
+ }
+}
+
+
+object CheckpointSuite {
+ // This is a custom cogroup function that does not use mapValues like
+ // the PairRDDFunctions.cogroup()
+ def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = {
+ //println("First = " + first + ", second = " + second)
+ new CoGroupedRDD[K](
+ Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]),
+ part
+ ).asInstanceOf[RDD[(K, Seq[Seq[V]])]]
+ }
+
+}
diff --git a/core/src/test/scala/spark/ClosureCleanerSuite.scala b/core/src/test/scala/spark/ClosureCleanerSuite.scala
index 7c0334d957..dfa2de80e6 100644
--- a/core/src/test/scala/spark/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/spark/ClosureCleanerSuite.scala
@@ -47,6 +47,8 @@ object TestObject {
val nums = sc.parallelize(Array(1, 2, 3, 4))
val answer = nums.map(_ + x).reduce(_ + _)
sc.stop()
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
return answer
}
}
diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala
index cacc2796b6..0487e06d12 100644
--- a/core/src/test/scala/spark/DistributedSuite.scala
+++ b/core/src/test/scala/spark/DistributedSuite.scala
@@ -188,4 +188,73 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val values = sc.parallelize(1 to 2, 2).map(x => System.getenv("TEST_VAR")).collect()
assert(values.toSeq === Seq("TEST_VALUE", "TEST_VALUE"))
}
+
+ test("recover from node failures") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2) // force executors to start
+ val masterId = SparkEnv.get.blockManager.blockManagerId
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).collect.size === 2)
+ }
+
+ test("recover from repeated node failures during shuffle-map") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, false), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ assert(data.map(failOnMarkedIdentity).map(x => x -> x).groupByKey.count === 2)
+ }
+ }
+
+ test("recover from repeated node failures during shuffle-reduce") {
+ import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity}
+ DistributedSuite.amMaster = true
+ sc = new SparkContext(clusterUrl, "test")
+ for (i <- 1 to 3) {
+ val data = sc.parallelize(Seq(true, true), 2)
+ assert(data.count === 2)
+ assert(data.map(markNodeIfIdentity).collect.size === 2)
+ // This relies on mergeCombiners being used to perform the actual reduce for this
+ // test to actually be testing what it claims.
+ val grouped = data.map(x => x -> x).combineByKey(
+ x => x,
+ (x: Boolean, y: Boolean) => x,
+ (x: Boolean, y: Boolean) => failOnMarkedIdentity(x)
+ )
+ assert(grouped.collect.size === 1)
+ }
+ }
+}
+
+object DistributedSuite {
+ // Indicates whether this JVM is marked for failure.
+ var mark = false
+
+ // Set by test to remember if we are in the driver program so we can assert
+ // that we are not.
+ var amMaster = false
+
+ // Act like an identity function, but if the argument is true, set mark to true.
+ def markNodeIfIdentity(item: Boolean): Boolean = {
+ if (item) {
+ assert(!amMaster)
+ mark = true
+ }
+ item
+ }
+
+ // Act like an identity function, but if mark was set to true previously, fail,
+ // crashing the entire JVM.
+ def failOnMarkedIdentity(item: Boolean): Boolean = {
+ if (mark) {
+ System.exit(42)
+ }
+ item
+ }
}
diff --git a/core/src/test/scala/spark/DriverSuite.scala b/core/src/test/scala/spark/DriverSuite.scala
new file mode 100644
index 0000000000..70a7c8bc2f
--- /dev/null
+++ b/core/src/test/scala/spark/DriverSuite.scala
@@ -0,0 +1,31 @@
+package spark
+
+import java.io.File
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.prop.TableDrivenPropertyChecks._
+import org.scalatest.time.SpanSugar._
+
+class DriverSuite extends FunSuite with Timeouts {
+ test("driver should exit after finishing") {
+ // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
+ val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
+ forAll(masters) { (master: String) =>
+ failAfter(10 seconds) {
+ Utils.execute(Seq("./run", "spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME")))
+ }
+ }
+ }
+}
+
+/**
+ * Program that creates a Spark driver but doesn't call SparkContext.stop() or
+ * Sys.exit() after finishing.
+ */
+object DriverWithoutCleanup {
+ def main(args: Array[String]) {
+ val sc = new SparkContext(args(0), "DriverWithoutCleanup")
+ sc.parallelize(1 to 100, 4).count()
+ }
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/FileServerSuite.scala b/core/src/test/scala/spark/FileServerSuite.scala
index b4283d9604..b9e1248829 100644
--- a/core/src/test/scala/spark/FileServerSuite.scala
+++ b/core/src/test/scala/spark/FileServerSuite.scala
@@ -9,8 +9,8 @@ import SparkContext._
class FileServerSuite extends FunSuite with BeforeAndAfter {
@transient var sc: SparkContext = _
- @transient var tmpFile : File = _
- @transient var testJarFile : File = _
+ @transient var tmpFile: File = _
+ @transient var testJarFile: File = _
before {
// Create a sample text file
@@ -40,7 +40,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -54,7 +55,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile((new File(tmpFile.toString)).toURL.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
@@ -83,7 +85,8 @@ class FileServerSuite extends FunSuite with BeforeAndAfter {
sc.addFile(tmpFile.toString)
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0))
val result = sc.parallelize(testData).reduceByKey {
- val in = new BufferedReader(new FileReader("FileServerSuite.txt"))
+ val path = SparkFiles.get("FileServerSuite.txt")
+ val in = new BufferedReader(new FileReader(path))
val fileVal = in.readLine().toInt
in.close()
_ * fileVal + _ * fileVal
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 46a0b68f89..01351de4ae 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -131,6 +131,17 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void lookup() {
+ JavaPairRDD<String, String> categories = sc.parallelizePairs(Arrays.asList(
+ new Tuple2<String, String>("Apples", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Fruit"),
+ new Tuple2<String, String>("Oranges", "Citrus")
+ ));
+ Assert.assertEquals(2, categories.lookup("Oranges").size());
+ Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size());
+ }
+
+ @Test
public void groupBy() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
Function<Integer, Boolean> isOdd = new Function<Integer, Boolean>() {
@@ -570,4 +581,91 @@ public class JavaAPISuite implements Serializable {
JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles);
zipped.count();
}
+
+ @Test
+ public void accumulators() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+
+ final Accumulator<Integer> intAccum = sc.intAccumulator(10);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ intAccum.add(x);
+ }
+ });
+ Assert.assertEquals((Integer) 25, intAccum.value());
+
+ final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ doubleAccum.add((double) x);
+ }
+ });
+ Assert.assertEquals((Double) 25.0, doubleAccum.value());
+
+ // Try a custom accumulator type
+ AccumulatorParam<Float> floatAccumulatorParam = new AccumulatorParam<Float>() {
+ public Float addInPlace(Float r, Float t) {
+ return r + t;
+ }
+
+ public Float addAccumulator(Float r, Float t) {
+ return r + t;
+ }
+
+ public Float zero(Float initialValue) {
+ return 0.0f;
+ }
+ };
+
+ final Accumulator<Float> floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ floatAccum.add((float) x);
+ }
+ });
+ Assert.assertEquals((Float) 25.0f, floatAccum.value());
+
+ // Test the setValue method
+ floatAccum.setValue(5.0f);
+ Assert.assertEquals((Float) 5.0f, floatAccum.value());
+ }
+
+ @Test
+ public void keyBy() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2));
+ List<Tuple2<String, Integer>> s = rdd.keyBy(new Function<Integer, String>() {
+ public String call(Integer t) throws Exception {
+ return t.toString();
+ }
+ }).collect();
+ Assert.assertEquals(new Tuple2<String, Integer>("1", 1), s.get(0));
+ Assert.assertEquals(new Tuple2<String, Integer>("2", 2), s.get(1));
+ }
+
+ @Test
+ public void checkpointAndComputation() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), rdd.collect());
+ }
+
+ @Test
+ public void checkpointAndRestore() {
+ File tempDir = Files.createTempDir();
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ sc.setCheckpointDir(tempDir.getAbsolutePath(), true);
+ Assert.assertEquals(false, rdd.isCheckpointed());
+ rdd.checkpoint();
+ rdd.count(); // Forces the DAG to cause a checkpoint
+ Assert.assertEquals(true, rdd.isCheckpointed());
+
+ Assert.assertTrue(rdd.getCheckpointFile().isPresent());
+ JavaRDD<Integer> recovered = sc.checkpointFile(rdd.getCheckpointFile().get());
+ Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
+ }
}
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index 5b4b198960..095f415978 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -1,12 +1,18 @@
package spark
import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
import akka.actor._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
+import spark.util.AkkaUtils
-class MapOutputTrackerSuite extends FunSuite {
+class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter {
+ after {
+ System.clearProperty("spark.master.port")
+ }
+
test("compressSize") {
assert(MapOutputTracker.compressSize(0L) === 0)
assert(MapOutputTracker.compressSize(1L) === 1)
@@ -41,13 +47,13 @@ class MapOutputTrackerSuite extends FunSuite {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000),
Array(compressedSize1000, compressedSize10000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000),
Array(compressedSize10000, compressedSize1000)))
val statuses = tracker.getServerStatuses(10, 0)
- assert(statuses.toSeq === Seq((new BlockManagerId("hostA", 1000), size1000),
- (new BlockManagerId("hostB", 1000), size10000)))
+ assert(statuses.toSeq === Seq((BlockManagerId("hostA", 1000), size1000),
+ (BlockManagerId("hostB", 1000), size10000)))
tracker.stop()
}
@@ -59,18 +65,48 @@ class MapOutputTrackerSuite extends FunSuite {
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
- tracker.registerMapOutput(10, 0, new MapStatus(new BlockManagerId("hostA", 1000),
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("hostA", 1000),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
- tracker.registerMapOutput(10, 1, new MapStatus(new BlockManagerId("hostB", 1000),
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("hostB", 1000),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simulatenous fetch failures
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
- tracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000))
+ tracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000))
// The remaining reduce task might try to grab the output dispite the shuffle failure;
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
- intercept[Exception] { tracker.getServerStatuses(10, 1) }
+ intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
+ }
+
+ test("remote fetch") {
+ System.clearProperty("spark.master.host")
+ val (actorSystem, boundPort) =
+ AkkaUtils.createActorSystem("test", "localhost", 0)
+ System.setProperty("spark.master.port", boundPort.toString)
+ val masterTracker = new MapOutputTracker(actorSystem, true)
+ val slaveTracker = new MapOutputTracker(actorSystem, false)
+ masterTracker.registerShuffle(10, 1)
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+ masterTracker.registerMapOutput(10, 0, new MapStatus(
+ BlockManagerId("hostA", 1000), Array(compressedSize1000)))
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+ Seq((BlockManagerId("hostA", 1000), size1000)))
+
+ masterTracker.unregisterMapOutput(10, 0, BlockManagerId("hostA", 1000))
+ masterTracker.incrementGeneration()
+ slaveTracker.updateGeneration(masterTracker.getGeneration)
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ // failure should be cached
+ intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
}
}
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index 3dadc7acec..eb3c8f238f 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -106,5 +106,31 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter {
assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
+
+ assert(grouped2.map(_ => 1).partitioner === None)
+ assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner)
+ assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner)
+ assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner)
+ }
+
+ test("partitioning Java arrays should fail") {
+ sc = new SparkContext("local", "test")
+ val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x))
+ val arrPairs: RDD[(Array[Int], Int)] =
+ sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x))
+
+ assert(intercept[SparkException]{ arrs.distinct() }.getMessage.contains("array"))
+ // We can't catch all usages of arrays, since they might occur inside other collections:
+ //assert(fails { arrPairs.distinct() })
+ assert(intercept[SparkException]{ arrPairs.partitionBy(new HashPartitioner(2)) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.join(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.leftOuterJoin(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.rightOuterJoin(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.groupByKey() }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.countByKey() }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.countByKeyApprox(1) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.cogroup(arrPairs) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
+ assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
}
}
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index b3c820ed94..db217f8482 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -8,9 +8,9 @@ import spark.rdd.CoalescedRDD
import SparkContext._
class RDDSuite extends FunSuite with BeforeAndAfter {
-
+
var sc: SparkContext = _
-
+
after {
if (sc != null) {
sc.stop()
@@ -19,11 +19,15 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.master.port")
}
-
+
test("basic operations") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
+ val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
+ assert(dups.distinct.count === 4)
+ assert(dups.distinct().collect === dups.distinct.collect)
+ assert(dups.distinct(2).collect === dups.distinct.collect)
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
@@ -31,6 +35,8 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
assert(nums.union(nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
+ assert(nums.collect({ case i if i >= 3 => i.toString }).collect().toList === List("3", "4"))
+ assert(nums.keyBy(_.toString).collect().toList === List(("1", 1), ("2", 2), ("3", 3), ("4", 4)))
val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
assert(partitionSums.collect().toList === List(3, 7))
@@ -70,10 +76,23 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}
- test("checkpointing") {
+ test("basic checkpointing") {
+ import java.io.File
+ val checkpointDir = File.createTempFile("temp", "")
+ checkpointDir.delete()
+
sc = new SparkContext("local", "test")
- val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint()
- assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
+ sc.setCheckpointDir(checkpointDir.toString)
+ val parCollection = sc.makeRDD(1 to 4)
+ val flatMappedRDD = parCollection.flatMap(x => 1 to x)
+ flatMappedRDD.checkpoint()
+ assert(flatMappedRDD.dependencies.head.rdd == parCollection)
+ val result = flatMappedRDD.collect()
+ Thread.sleep(1000)
+ assert(flatMappedRDD.dependencies.head.rdd != parCollection)
+ assert(flatMappedRDD.collect() === result)
+
+ checkpointDir.deleteOnExit()
}
test("basic caching") {
@@ -84,6 +103,29 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(rdd.collect().toList === List(1, 2, 3, 4))
}
+ test("caching with failures") {
+ sc = new SparkContext("local", "test")
+ val onlySplit = new Split { override def index: Int = 0 }
+ var shouldFail = true
+ val rdd = new RDD[Int](sc, Nil) {
+ override def getSplits: Array[Split] = Array(onlySplit)
+ override val getDependencies = List[Dependency[_]]()
+ override def compute(split: Split, context: TaskContext): Iterator[Int] = {
+ if (shouldFail) {
+ throw new Exception("injected failure")
+ } else {
+ return Array(1, 2, 3, 4).iterator
+ }
+ }
+ }.cache()
+ val thrown = intercept[Exception]{
+ rdd.collect()
+ }
+ assert(thrown.getMessage.contains("injected failure"))
+ shouldFail = false
+ assert(rdd.collect().toList === List(1, 2, 3, 4))
+ }
+
test("coalesced RDDs") {
sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
@@ -94,8 +136,8 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10)))
// Check that the narrow dependency is also specified correctly
- assert(coalesced1.dependencies.head.getParents(0).toList === List(0, 1, 2, 3, 4))
- assert(coalesced1.dependencies.head.getParents(1).toList === List(5, 6, 7, 8, 9))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(0).toList === List(0, 1, 2, 3, 4))
+ assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9))
val coalesced2 = new CoalescedRDD(data, 3)
assert(coalesced2.collect().toList === (1 to 10).toList)
@@ -121,7 +163,7 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
val zipped = nums.zip(nums.map(_ + 1.0))
assert(zipped.glom().map(_.toList).collect().toList ===
List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0))))
-
+
intercept[IllegalArgumentException] {
nums.zip(sc.parallelize(1 to 4, 1)).collect()
}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 8170100f1d..bebb8ebe86 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -216,6 +216,13 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
// Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
}
+
+ test("keys and values") {
+ sc = new SparkContext("local", "test")
+ val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
+ assert(rdd.keys.collect().toList === List(1, 2))
+ assert(rdd.values.collect().toList === List("a", "b"))
+ }
}
object ShuffleSuite {
diff --git a/core/src/test/scala/spark/SizeEstimatorSuite.scala b/core/src/test/scala/spark/SizeEstimatorSuite.scala
index 17f366212b..e235ef2f67 100644
--- a/core/src/test/scala/spark/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/spark/SizeEstimatorSuite.scala
@@ -3,7 +3,6 @@ package spark
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfterAll
import org.scalatest.PrivateMethodTester
-import org.scalatest.matchers.ShouldMatchers
class DummyClass1 {}
@@ -20,8 +19,17 @@ class DummyClass4(val d: DummyClass3) {
val x: Int = 0
}
+object DummyString {
+ def apply(str: String) : DummyString = new DummyString(str.toArray)
+}
+class DummyString(val arr: Array[Char]) {
+ override val hashCode: Int = 0
+ // JDK-7 has an extra hash32 field http://hg.openjdk.java.net/jdk7u/jdk7u6/jdk/rev/11987e85555f
+ @transient val hash32: Int = 0
+}
+
class SizeEstimatorSuite
- extends FunSuite with BeforeAndAfterAll with PrivateMethodTester with ShouldMatchers {
+ extends FunSuite with BeforeAndAfterAll with PrivateMethodTester {
var oldArch: String = _
var oldOops: String = _
@@ -45,15 +53,13 @@ class SizeEstimatorSuite
expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3)))
}
- // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length.
- // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6.
- // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
- // Work around to check for either.
+ // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
+ // (Sun vs IBM). Use a DummyString class to make tests deterministic.
test("strings") {
- SizeEstimator.estimate("") should (equal (48) or equal (40))
- SizeEstimator.estimate("a") should (equal (56) or equal (48))
- SizeEstimator.estimate("ab") should (equal (56) or equal (48))
- SizeEstimator.estimate("abcdefgh") should (equal(64) or equal(56))
+ expect(40)(SizeEstimator.estimate(DummyString("")))
+ expect(48)(SizeEstimator.estimate(DummyString("a")))
+ expect(48)(SizeEstimator.estimate(DummyString("ab")))
+ expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
}
test("primitive arrays") {
@@ -105,18 +111,16 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- expect(40)(SizeEstimator.estimate(""))
- expect(48)(SizeEstimator.estimate("a"))
- expect(48)(SizeEstimator.estimate("ab"))
- expect(56)(SizeEstimator.estimate("abcdefgh"))
+ expect(40)(SizeEstimator.estimate(DummyString("")))
+ expect(48)(SizeEstimator.estimate(DummyString("a")))
+ expect(48)(SizeEstimator.estimate(DummyString("ab")))
+ expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
}
- // NOTE: The String class definition changed in JDK 7 to exclude the int fields count and length.
- // This means that the size of strings will be lesser by 8 bytes in JDK 7 compared to JDK 6.
- // http://mail.openjdk.java.net/pipermail/core-libs-dev/2012-May/010257.html
- // Work around to check for either.
+ // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
+ // (Sun vs IBM). Use a DummyString class to make tests deterministic.
test("64-bit arch with no compressed oops") {
val arch = System.setProperty("os.arch", "amd64")
val oops = System.setProperty("spark.test.useCompressedOops", "false")
@@ -124,10 +128,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- SizeEstimator.estimate("") should (equal (64) or equal (56))
- SizeEstimator.estimate("a") should (equal (72) or equal (64))
- SizeEstimator.estimate("ab") should (equal (72) or equal (64))
- SizeEstimator.estimate("abcdefgh") should (equal (80) or equal (72))
+ expect(56)(SizeEstimator.estimate(DummyString("")))
+ expect(64)(SizeEstimator.estimate(DummyString("a")))
+ expect(64)(SizeEstimator.estimate(DummyString("ab")))
+ expect(72)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
resetOrClear("spark.test.useCompressedOops", oops)
diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
new file mode 100644
index 0000000000..ba6f8b588f
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala
@@ -0,0 +1,42 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+import spark.TaskContext
+import spark.RDD
+import spark.SparkContext
+import spark.Split
+
+class TaskContextSuite extends FunSuite with BeforeAndAfter {
+
+ var sc: SparkContext = _
+
+ after {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
+ }
+
+ test("Calls executeOnCompleteCallbacks after failure") {
+ var completed = false
+ sc = new SparkContext("local", "test")
+ val rdd = new RDD[String](sc, List()) {
+ override def getSplits = Array[Split](StubSplit(0))
+ override def compute(split: Split, context: TaskContext) = {
+ context.addOnCompleteCallback(() => completed = true)
+ sys.error("failed")
+ }
+ }
+ val func = (c: TaskContext, i: Iterator[String]) => i.next
+ val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0)
+ intercept[RuntimeException] {
+ task.run(0)
+ }
+ assert(completed === true)
+ }
+
+ case class StubSplit(val index: Int) extends Split
+} \ No newline at end of file
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index ad2253596d..a1aeb12f25 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -7,6 +7,10 @@ import akka.actor._
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.scalatest.PrivateMethodTester
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.concurrent.Timeouts._
+import org.scalatest.matchers.ShouldMatchers._
+import org.scalatest.time.SpanSugar._
import spark.KryoSerializer
import spark.SizeEstimator
@@ -20,15 +24,16 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
var oldArch: String = null
var oldOops: String = null
var oldHeartBeat: String = null
-
- // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
+
+ // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
+ System.setProperty("spark.kryoserializer.buffer.mb", "1")
val serializer = new KryoSerializer
before {
actorSystem = ActorSystem("test")
- master = new BlockManagerMaster(actorSystem, true, true)
+ master = new BlockManagerMaster(actorSystem, true, true, "localhost", 7077)
- // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
+ // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
oldArch = System.setProperty("os.arch", "amd64")
oldOops = System.setProperty("spark.test.useCompressedOops", "true")
oldHeartBeat = System.setProperty("spark.storage.disableBlockManagerHeartBeat", "true")
@@ -63,7 +68,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
}
- test("manager-master interaction") {
+ test("StorageLevel object caching") {
+ val level1 = StorageLevel(false, false, false, 3)
+ val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1
+ val level3 = StorageLevel(false, false, false, 2) // this should return a different object
+ assert(level2 === level1, "level2 is not same as level1")
+ assert(level2.eq(level1), "level2 is not the same object as level1")
+ assert(level3 != level1, "level3 is same as level1")
+ val bytes1 = spark.Utils.serialize(level1)
+ val level1_ = spark.Utils.deserialize[StorageLevel](bytes1)
+ val bytes2 = spark.Utils.serialize(level2)
+ val level2_ = spark.Utils.deserialize[StorageLevel](bytes2)
+ assert(level1_ === level1, "Deserialized level1 not same as original level1")
+ assert(level1_.eq(level1), "Deserialized level1 not the same object as original level2")
+ assert(level2_ === level2, "Deserialized level2 not same as original level2")
+ assert(level2_.eq(level1), "Deserialized level2 not the same object as original level1")
+ }
+
+ test("BlockManagerId object caching") {
+ val id1 = BlockManagerId("XXX", 1)
+ val id2 = BlockManagerId("XXX", 1) // this should return the same object as id1
+ val id3 = BlockManagerId("XXX", 2) // this should return a different object
+ assert(id2 === id1, "id2 is not same as id1")
+ assert(id2.eq(id1), "id2 is not the same object as id1")
+ assert(id3 != id1, "id3 is same as id1")
+ val bytes1 = spark.Utils.serialize(id1)
+ val id1_ = spark.Utils.deserialize[BlockManagerId](bytes1)
+ val bytes2 = spark.Utils.serialize(id2)
+ val id2_ = spark.Utils.deserialize[BlockManagerId](bytes2)
+ assert(id1_ === id1, "Deserialized id1 is not same as original id1")
+ assert(id1_.eq(id1), "Deserialized id1 is not the same object as original id1")
+ assert(id2_ === id2, "Deserialized id2 is not same as original id2")
+ assert(id2_.eq(id1), "Deserialized id2 is not the same object as original id1")
+ }
+
+ test("master + 1 manager interaction") {
store = new BlockManager(actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
@@ -74,83 +113,122 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY)
store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY, false)
- // Checking whether blocks are in memory
+ // Checking whether blocks are in memory
assert(store.getSingle("a1") != None, "a1 was not in store")
assert(store.getSingle("a2") != None, "a2 was not in store")
assert(store.getSingle("a3") != None, "a3 was not in store")
// Checking whether master knows about the blocks or not
- assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1")
- assert(master.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2")
- assert(master.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3")
-
+ assert(master.getLocations("a1").size > 0, "master was not told about a1")
+ assert(master.getLocations("a2").size > 0, "master was not told about a2")
+ assert(master.getLocations("a3").size === 0, "master was told about a3")
+
// Drop a1 and a2 from memory; this should be reported back to the master
store.dropFromMemory("a1", null)
store.dropFromMemory("a2", null)
assert(store.getSingle("a1") === None, "a1 not removed from store")
assert(store.getSingle("a2") === None, "a2 not removed from store")
- assert(master.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1")
- assert(master.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2")
+ assert(master.getLocations("a1").size === 0, "master did not remove a1")
+ assert(master.getLocations("a2").size === 0, "master did not remove a2")
}
- test("reregistration on heart beat") {
- val heartBeat = PrivateMethod[Unit]('heartBeat)
+ test("master + 2 managers interaction") {
store = new BlockManager(actorSystem, master, serializer, 2000)
+ store2 = new BlockManager(actorSystem, master, new KryoSerializer, 2000)
+
+ val peers = master.getPeers(store.blockManagerId, 1)
+ assert(peers.size === 1, "master did not return the other manager as a peer")
+ assert(peers.head === store2.blockManagerId, "peer returned by master is not the other manager")
+
val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_2)
+ store2.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_2)
+ assert(master.getLocations("a1").size === 2, "master did not report 2 locations for a1")
+ assert(master.getLocations("a2").size === 2, "master did not report 2 locations for a2")
+ }
- store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+ test("removing block") {
+ store = new BlockManager(actorSystem, master, serializer, 2000)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
- assert(store.getSingle("a1") != None, "a1 was not in store")
- assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1")
+ // Putting a1, a2 and a3 in memory and telling master only about a1 and a2
+ store.putSingle("a1-to-remove", a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a2-to-remove", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle("a3-to-remove", a3, StorageLevel.MEMORY_ONLY, false)
- master.notifyADeadHost(store.blockManagerId.ip)
- assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a1 was not removed from master")
+ // Checking whether blocks are in memory and memory size
+ val memStatus = master.getMemoryStatus.head._2
+ assert(memStatus._1 == 2000L, "total memory " + memStatus._1 + " should equal 2000")
+ assert(memStatus._2 <= 1200L, "remaining memory " + memStatus._2 + " should <= 1200")
+ assert(store.getSingle("a1-to-remove") != None, "a1 was not in store")
+ assert(store.getSingle("a2-to-remove") != None, "a2 was not in store")
+ assert(store.getSingle("a3-to-remove") != None, "a3 was not in store")
- store invokePrivate heartBeat()
- assert(master.mustGetLocations(GetLocations("a1")).size > 0,
- "a1 was not reregistered with master")
+ // Checking whether master knows about the blocks or not
+ assert(master.getLocations("a1-to-remove").size > 0, "master was not told about a1")
+ assert(master.getLocations("a2-to-remove").size > 0, "master was not told about a2")
+ assert(master.getLocations("a3-to-remove").size === 0, "master was told about a3")
+
+ // Remove a1 and a2 and a3. Should be no-op for a3.
+ master.removeBlock("a1-to-remove")
+ master.removeBlock("a2-to-remove")
+ master.removeBlock("a3-to-remove")
+
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("a1-to-remove") should be (None)
+ master.getLocations("a1-to-remove") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("a2-to-remove") should be (None)
+ master.getLocations("a2-to-remove") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ store.getSingle("a3-to-remove") should not be (None)
+ master.getLocations("a3-to-remove") should have size 0
+ }
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ val memStatus = master.getMemoryStatus.head._2
+ memStatus._1 should equal (2000L)
+ memStatus._2 should equal (2000L)
+ }
}
- test("reregistration on block update") {
+ test("reregistration on heart beat") {
+ val heartBeat = PrivateMethod[Unit]('heartBeat)
store = new BlockManager(actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
- val a2 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
- assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1")
+ assert(store.getSingle("a1") != None, "a1 was not in store")
+ assert(master.getLocations("a1").size > 0, "master was not told about a1")
master.notifyADeadHost(store.blockManagerId.ip)
- assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a1 was not removed from master")
-
- store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY)
+ assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
- assert(master.mustGetLocations(GetLocations("a1")).size > 0,
- "a1 was not reregistered with master")
- assert(master.mustGetLocations(GetLocations("a2")).size > 0,
- "master was not told about a2")
+ store invokePrivate heartBeat()
+ assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
}
- test("deregistration on duplicate") {
- val heartBeat = PrivateMethod[Unit]('heartBeat)
+ test("reregistration on block update") {
store = new BlockManager(actorSystem, master, serializer, 2000)
val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
- assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1")
+ assert(master.getLocations("a1").size > 0, "master was not told about a1")
- store2 = new BlockManager(actorSystem, master, serializer, 2000)
-
- assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a1 was not removed from master")
+ master.notifyADeadHost(store.blockManagerId.ip)
+ assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
- store invokePrivate heartBeat()
-
- assert(master.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1")
+ store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY)
- store2 invokePrivate heartBeat()
-
- assert(master.mustGetLocations(GetLocations("a1")).size == 0, "a2 was not removed from master")
+ assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
+ assert(master.getLocations("a2").size > 0, "master was not told about a2")
}
test("in-memory LRU storage") {
@@ -171,7 +249,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(store.getSingle("a2") != None, "a2 was not in store")
assert(store.getSingle("a3") === None, "a3 was in store")
}
-
+
test("in-memory LRU storage with serialization") {
store = new BlockManager(actorSystem, master, serializer, 1200)
val a1 = new Array[Byte](400)
diff --git a/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala
new file mode 100644
index 0000000000..794063fb6d
--- /dev/null
+++ b/core/src/test/scala/spark/util/RateLimitedOutputStreamSuite.scala
@@ -0,0 +1,23 @@
+package spark.util
+
+import org.scalatest.FunSuite
+import java.io.ByteArrayOutputStream
+import java.util.concurrent.TimeUnit._
+
+class RateLimitedOutputStreamSuite extends FunSuite {
+
+ private def benchmark[U](f: => U): Long = {
+ val start = System.nanoTime
+ f
+ System.nanoTime - start
+ }
+
+ test("write") {
+ val underlying = new ByteArrayOutputStream
+ val data = "X" * 41000
+ val stream = new RateLimitedOutputStream(underlying, 10000)
+ val elapsedNs = benchmark { stream.write(data.getBytes("UTF-8")) }
+ assert(SECONDS.convert(elapsedNs, NANOSECONDS) == 4)
+ assert(underlying.toString("UTF-8") == data)
+ }
+}
diff --git a/docs/README.md b/docs/README.md
index 092153070e..887f407f18 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -25,10 +25,12 @@ To mark a block of code in your markdown to be syntax highlighted by jekyll duri
// supported languages too.
{% endhighlight %}
-## Scaladoc
+## API Docs (Scaladoc and Epydoc)
You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory.
-When you run `jekyll` in the docs directory, it will also copy over the scala doc for the various Spark subprojects into the docs directory (and then also into the _site directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc.
+Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the SPARK_PROJECT_ROOT/pyspark directory.
-NOTE: To skip the step of building and copying over the scaladoc when you build the docs, run `SKIP_SCALADOC=1 jekyll`.
+When you run `jekyll` in the docs directory, it will also copy over the scaladoc for the various Spark subprojects into the docs directory (and then also into the _site directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the PySpark docs using [epydoc](http://epydoc.sourceforge.net/).
+
+NOTE: To skip the step of building and copying over the scaladoc when you build the docs, run `SKIP_SCALADOC=1 jekyll`. Similarly, `SKIP_EPYDOC=1 jekyll` will skip PySpark API doc generation.
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 7244ab6fc9..94baa634aa 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -47,10 +47,19 @@
<li><a href="quick-start.html">Quick Start</a></li>
<li><a href="scala-programming-guide.html">Scala</a></li>
<li><a href="java-programming-guide.html">Java</a></li>
+ <li><a href="python-programming-guide.html">Python</a></li>
+ <li><a href="streaming-programming-guide.html">Spark Streaming</a></li>
</ul>
</li>
- <li><a href="api/core/index.html">API (Scaladoc)</a></li>
+ <li class="dropdown">
+ <a href="#" class="dropdown-toggle" data-toggle="dropdown">API (Scaladoc)<b class="caret"></b></a>
+ <ul class="dropdown-menu">
+ <li><a href="api/core/index.html">Spark Scala/Java (Scaladoc)</a></li>
+ <li><a href="api/pyspark/index.html">Spark Python (Epydoc)</a></li>
+ <li><a href="api/streaming/index.html">Spark Streaming Scala/Java (Scaladoc) </a></li>
+ </ul>
+ </li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">Deploying<b class="caret"></b></a>
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index e61c105449..e400dec619 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -2,7 +2,7 @@ require 'fileutils'
include FileUtils
if ENV['SKIP_SCALADOC'] != '1'
- projects = ["core", "examples", "repl", "bagel"]
+ projects = ["core", "examples", "repl", "bagel", "streaming"]
puts "Moving to project root and building scaladoc."
curr_dir = pwd
@@ -11,7 +11,7 @@ if ENV['SKIP_SCALADOC'] != '1'
puts "Running sbt/sbt doc from " + pwd + "; this may take a few minutes..."
puts `sbt/sbt doc`
- puts "moving back into docs dir."
+ puts "Moving back into docs dir."
cd("docs")
# Copy over the scaladoc from each project into the docs directory.
@@ -28,3 +28,20 @@ if ENV['SKIP_SCALADOC'] != '1'
cp_r(source + "/.", dest)
end
end
+
+if ENV['SKIP_EPYDOC'] != '1'
+ puts "Moving to python directory and building epydoc."
+ cd("../python")
+ puts `epydoc --config epydoc.conf`
+
+ puts "Moving back into docs dir."
+ cd("../docs")
+
+ puts "echo making directory pyspark"
+ mkdir_p "pyspark"
+
+ puts "cp -r ../python/docs/. api/pyspark"
+ cp_r("../python/docs/.", "api/pyspark")
+
+ cd("..")
+end
diff --git a/docs/api.md b/docs/api.md
index 43548b223c..e86d07770a 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -5,6 +5,8 @@ title: Spark API documentation (Scaladoc)
Here you can find links to the Scaladoc generated for the Spark sbt subprojects. If the following links don't work, try running `sbt/sbt doc` from the Spark project home directory.
-- [Core](api/core/index.html)
-- [Examples](api/examples/index.html)
+- [Spark](api/core/index.html)
+- [Spark Examples](api/examples/index.html)
+- [Spark Streaming](api/streaming/index.html)
- [Bagel](api/bagel/index.html)
+- [PySpark](api/pyspark/index.html)
diff --git a/docs/configuration.md b/docs/configuration.md
index d8317ea97c..036a0df480 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -198,6 +198,15 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
+ <td>spark.akka.frameSize</td>
+ <td>10</td>
+ <td>
+ Maximum message size to allow in "control plane" communication (for serialized tasks and task
+ results), in MB. Increase this if your tasks need to send back large results to the master
+ (e.g. using <code>collect()</code> on a large dataset).
+ </td>
+</tr>
+<tr>
<td>spark.akka.threads</td>
<td>4</td>
<td>
@@ -206,6 +215,13 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
+ <td>spark.akka.timeout</td>
+ <td>20</td>
+ <td>
+ Communication timeout between Spark nodes.
+ </td>
+</tr>
+<tr>
<td>spark.master.host</td>
<td>(local hostname)</td>
<td>
@@ -219,6 +235,17 @@ Apart from these, the following properties are also available, and may be useful
Port for the master to listen on.
</td>
</tr>
+<tr>
+ <td>spark.cleaner.delay</td>
+ <td>(disable)</td>
+ <td>
+ Duration (minutes) of how long Spark will remember any metadata (stages generated, tasks generated, etc.).
+ Periodic cleanups will ensure that metadata older than this duration will be forgetten. This is
+ useful for running Spark for many hours / days (for example, running 24/7 in case of Spark Streaming
+ applications). Note that any RDD that persists in memory for more than this duration will be cleared as well.
+ </td>
+</tr>
+
</table>
# Configuring Logging
diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md
index 6e1f7fd3b1..931b7a66bd 100644
--- a/docs/ec2-scripts.md
+++ b/docs/ec2-scripts.md
@@ -96,7 +96,9 @@ permissions on your private key file, you can run `launch` with the
`spark-ec2` to attach a persistent EBS volume to each node for
storing the persistent HDFS.
- Finally, if you get errors while running your jobs, look at the slave's logs
- for that job using the Mesos web UI (`http://<master-hostname>:8080`).
+ for that job inside of the Mesos work directory (/mnt/mesos-work). You can
+ also view the status of the cluster using the Mesos web UI
+ (`http://<master-hostname>:8080`).
# Configuration
diff --git a/docs/index.md b/docs/index.md
index ed9953a590..c6ef507cb0 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -7,11 +7,11 @@ title: Spark Overview
TODO(andyk): Rewrite to make the Java API a first class part of the story.
{% endcomment %}
-Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an
-interpreter. It provides clean, language-integrated APIs in Scala and Java, with a rich array of parallel operators. Spark can
-run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager,
+Spark is a MapReduce-like cluster computing framework designed for low-latency iterative jobs and interactive use from an interpreter.
+It provides clean, language-integrated APIs in [Scala](scala-programming-guide.html), [Java](java-programming-guide.html), and [Python](python-programming-guide.html), with a rich array of parallel operators.
+Spark can run on top of the [Apache Mesos](http://incubator.apache.org/mesos/) cluster manager,
[Hadoop YARN](http://hadoop.apache.org/docs/r2.0.1-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html),
-Amazon EC2, or without an independent resource manager ("standalone mode").
+Amazon EC2, or without an independent resource manager ("standalone mode").
# Downloading
@@ -58,7 +58,15 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`).
* [Quick Start](quick-start.html): a quick introduction to the Spark API; start here!
* [Spark Programming Guide](scala-programming-guide.html): an overview of Spark concepts, and details on the Scala API
+* [Streaming Programming Guide](streaming-programming-guide.html): an API preview of Spark Streaming
* [Java Programming Guide](java-programming-guide.html): using Spark from Java
+* [Python Programming Guide](python-programming-guide.html): using Spark from Python
+
+**API Docs:**
+
+* [Spark Java/Scala (Scaladoc)](api/core/index.html)
+* [Spark Python (Epydoc)](api/pyspark/index.html)
+* [Spark Streaming Java/Scala (Scaladoc)](api/streaming/index.html)
**Deployment guides:**
@@ -72,7 +80,6 @@ of `project/SparkBuild.scala`, then rebuilding Spark (`sbt/sbt clean compile`).
* [Configuration](configuration.html): customize Spark via its configuration system
* [Tuning Guide](tuning.html): best practices to optimize performance and memory use
-* [API Docs (Scaladoc)](api/core/index.html)
* [Bagel](bagel-programming-guide.html): an implementation of Google's Pregel on Spark
* [Contributing to Spark](contributing-to-spark.html)
diff --git a/docs/java-programming-guide.md b/docs/java-programming-guide.md
index 188ca4995e..37a906ea1c 100644
--- a/docs/java-programming-guide.md
+++ b/docs/java-programming-guide.md
@@ -75,7 +75,8 @@ class has a single abstract method, `call()`, that must be implemented.
## Storage Levels
RDD [storage level](scala-programming-guide.html#rdd-persistence) constants, such as `MEMORY_AND_DISK`, are
-declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class.
+declared in the [spark.api.java.StorageLevels](api/core/index.html#spark.api.java.StorageLevels) class. To
+define your own storage level, you can use StorageLevels.create(...).
# Other Features
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
new file mode 100644
index 0000000000..a840b9b34b
--- /dev/null
+++ b/docs/python-programming-guide.md
@@ -0,0 +1,110 @@
+---
+layout: global
+title: Python Programming Guide
+---
+
+
+The Spark Python API (PySpark) exposes most of the Spark features available in the Scala version to Python.
+To learn the basics of Spark, we recommend reading through the
+[Scala programming guide](scala-programming-guide.html) first; it should be
+easy to follow even if you don't know Scala.
+This guide will show how to use the Spark features described there in Python.
+
+# Key Differences in the Python API
+
+There are a few key differences between the Python and Scala APIs:
+
+* Python is dynamically typed, so RDDs can hold objects of different types.
+* PySpark does not currently support the following Spark features:
+ - Special functions on RDDs of doubles, such as `mean` and `stdev`
+ - `lookup`
+ - `persist` at storage levels other than `MEMORY_ONLY`
+ - `sample`
+ - `sort`
+
+In PySpark, RDDs support the same methods as their Scala counterparts but take Python functions and return Python collection types.
+Short functions can be passed to RDD methods using Python's [`lambda`](http://www.diveintopython.net/power_of_introspection/lambda_functions.html) syntax:
+
+{% highlight python %}
+logData = sc.textFile(logFile).cache()
+errors = logData.filter(lambda s: 'ERROR' in s.split())
+{% endhighlight %}
+
+You can also pass functions that are defined using the `def` keyword; this is useful for more complicated functions that cannot be expressed using `lambda`:
+
+{% highlight python %}
+def is_error(line):
+ return 'ERROR' in line.split()
+errors = logData.filter(is_error)
+{% endhighlight %}
+
+Functions can access objects in enclosing scopes, although modifications to those objects within RDD methods will not be propagated to other tasks:
+
+{% highlight python %}
+error_keywords = ["Exception", "Error"]
+def is_error(line):
+ words = line.split()
+ return any(keyword in words for keyword in error_keywords)
+errors = logData.filter(is_error)
+{% endhighlight %}
+
+PySpark will automatically ship these functions to workers, along with any objects that they reference.
+Instances of classes will be serialized and shipped to workers by PySpark, but classes themselves cannot be automatically distributed to workers.
+The [Standalone Use](#standalone-use) section describes how to ship code dependencies to workers.
+
+# Installing and Configuring PySpark
+
+PySpark requires Python 2.6 or higher.
+PySpark jobs are executed using a standard cPython interpreter in order to support Python modules that use C extensions.
+We have not tested PySpark with Python 3 or with alternative Python interpreters, such as [PyPy](http://pypy.org/) or [Jython](http://www.jython.org/).
+By default, PySpark's scripts will run programs using `python`; an alternate Python executable may be specified by setting the `PYSPARK_PYTHON` environment variable in `conf/spark-env.sh`.
+
+All of PySpark's library dependencies, including [Py4J](http://py4j.sourceforge.net/), are bundled with PySpark and automatically imported.
+
+Standalone PySpark jobs should be run using the `pyspark` script, which automatically configures the Java and Python environment using the settings in `conf/spark-env.sh`.
+The script automatically adds the `pyspark` package to the `PYTHONPATH`.
+
+
+# Interactive Use
+
+The `pyspark` script launches a Python interpreter that is configured to run PySpark jobs.
+When run without any input files, `pyspark` launches a shell that can be used explore data interactively, which is a simple way to learn the API:
+
+{% highlight python %}
+>>> words = sc.textFile("/usr/share/dict/words")
+>>> words.filter(lambda w: w.startswith("spar")).take(5)
+[u'spar', u'sparable', u'sparada', u'sparadrap', u'sparagrass']
+{% endhighlight %}
+
+By default, the `pyspark` shell creates SparkContext that runs jobs locally.
+To connect to a non-local cluster, set the `MASTER` environment variable.
+For example, to use the `pyspark` shell with a [standalone Spark cluster](spark-standalone.html):
+
+{% highlight shell %}
+$ MASTER=spark://IP:PORT ./pyspark
+{% endhighlight %}
+
+
+# Standalone Use
+
+PySpark can also be used from standalone Python scripts by creating a SparkContext in your script and running the script using `pyspark`.
+The Quick Start guide includes a [complete example](quick-start.html#a-standalone-job-in-python) of a standalone Python job.
+
+Code dependencies can be deployed by listing them in the `pyFiles` option in the SparkContext constructor:
+
+{% highlight python %}
+from pyspark import SparkContext
+sc = SparkContext("local", "Job Name", pyFiles=['MyFile.py', 'lib.zip', 'app.egg'])
+{% endhighlight %}
+
+Files listed here will be added to the `PYTHONPATH` and shipped to remote worker machines.
+Code dependencies can be added to an existing SparkContext using its `addPyFile()` method.
+
+# Where to Go from Here
+
+PySpark includes several sample programs using the Python API in `python/examples`.
+You can run them by passing the files to the `pyspark` script -- for example `./pyspark python/examples/wordcount.py`.
+Each example program prints usage help when run without any arguments.
+
+We currently provide [API documentation](api/pyspark/index.html) for the Python API as Epydoc.
+Many of the RDD method descriptions contain [doctests](http://docs.python.org/2/library/doctest.html) that provide additional usage examples.
diff --git a/docs/quick-start.md b/docs/quick-start.md
index 177cb14551..a4c4c9a8fb 100644
--- a/docs/quick-start.md
+++ b/docs/quick-start.md
@@ -6,7 +6,8 @@ title: Quick Start
* This will become a table of contents (this text will be scraped).
{:toc}
-This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive Scala shell (don't worry if you don't know Scala -- you will not need much for this), then show how to write standalone jobs in Scala and Java. See the [programming guide](scala-programming-guide.html) for a more complete reference.
+This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive Scala shell (don't worry if you don't know Scala -- you will not need much for this), then show how to write standalone jobs in Scala, Java, and Python.
+See the [programming guide](scala-programming-guide.html) for a more complete reference.
To follow along with this guide, you only need to have successfully built Spark on one machine. Simply go into your Spark directory and run:
@@ -200,6 +201,16 @@ To build the job, we also write a Maven `pom.xml` file that lists Spark as a dep
<name>Simple Project</name>
<packaging>jar</packaging>
<version>1.0</version>
+ <repositories>
+ <repository>
+ <id>Spray.cc repository</id>
+ <url>http://repo.spray.cc</url>
+ </repository>
+ <repository>
+ <id>Typesafe repository</id>
+ <url>http://repo.typesafe.com/typesafe/releases</url>
+ </repository>
+ </repositories>
<dependencies>
<dependency> <!-- Spark dependency -->
<groupId>org.spark-project</groupId>
@@ -230,3 +241,40 @@ Lines with a: 8422, Lines with b: 1836
{% endhighlight %}
This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.
+
+# A Standalone Job In Python
+Now we will show how to write a standalone job using the Python API (PySpark).
+
+As an example, we'll create a simple Spark job, `SimpleJob.py`:
+
+{% highlight python %}
+"""SimpleJob.py"""
+from pyspark import SparkContext
+
+logFile = "/var/log/syslog" # Should be some file on your system
+sc = SparkContext("local", "Simple job")
+logData = sc.textFile(logFile).cache()
+
+numAs = logData.filter(lambda s: 'a' in s).count()
+numBs = logData.filter(lambda s: 'b' in s).count()
+
+print "Lines with a: %i, lines with b: %i" % (numAs, numBs)
+{% endhighlight %}
+
+
+This job simply counts the number of lines containing 'a' and the number containing 'b' in a system log file.
+Like in the Scala and Java examples, we use a SparkContext to create RDDs.
+We can pass Python functions to Spark, which are automatically serialized along with any variables that they reference.
+For jobs that use custom classes or third-party libraries, we can add those code dependencies to SparkContext to ensure that they will be available on remote machines; this is described in more detail in the [Python programming guide](python-programming-guide).
+`SimpleJob` is simple enough that we do not need to specify any code dependencies.
+
+We can run this job using the `pyspark` script:
+
+{% highlight python %}
+$ cd $SPARK_HOME
+$ ./pyspark SimpleJob.py
+...
+Lines with a: 8422, Lines with b: 1836
+{% endhighlight python %}
+
+This example only runs the job locally; for a tutorial on running jobs across several machines, see the [Standalone Mode](spark-standalone.html) documentation, and consider using a distributed input source, such as HDFS.
diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md
index 7350eca837..301b330a79 100644
--- a/docs/scala-programming-guide.md
+++ b/docs/scala-programming-guide.md
@@ -301,7 +301,8 @@ We recommend going through the following process to select one:
* Use the replicated storage levels if you want fast fault recovery (e.g. if using Spark to serve requests from a web
application). *All* the storage levels provide full fault tolerance by recomputing lost data, but the replicated ones
let you continue running tasks on the RDD without waiting to recompute a lost partition.
-
+
+If you want to define your own storage level (say, with replication factor of 3 instead of 2), then use the function factor method `apply()` of the [`StorageLevel`](api/core/index.html#spark.storage.StorageLevel$) singleton object.
# Shared Variables
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index e0ba7c35cb..bf296221b8 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -51,11 +51,11 @@ Finally, the following configuration options can be passed to the master and wor
</tr>
<tr>
<td><code>-c CORES</code>, <code>--cores CORES</code></td>
- <td>Number of CPU cores to use (default: all available); only on worker</td>
+ <td>Total CPU cores to allow Spark jobs to use on the machine (default: all available); only on worker</td>
</tr>
<tr>
<td><code>-m MEM</code>, <code>--memory MEM</code></td>
- <td>Amount of memory to use, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker</td>
+ <td>Total amount of memory to allow Spark jobs to use on the machine, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker</td>
</tr>
<tr>
<td><code>-d DIR</code>, <code>--work-dir DIR</code></td>
@@ -66,9 +66,20 @@ Finally, the following configuration options can be passed to the master and wor
# Cluster Launch Scripts
-To launch a Spark standalone cluster with the deploy scripts, you need to set up two files, `conf/spark-env.sh` and `conf/slaves`. The `conf/spark-env.sh` file lets you specify global settings for the master and slave instances, such as memory, or port numbers to bind to, while `conf/slaves` is a list of slave nodes. The system requires that all the slave machines have the same configuration files, so *copy these files to each machine*.
+To launch a Spark standalone cluster with the deploy scripts, you need to create a file called `conf/slaves` in your Spark directory, which should contain the hostnames of all the machines where you would like to start Spark workers, one per line. The master machine must be able to access each of the slave machines via password-less `ssh` (using a private key). For testing, you can just put `localhost` in this file.
-In `conf/spark-env.sh`, you can set the following parameters, in addition to the [standard Spark configuration settings](configuration.html):
+Once you've set up this fine, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`:
+
+- `bin/start-master.sh` - Starts a master instance on the machine the script is executed on.
+- `bin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file.
+- `bin/start-all.sh` - Starts both a master and a number of slaves as described above.
+- `bin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script.
+- `bin/stop-slaves.sh` - Stops the slave instances that were started via `bin/start-slaves.sh`.
+- `bin/stop-all.sh` - Stops both the master and the slaves as described above.
+
+Note that these scripts must be executed on the machine you want to run the Spark master on, not your local machine.
+
+You can optionally configure the cluster further by setting environment variables in `conf/spark-env.sh`. Create this file by starting with the `conf/spark-env.sh.template`, and _copy it to all your worker machines_ for the settings to take effect. The following settings are available:
<table class="table">
<tr><th style="width:21%">Environment Variable</th><th>Meaning</th></tr>
@@ -89,35 +100,23 @@ In `conf/spark-env.sh`, you can set the following parameters, in addition to the
<td>Start the Spark worker on a specific port (default: random)</td>
</tr>
<tr>
+ <td><code>SPARK_WORKER_DIR</code></td>
+ <td>Directory to run jobs in, which will include both logs and scratch space (default: SPARK_HOME/work)</td>
+ </tr>
+ <tr>
<td><code>SPARK_WORKER_CORES</code></td>
- <td>Number of cores to use (default: all available cores)</td>
+ <td>Total number of cores to allow Spark jobs to use on the machine (default: all available cores)</td>
</tr>
<tr>
<td><code>SPARK_WORKER_MEMORY</code></td>
- <td>How much memory to use, e.g. 1000M, 2G (default: total memory minus 1 GB)</td>
+ <td>Total amount of memory to allow Spark jobs to use on the machine, e.g. 1000M, 2G (default: total memory minus 1 GB); note that each job's <i>individual</i> memory is configured using <code>SPARK_MEM</code></td>
</tr>
<tr>
<td><code>SPARK_WORKER_WEBUI_PORT</code></td>
<td>Port for the worker web UI (default: 8081)</td>
</tr>
- <tr>
- <td><code>SPARK_WORKER_DIR</code></td>
- <td>Directory to run jobs in, which will include both logs and scratch space (default: SPARK_HOME/work)</td>
- </tr>
</table>
-In `conf/slaves`, include a list of all machines where you would like to start a Spark worker, one per line. The master machine must be able to access each of the slave machines via password-less `ssh` (using a private key). For testing purposes, you can have a single `localhost` entry in the slaves file.
-
-Once you've set up these configuration files, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`:
-
-- `bin/start-master.sh` - Starts a master instance on the machine the script is executed on.
-- `bin/start-slaves.sh` - Starts a slave instance on each machine specified in the `conf/slaves` file.
-- `bin/start-all.sh` - Starts both a master and a number of slaves as described above.
-- `bin/stop-master.sh` - Stops the master that was started via the `bin/start-master.sh` script.
-- `bin/stop-slaves.sh` - Stops the slave instances that were started via `bin/start-slaves.sh`.
-- `bin/stop-all.sh` - Stops both the master and the slaves as described above.
-
-Note that the scripts must be executed on the machine you want to run the Spark master on, not your local machine.
# Connecting a Job to the Cluster
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
new file mode 100644
index 0000000000..b6da7af654
--- /dev/null
+++ b/docs/streaming-programming-guide.md
@@ -0,0 +1,313 @@
+---
+layout: global
+title: Spark Streaming Programming Guide
+---
+
+* This will become a table of contents (this text will be scraped).
+{:toc}
+
+# Overview
+A Spark Streaming application is very similar to a Spark application; it consists of a *driver program* that runs the user's `main` function and continuous executes various *parallel operations* on input streams of data. The main abstraction Spark Streaming provides is a *discretized stream* (DStream), which is a continuous sequence of RDDs (distributed collection of elements) representing a continuous stream of data. DStreams can created from live incoming data (such as data from a socket, Kafka, etc.) or it can be generated by transformation of existing DStreams using parallel operators like map, reduce, and window. The basic processing model is as follows:
+(i) While a Spark Streaming driver program is running, the system receives data from various sources and and divides the data into batches. Each batch of data is treated as a RDD, that is a immutable and parallel collection of data. These input data RDDs are automatically persisted in memory (serialized by default) and replicated to two nodes for fault-tolerance. This sequence of RDDs is collectively referred to as an InputDStream.
+(ii) Data received by InputDStreams are processed processed using DStream operations. Since all data is represented as RDDs and all DStream operations as RDD operations, data is automatically recovered in the event of node failures.
+
+This guide shows some how to start programming with DStreams.
+
+# Initializing Spark Streaming
+The first thing a Spark Streaming program must do is create a `StreamingContext` object, which tells Spark how to access a cluster. A `StreamingContext` can be created by using
+
+{% highlight scala %}
+new StreamingContext(master, jobName, batchDuration)
+{% endhighlight %}
+
+The `master` parameter is either the [Mesos master URL](running-on-mesos.html) (for running on a cluster)or the special "local" string (for local mode) that is used to create a Spark Context. For more information about this please refer to the [Spark programming guide](scala-programming-guide.html). The `jobName` is the name of the streaming job, which is the same as the jobName used in SparkContext. It is used to identify this job in the Mesos web UI. The `batchDuration` is the size of the batches (as explained earlier). This must be set carefully such the cluster can keep up with the processing of the data streams. Starting with something conservative like 5 seconds maybe a good start. See [Performance Tuning](#setting-the-right-batch-size) section for a detailed discussion.
+
+This constructor creates a SparkContext object using the given `master` and `jobName` parameters. However, if you already have a SparkContext or you need to create a custom SparkContext by specifying list of JARs, then a StreamingContext can be created from the existing SparkContext, by using
+{% highlight scala %}
+new StreamingContext(sparkContext, batchDuration)
+{% endhighlight %}
+
+
+
+# Attaching Input Sources - InputDStreams
+The StreamingContext is used to creating InputDStreams from input sources:
+
+{% highlight scala %}
+// Assuming ssc is the StreamingContext
+ssc.networkStream(hostname, port) // Creates a stream that uses a TCP socket to read data from hostname:port
+ssc.textFileStream(directory) // Creates a stream by monitoring and processing new files in a HDFS directory
+{% endhighlight %}
+
+A complete list of input sources is available in the [StreamingContext API documentation](api/streaming/index.html#spark.streaming.StreamingContext). Data received from these sources can be processed using DStream operations, which are explained next.
+
+
+
+# DStream Operations
+Once an input DStream has been created, you can transform it using _DStream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the DStream by writing data out to an external source.
+
+## Transformations
+
+DStreams support many of the transformations available on normal Spark RDD's:
+
+<table class="table">
+<tr><th style="width:25%">Transformation</th><th>Meaning</th></tr>
+<tr>
+ <td> <b>map</b>(<i>func</i>) </td>
+ <td> Returns a new DStream formed by passing each element of the source through a function <i>func</i>. </td>
+</tr>
+<tr>
+ <td> <b>filter</b>(<i>func</i>) </td>
+ <td> Returns a new stream formed by selecting those elements of the source on which <i>func</i> returns true. </td>
+</tr>
+<tr>
+ <td> <b>flatMap</b>(<i>func</i>) </td>
+ <td> Similar to map, but each input item can be mapped to 0 or more output items (so <i>func</i> should return a Seq rather than a single item). </td>
+</tr>
+<tr>
+ <td> <b>mapPartitions</b>(<i>func</i>) </td>
+ <td> Similar to map, but runs separately on each partition (block) of the DStream, so <i>func</i> must be of type
+ Iterator[T] => Iterator[U] when running on an DStream of type T. </td>
+</tr>
+<tr>
+ <td> <b>union</b>(<i>otherStream</i>) </td>
+ <td> Return a new stream that contains the union of the elements in the source stream and the argument. </td>
+</tr>
+<tr>
+ <td> <b>groupByKey</b>([<i>numTasks</i>]) </td>
+ <td> When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs. <br />
+<b>Note:</b> By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks.
+</td>
+</tr>
+<tr>
+ <td> <b>reduceByKey</b>(<i>func</i>, [<i>numTasks</i>]) </td>
+ <td> When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td>
+</tr>
+<tr>
+ <td> <b>join</b>(<i>otherStream</i>, [<i>numTasks</i>]) </td>
+ <td> When called on streams of type (K, V) and (K, W), returns a stream of (K, (V, W)) pairs with all pairs of elements for each key. </td>
+</tr>
+<tr>
+ <td> <b>cogroup</b>(<i>otherStream</i>, [<i>numTasks</i>]) </td>
+ <td> When called on DStream of type (K, V) and (K, W), returns a DStream of (K, Seq[V], Seq[W]) tuples.</td>
+</tr>
+<tr>
+ <td> <b>reduce</b>(<i>func</i>) </td>
+ <td> Returns a new DStream of single-element RDDs by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel. </td>
+</tr>
+<tr>
+ <td> <b>transform</b>(<i>func</i>) </td>
+ <td> Returns a new DStream by applying func (a RDD-to-RDD function) to every RDD of the stream. This can be used to do arbitrary RDD operations on the DStream. </td>
+</tr>
+</table>
+
+Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a <i>windowDuration</i>, which represents the width of the window and a <i>slideTime</i>, which represents the frequency during which the window is calculated.
+
+<table class="table">
+<tr><th style="width:25%">Transformation</th><th>Meaning</th></tr>
+<tr>
+ <td> <b>window</b>(<i>windowDuration</i>, </i>slideTime</i>) </td>
+ <td> Return a new stream which is computed based on windowed batches of the source stream. <i>windowDuration</i> is the width of the window and <i>slideTime</i> is the frequency during which the window is calculated. Both times must be multiples of the batch interval.
+ </td>
+</tr>
+<tr>
+ <td> <b>countByWindow</b>(<i>windowDuration</i>, </i>slideTime</i>) </td>
+ <td> Return a sliding count of elements in the stream. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
+ </td>
+</tr>
+<tr>
+ <td> <b>reduceByWindow</b>(<i>func</i>, <i>windowDuration</i>, </i>slideDuration</i>) </td>
+ <td> Return a new single-element stream, created by aggregating elements in the stream over a sliding interval using <i>func</i>. The function should be associative so that it can be computed correctly in parallel. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
+ </td>
+</tr>
+<tr>
+ <td> <b>groupByKeyAndWindow</b>(windowDuration, slideDuration, [<i>numTasks</i>])
+ </td>
+ <td> When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs over a sliding window. <br />
+<b>Note:</b> By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
+</td>
+</tr>
+<tr>
+ <td> <b>reduceByKeyAndWindow</b>(<i>func</i>, [<i>numTasks</i>]) </td>
+ <td> When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function over batches within a sliding window. Like in <code>groupByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument.
+ <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
+</td>
+</tr>
+<tr>
+ <td> <b>countByKeyAndWindow</b>([<i>numTasks</i>]) </td>
+ <td> When called on a stream of (K, V) pairs, returns a stream of (K, Int) pairs where the values for each key are the count within a sliding window. Like in <code>countByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument.
+ <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
+</td>
+</tr>
+
+</table>
+
+A complete list of DStream operations is available in the API documentation of [DStream](api/streaming/index.html#spark.streaming.DStream) and [PairDStreamFunctions](api/streaming/index.html#spark.streaming.PairDStreamFunctions).
+
+## Output Operations
+When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined:
+
+<table class="table">
+<tr><th style="width:25%">Operator</th><th>Meaning</th></tr>
+<tr>
+ <td> <b>foreach</b>(<i>func</i>) </td>
+ <td> The fundamental output operator. Applies a function, <i>func</i>, to each RDD generated from the stream. This function should have side effects, such as printing output, saving the RDD to external files, or writing it over the network to an external system. </td>
+</tr>
+
+<tr>
+ <td> <b>print</b>() </td>
+ <td> Prints first ten elements of every batch of data in a DStream on the driver. </td>
+</tr>
+
+<tr>
+ <td> <b>saveAsObjectFiles</b>(<i>prefix</i>, [<i>suffix</i>]) </td>
+ <td> Save this DStream's contents as a <code>SequenceFile</code> of serialized objects. The file name at each batch interval is generated based on <i>prefix</i> and <i>suffix</i>: <i>"prefix-TIME_IN_MS[.suffix]"</i>.
+ </td>
+</tr>
+
+<tr>
+ <td> <b>saveAsTextFiles</b>(<i>prefix</i>, [<i>suffix</i>]) </td>
+ <td> Save this DStream's contents as a text files. The file name at each batch interval is generated based on <i>prefix</i> and <i>suffix</i>: <i>"prefix-TIME_IN_MS[.suffix]"</i>. </td>
+</tr>
+
+<tr>
+ <td> <b>saveAsHadoopFiles</b>(<i>prefix</i>, [<i>suffix</i>]) </td>
+ <td> Save this DStream's contents as a Hadoop file. The file name at each batch interval is generated based on <i>prefix</i> and <i>suffix</i>: <i>"prefix-TIME_IN_MS[.suffix]"</i>. </td>
+</tr>
+
+</table>
+
+## DStream Persistence
+Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple DStream operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`.
+
+Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More information on different persistence levels can be found in [Spark Programming Guide](scala-programming-guide.html#rdd-persistence).
+
+# Starting the Streaming computation
+All the above DStream operations are completely lazy, that is, the operations will start executing only after the context is started by using
+{% highlight scala %}
+ssc.start()
+{% endhighlight %}
+
+Conversely, the computation can be stopped by using
+{% highlight scala %}
+ssc.stop()
+{% endhighlight %}
+
+# Example - NetworkWordCount.scala
+A good example to start off is the spark.streaming.examples.NetworkWordCount. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in <Spark repo>/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala.
+
+{% highlight scala %}
+import spark.streaming.{Seconds, StreamingContext}
+import spark.streaming.StreamingContext._
+...
+
+// Create the context and set up a network input stream to receive from a host:port
+val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1))
+val lines = ssc.networkTextStream(args(1), args(2).toInt)
+
+// Split the lines into words, count them, and print some of the counts on the master
+val words = lines.flatMap(_.split(" "))
+val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+wordCounts.print()
+
+// Start the computation
+ssc.start()
+{% endhighlight %}
+
+To run this example on your local machine, you need to first run a Netcat server by using
+
+{% highlight bash %}
+$ nc -lk 9999
+{% endhighlight %}
+
+Then, in a different terminal, you can start NetworkWordCount by using
+
+{% highlight bash %}
+$ ./run spark.streaming.examples.NetworkWordCount local[2] localhost 9999
+{% endhighlight %}
+
+This will make NetworkWordCount connect to the netcat server. Any lines typed in the terminal running the netcat server will be counted and printed on screen.
+
+<table>
+<td>
+{% highlight bash %}
+# TERMINAL 1
+# RUNNING NETCAT
+
+$ nc -lk 9999
+hello world
+
+
+
+
+
+...
+{% endhighlight %}
+</td>
+<td>
+{% highlight bash %}
+# TERMINAL 2: RUNNING NetworkWordCount
+...
+2012-12-31 18:47:10,446 INFO SparkContext: Job finished: run at ThreadPoolExecutor.java:886, took 0.038817 s
+-------------------------------------------
+Time: 1357008430000 ms
+-------------------------------------------
+(hello,1)
+(world,1)
+
+2012-12-31 18:47:10,447 INFO JobManager: Total delay: 0.44700 s for job 8 (execution: 0.44000 s)
+...
+{% endhighlight %}
+</td>
+</table>
+
+
+
+# Performance Tuning
+Getting the best performance of a Spark Streaming application on a cluster requires a bit of tuning. This section explains a number of the parameters and configurations that can tuned to improve the performance of you application. At a high level, you need to consider two things:
+<ol>
+<li>Reducing the processing time of each batch of data by efficiently using cluster resources.</li>
+<li>Setting the right batch size such that the data processing can keep up with the data ingestion.</li>
+</ol>
+
+## Reducing the Processing Time of each Batch
+There are a number of optimizations that can be done in Spark to minimize the processing time of each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section highlights some of the most important ones.
+
+### Level of Parallelism
+Cluster resources maybe underutilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is 8. You can pass the level of parallelism as an argument (see the [`spark.PairDStreamFunctions`](api/streaming/index.html#spark.PairDStreamFunctions) documentation), or set the system property `spark.default.parallelism` to change the default.
+
+### Data Serialization
+The overhead of data serialization can be significant, especially when sub-second batch sizes are to be achieved. There are two aspects to it.
+* Serialization of RDD data in Spark: Please refer to the detailed discussion on data serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default RDDs are persisted as serialized byte arrays to minimize pauses related to GC.
+* Serialization of input data: To ingest external data into Spark, data received as bytes (say, from the network) needs to deserialized from bytes and re-serialized into Spark's serialization format. Hence, the deserialization overhead of input data may be a bottleneck.
+
+### Task Launching Overheads
+If the number of tasks launched per second is high (say, 50 or more per second), then the overhead of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes:
+* Task Serialization: Using Kryo serialization for serializing tasks can reduced the task sizes, and therefore reduce the time taken to send them to the slaves.
+* Execution mode: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the [Running on Mesos guide](running-on-mesos.html) for more details.
+These changes may reduce batch processing time by 100s of milliseconds, thus allowing sub-second batch size to be viable.
+
+## Setting the Right Batch Size
+For a Spark Streaming application running on a cluster to be stable, the processing of the data streams must keep up with the rate of ingestion of the data streams. Depending on the type of computation, the batch size used may have significant impact on the rate of ingestion that can be sustained by the Spark Streaming application on a fixed cluster resources. For example, let us consider the earlier WordCountNetwork example. For a particular data rate, the system may be able to keep up with reporting word counts every 2 seconds (i.e., batch size of 2 seconds), but not every 500 milliseconds.
+
+A good approach to figure out the right batch size for your application is to test it with a conservative batch size (say, 5-10 seconds) and a low data rate. To verify whether the system is able to keep up with data rate, you can check the value of the end-to-end delay experienced by each processed batch (in the Spark master logs, find the line having the phrase "Total delay"). If the delay is maintained to be less than the batch size, then system is stable. Otherwise, if the delay is continuously increasing, it means that the system is unable to keep up and it therefore unstable. Once you have an idea of a stable configuration, you can try increasing the data rate and/or reducing the batch size. Note that momentary increase in the delay due to temporary data rate increases maybe fine as long as the delay reduces back to a low value (i.e., less than batch size).
+
+## 24/7 Operation
+By default, Spark does not forget any of the metadata (RDDs generated, stages processed, etc.). But for a Spark Streaming application to operate 24/7, it is necessary for Spark to do periodic cleanup of it metadata. This can be enabled by setting the Java system property `spark.cleaner.delay` to the number of minutes you want any metadata to persist. For example, setting `spark.cleaner.delay` to 10 would cause Spark periodically cleanup all metadata and persisted RDDs that are older than 10 minutes. Note, that this property needs to be set before the SparkContext is created.
+
+This value is closely tied with any window operation that is being used. Any window operation would require the input data to be persisted in memory for at least the duration of the window. Hence it is necessary to set the delay to at least the value of the largest window operation used in the Spark Streaming application. If this delay is set too low, the application will throw an exception saying so.
+
+## Memory Tuning
+Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail in the [Tuning Guide](tuning.html). It is recommended that you read that. In this section, we highlight a few customizations that are strongly recommended to minimize GC related pauses in Spark Streaming applications and achieving more consistent batch processing times.
+
+* <b>Default persistence level of DStreams</b>: Unlike RDDs, the default persistence level of DStreams serializes the data in memory (that is, [StorageLevel.MEMORY_ONLY_SER](api/core/index.html#spark.storage.StorageLevel$) for DStream compared to [StorageLevel.MEMORY_ONLY](api/core/index.html#spark.storage.StorageLevel$) for RDDs). Even though keeping the data serialized incurs a higher serialization overheads, it significantly reduces GC pauses.
+
+* <b>Concurrent garbage collector</b>: Using the concurrent mark-and-sweep GC further minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more consistent batch processing times.
+
+# Master Fault-tolerance (Alpha)
+TODO
+
+* Checkpointing of DStream graph
+
+* Recovery from master faults
+
+* Current state and future directions \ No newline at end of file
diff --git a/docs/tuning.md b/docs/tuning.md
index f18de8ff3a..9aaa53cd65 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -33,7 +33,7 @@ in your operations) and performance. It provides two serialization libraries:
Java serialization is flexible but often quite slow, and leads to large
serialized formats for many classes.
* [Kryo serialization](http://code.google.com/p/kryo/wiki/V1Documentation): Spark can also use
- the Kryo library (currently just version 1) to serialize objects more quickly. Kryo is significantly
+ the Kryo library (version 2) to serialize objects more quickly. Kryo is significantly
faster and more compact than Java serialization (often as much as 10x), but does not support all
`Serializable` types and requires you to *register* the classes you'll use in the program in advance
for best performance.
@@ -47,6 +47,8 @@ Finally, to register your classes with Kryo, create a public class that extends
`spark.kryo.registrator` system property to point to it, as follows:
{% highlight scala %}
+import com.esotericsoftware.kryo.Kryo
+
class MyRegistrator extends KryoRegistrator {
override def registerClasses(kryo: Kryo) {
kryo.register(classOf[MyClass1])
@@ -60,7 +62,7 @@ System.setProperty("spark.kryo.registrator", "mypackage.MyRegistrator")
val sc = new SparkContext(...)
{% endhighlight %}
-The [Kryo documentation](http://code.google.com/p/kryo/wiki/V1Documentation) describes more advanced
+The [Kryo documentation](http://code.google.com/p/kryo/) describes more advanced
registration options, such as adding custom serialization code.
If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb`
@@ -147,7 +149,7 @@ the space allocated to the RDD cache to mitigate this.
**Measuring the Impact of GC**
-The first step in GC tuning is to collect statistics on how frequently garbage collection occurs and the amount of
+The first step in GC tuning is to collect statistics on how frequently garbage collection occurs and the amount of
time spent GC. This can be done by adding `-verbose:gc -XX:+PrintGCDetails -XX:+PrintGCTimeStamps` to your
`SPARK_JAVA_OPTS` environment variable. Next time your Spark job is run, you will see messages printed in the worker's logs
each time a garbage collection occurs. Note these logs will be on your cluster's worker nodes (in the `stdout` files in
@@ -155,15 +157,15 @@ their work directories), *not* on your driver program.
**Cache Size Tuning**
-One important configuration parameter for GC is the amount of memory that should be used for
-caching RDDs. By default, Spark uses 66% of the configured memory (`SPARK_MEM`) to cache RDDs. This means that
+One important configuration parameter for GC is the amount of memory that should be used for
+caching RDDs. By default, Spark uses 66% of the configured memory (`SPARK_MEM`) to cache RDDs. This means that
33% of memory is available for any objects created during task execution.
In case your tasks slow down and you find that your JVM is garbage-collecting frequently or running out of
-memory, lowering this value will help reduce the memory consumption. To change this to say 50%, you can call
-`System.setProperty("spark.storage.memoryFraction", "0.5")`. Combined with the use of serialized caching,
-using a smaller cache should be sufficient to mitigate most of the garbage collection problems.
-In case you are interested in further tuning the Java GC, continue reading below.
+memory, lowering this value will help reduce the memory consumption. To change this to say 50%, you can call
+`System.setProperty("spark.storage.memoryFraction", "0.5")`. Combined with the use of serialized caching,
+using a smaller cache should be sufficient to mitigate most of the garbage collection problems.
+In case you are interested in further tuning the Java GC, continue reading below.
**Advanced GC Tuning**
@@ -172,9 +174,9 @@ To further tune garbage collection, we first need to understand some basic infor
* Java Heap space is divided in to two regions Young and Old. The Young generation is meant to hold short-lived objects
while the Old generation is intended for objects with longer lifetimes.
-* The Young generation is further divided into three regions [Eden, Survivor1, Survivor2].
+* The Young generation is further divided into three regions [Eden, Survivor1, Survivor2].
-* A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects
+* A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects
that are alive from Eden and Survivor1 are copied to Survivor2. The Survivor regions are swapped. If an object is old
enough or Survivor2 is full, it is moved to Old. Finally when Old is close to full, a full GC is invoked.
@@ -186,7 +188,7 @@ temporary objects created during task execution. Some steps which may be useful
before a task completes, it means that there isn't enough memory available for executing tasks.
* In the GC stats that are printed, if the OldGen is close to being full, reduce the amount of memory used for caching.
- This can be done using the `spark.storage.memoryFraction` property. It is better to cache fewer objects than to slow
+ This can be done using the `spark.storage.memoryFraction` property. It is better to cache fewer objects than to slow
down task execution!
* If there are too many minor collections but not many major GCs, allocating more memory for Eden would help. You
@@ -195,8 +197,8 @@ temporary objects created during task execution. Some steps which may be useful
up by 4/3 is to account for space used by survivor regions as well.)
* As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using
- the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the
- size of the block. So if we wish to have 3 or 4 tasks worth of working space, and the HDFS block size is 64 MB,
+ the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the
+ size of the block. So if we wish to have 3 or 4 tasks worth of working space, and the HDFS block size is 64 MB,
we can estimate size of Eden to be `4*3*64MB`.
* Monitor how the frequency and time taken by garbage collection changes with the new settings.
diff --git a/examples/pom.xml b/examples/pom.xml
index 782c026d73..4d43103475 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -19,6 +19,11 @@
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.twitter4j</groupId>
+ <artifactId>twitter4j-stream</artifactId>
+ <version>3.0.3</version>
+ </dependency>
<dependency>
<groupId>org.scalatest</groupId>
@@ -45,6 +50,11 @@
<profiles>
<profile>
<id>hadoop1</id>
+ <activation>
+ <property>
+ <name>!hadoopVersion</name>
+ </property>
+ </activation>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
@@ -53,6 +63,12 @@
<classifier>hadoop1</classifier>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop1</classifier>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
@@ -72,6 +88,12 @@
</profile>
<profile>
<id>hadoop2</id>
+ <activation>
+ <property>
+ <name>hadoopVersion</name>
+ <value>2</value>
+ </property>
+ </activation>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
@@ -80,6 +102,12 @@
<classifier>hadoop2</classifier>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2</classifier>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
diff --git a/examples/src/main/scala/spark/examples/LocalLR.scala b/examples/src/main/scala/spark/examples/LocalLR.scala
index f2ac2b3e06..9553162004 100644
--- a/examples/src/main/scala/spark/examples/LocalLR.scala
+++ b/examples/src/main/scala/spark/examples/LocalLR.scala
@@ -5,7 +5,7 @@ import spark.util.Vector
object LocalLR {
val N = 10000 // Number of data points
- val D = 10 // Numer of dimensions
+ val D = 10 // Number of dimensions
val R = 0.7 // Scaling factor
val ITERATIONS = 5
val rand = new Random(42)
diff --git a/examples/src/main/scala/spark/examples/SparkALS.scala b/examples/src/main/scala/spark/examples/SparkALS.scala
index fb28e2c932..5e01885dbb 100644
--- a/examples/src/main/scala/spark/examples/SparkALS.scala
+++ b/examples/src/main/scala/spark/examples/SparkALS.scala
@@ -7,6 +7,7 @@ import cern.jet.math._
import cern.colt.matrix._
import cern.colt.matrix.linalg._
import spark._
+import scala.Option
object SparkALS {
// Parameters set through command line arguments
@@ -42,7 +43,7 @@ object SparkALS {
return sqrt(sumSqs / (M * U))
}
- def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
+ def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
R: DoubleMatrix2D) : DoubleMatrix1D =
{
val U = us.size
@@ -68,50 +69,30 @@ object SparkALS {
return solved2D.viewColumn(0)
}
- def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
- R: DoubleMatrix2D) : DoubleMatrix1D =
- {
- val M = ms.size
- val F = ms(0).size
- val XtX = factory2D.make(F, F)
- val Xty = factory1D.make(F)
- // For each movie that the user rated
- for (i <- 0 until M) {
- val m = ms(i)
- // Add m * m^t to XtX
- blas.dger(1, m, m, XtX)
- // Add m * rating to Xty
- blas.daxpy(R.get(i, j), m, Xty)
- }
- // Add regularization coefs to diagonal terms
- for (d <- 0 until F) {
- XtX.set(d, d, XtX.get(d, d) + LAMBDA * M)
- }
- // Solve it with Cholesky
- val ch = new CholeskyDecomposition(XtX)
- val Xty2D = factory2D.make(Xty.toArray, F)
- val solved2D = ch.solve(Xty2D)
- return solved2D.viewColumn(0)
- }
-
def main(args: Array[String]) {
var host = ""
var slices = 0
- args match {
- case Array(m, u, f, iters, slices_, host_) => {
- M = m.toInt
- U = u.toInt
- F = f.toInt
- ITERATIONS = iters.toInt
- slices = slices_.toInt
- host = host_
+
+ (0 to 5).map(i => {
+ i match {
+ case a if a < args.length => Some(args(a))
+ case _ => None
+ }
+ }).toArray match {
+ case Array(host_, m, u, f, iters, slices_) => {
+ host = host_ getOrElse "local"
+ M = (m getOrElse "100").toInt
+ U = (u getOrElse "500").toInt
+ F = (f getOrElse "10").toInt
+ ITERATIONS = (iters getOrElse "5").toInt
+ slices = (slices_ getOrElse "2").toInt
}
case _ => {
- System.err.println("Usage: SparkALS <M> <U> <F> <iters> <slices> <master>")
+ System.err.println("Usage: SparkALS [<master> <M> <U> <F> <iters> <slices>]")
System.exit(1)
}
}
- printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS);
+ printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS)
val spark = new SparkContext(host, "SparkALS")
val R = generateR()
@@ -127,11 +108,11 @@ object SparkALS {
for (iter <- 1 to ITERATIONS) {
println("Iteration " + iter + ":")
ms = spark.parallelize(0 until M, slices)
- .map(i => updateMovie(i, msc.value(i), usc.value, Rc.value))
+ .map(i => update(i, msc.value(i), usc.value, Rc.value))
.toArray
msc = spark.broadcast(ms) // Re-broadcast ms because it was updated
us = spark.parallelize(0 until U, slices)
- .map(i => updateUser(i, usc.value(i), msc.value, Rc.value))
+ .map(i => update(i, usc.value(i), msc.value, algebra.transpose(Rc.value)))
.toArray
usc = spark.broadcast(us) // Re-broadcast us because it was updated
println("RMSE = " + rmse(R, ms, us))
diff --git a/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala
new file mode 100644
index 0000000000..461929fba2
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/FlumeEventCount.scala
@@ -0,0 +1,43 @@
+package spark.streaming.examples
+
+import spark.util.IntParam
+import spark.storage.StorageLevel
+import spark.streaming._
+
+/**
+ * Produces a count of events received from Flume.
+ *
+ * This should be used in conjunction with an AvroSink in Flume. It will start
+ * an Avro server on at the request host:port address and listen for requests.
+ * Your Flume AvroSink should be pointed to this address.
+ *
+ * Usage: FlumeEventCount <master> <host> <port>
+ *
+ * <master> is a Spark master URL
+ * <host> is the host the Flume receiver will be started on - a receiver
+ * creates a server and listens for flume events.
+ * <port> is the port the Flume receiver will listen on.
+ */
+object FlumeEventCount {
+ def main(args: Array[String]) {
+ if (args.length != 3) {
+ System.err.println(
+ "Usage: FlumeEventCount <master> <host> <port>")
+ System.exit(1)
+ }
+
+ val Array(master, host, IntParam(port)) = args
+
+ val batchInterval = Milliseconds(2000)
+ // Create the context and set the batch size
+ val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval)
+
+ // Create a flume stream
+ val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY)
+
+ // Print out the count of events received from this server in each batch
+ stream.count().map(cnt => "Received " + cnt + " flume events." ).print()
+
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala
new file mode 100644
index 0000000000..8530f5c175
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/HdfsWordCount.scala
@@ -0,0 +1,36 @@
+package spark.streaming.examples
+
+import spark.streaming.{Seconds, StreamingContext}
+import spark.streaming.StreamingContext._
+
+
+/**
+ * Counts words in new text files created in the given directory
+ * Usage: HdfsWordCount <master> <directory>
+ * <master> is the Spark master URL.
+ * <directory> is the directory that Spark Streaming will use to find and read new text files.
+ *
+ * To run this on your local machine on directory `localdir`, run this example
+ * `$ ./run spark.streaming.examples.HdfsWordCount local[2] localdir`
+ * Then create a text file in `localdir` and the words in the file will get counted.
+ */
+object HdfsWordCount {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println("Usage: HdfsWordCount <master> <directory>")
+ System.exit(1)
+ }
+
+ // Create the context
+ val ssc = new StreamingContext(args(0), "HdfsWordCount", Seconds(2))
+
+ // Create the FileInputDStream on the directory and use the
+ // stream to count words in new files created
+ val lines = ssc.textFileStream(args(1))
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.print()
+ ssc.start()
+ }
+}
+
diff --git a/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java
new file mode 100644
index 0000000000..cddce16e39
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java
@@ -0,0 +1,50 @@
+package spark.streaming.examples;
+
+import spark.api.java.function.Function;
+import spark.streaming.*;
+import spark.streaming.api.java.*;
+import spark.streaming.dstream.SparkFlumeEvent;
+
+/**
+ * Produces a count of events received from Flume.
+ *
+ * This should be used in conjunction with an AvroSink in Flume. It will start
+ * an Avro server on at the request host:port address and listen for requests.
+ * Your Flume AvroSink should be pointed to this address.
+ *
+ * Usage: JavaFlumeEventCount <master> <host> <port>
+ *
+ * <master> is a Spark master URL
+ * <host> is the host the Flume receiver will be started on - a receiver
+ * creates a server and listens for flume events.
+ * <port> is the port the Flume receiver will listen on.
+ */
+public class JavaFlumeEventCount {
+ public static void main(String[] args) {
+ if (args.length != 3) {
+ System.err.println("Usage: JavaFlumeEventCount <master> <host> <port>");
+ System.exit(1);
+ }
+
+ String master = args[0];
+ String host = args[1];
+ int port = Integer.parseInt(args[2]);
+
+ Duration batchInterval = new Duration(2000);
+
+ JavaStreamingContext sc = new JavaStreamingContext(master, "FlumeEventCount", batchInterval);
+
+ JavaDStream<SparkFlumeEvent> flumeStream = sc.flumeStream("localhost", port);
+
+ flumeStream.count();
+
+ flumeStream.count().map(new Function<Long, String>() {
+ @Override
+ public String call(Long in) {
+ return "Received " + in + " flume events.";
+ }
+ }).print();
+
+ sc.start();
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java
new file mode 100644
index 0000000000..4299febfd6
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java
@@ -0,0 +1,62 @@
+package spark.streaming.examples;
+
+import com.google.common.collect.Lists;
+import scala.Tuple2;
+import spark.api.java.function.FlatMapFunction;
+import spark.api.java.function.Function2;
+import spark.api.java.function.PairFunction;
+import spark.streaming.Duration;
+import spark.streaming.api.java.JavaDStream;
+import spark.streaming.api.java.JavaPairDStream;
+import spark.streaming.api.java.JavaStreamingContext;
+
+/**
+ * Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ * Usage: NetworkWordCount <master> <hostname> <port>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ * `$ nc -lk 9999`
+ * and then run the example
+ * `$ ./run spark.streaming.examples.JavaNetworkWordCount local[2] localhost 9999`
+ */
+public class JavaNetworkWordCount {
+ public static void main(String[] args) {
+ if (args.length < 2) {
+ System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" +
+ "In local mode, <master> should be 'local[n]' with n > 1");
+ System.exit(1);
+ }
+
+ // Create the context with a 1 second batch size
+ JavaStreamingContext ssc = new JavaStreamingContext(
+ args[0], "NetworkWordCount", new Duration(1000));
+
+ // Create a NetworkInputDStream on target ip:port and count the
+ // words in input stream of \n delimited test (eg. generated by 'nc')
+ JavaDStream<String> lines = ssc.networkTextStream(args[1], Integer.parseInt(args[2]));
+ JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
+ @Override
+ public Iterable<String> call(String x) {
+ return Lists.newArrayList(x.split(" "));
+ }
+ });
+ JavaPairDStream<String, Integer> wordCounts = words.map(
+ new PairFunction<String, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(String s) throws Exception {
+ return new Tuple2<String, Integer>(s, 1);
+ }
+ }).reduceByKey(new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer i1, Integer i2) throws Exception {
+ return i1 + i2;
+ }
+ });
+
+ wordCounts.print();
+ ssc.start();
+
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java
new file mode 100644
index 0000000000..43c3cd4dfa
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java
@@ -0,0 +1,62 @@
+package spark.streaming.examples;
+
+import com.google.common.collect.Lists;
+import scala.Tuple2;
+import spark.api.java.JavaRDD;
+import spark.api.java.function.Function2;
+import spark.api.java.function.PairFunction;
+import spark.streaming.Duration;
+import spark.streaming.api.java.JavaDStream;
+import spark.streaming.api.java.JavaPairDStream;
+import spark.streaming.api.java.JavaStreamingContext;
+
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+
+public class JavaQueueStream {
+ public static void main(String[] args) throws InterruptedException {
+ if (args.length < 1) {
+ System.err.println("Usage: JavaQueueStream <master>");
+ System.exit(1);
+ }
+
+ // Create the context
+ JavaStreamingContext ssc = new JavaStreamingContext(args[0], "QueueStream", new Duration(1000));
+
+ // Create the queue through which RDDs can be pushed to
+ // a QueueInputDStream
+ Queue<JavaRDD<Integer>> rddQueue = new LinkedList<JavaRDD<Integer>>();
+
+ // Create and push some RDDs into the queue
+ List<Integer> list = Lists.newArrayList();
+ for (int i = 0; i < 1000; i++) {
+ list.add(i);
+ }
+
+ for (int i = 0; i < 30; i++) {
+ rddQueue.add(ssc.sc().parallelize(list));
+ }
+
+
+ // Create the QueueInputDStream and use it do some processing
+ JavaDStream<Integer> inputStream = ssc.queueStream(rddQueue);
+ JavaPairDStream<Integer, Integer> mappedStream = inputStream.map(
+ new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i % 10, 1);
+ }
+ });
+ JavaPairDStream<Integer, Integer> reducedStream = mappedStream.reduceByKey(
+ new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer i1, Integer i2) throws Exception {
+ return i1 + i2;
+ }
+ });
+
+ reducedStream.print();
+ ssc.start();
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
new file mode 100644
index 0000000000..fe55db6e2c
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala
@@ -0,0 +1,69 @@
+package spark.streaming.examples
+
+import java.util.Properties
+import kafka.message.Message
+import kafka.producer.SyncProducerConfig
+import kafka.producer._
+import spark.SparkContext
+import spark.streaming._
+import spark.streaming.StreamingContext._
+import spark.storage.StorageLevel
+import spark.streaming.util.RawTextHelper._
+
+object KafkaWordCount {
+ def main(args: Array[String]) {
+
+ if (args.length < 6) {
+ System.err.println("Usage: KafkaWordCount <master> <hostname> <port> <group> <topics> <numThreads>")
+ System.exit(1)
+ }
+
+ val Array(master, hostname, port, group, topics, numThreads) = args
+
+ val sc = new SparkContext(master, "KafkaWordCount")
+ val ssc = new StreamingContext(sc, Seconds(2))
+ ssc.checkpoint("checkpoint")
+
+ val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
+ val lines = ssc.kafkaStream[String](hostname, port.toInt, group, topicpMap)
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
+ wordCounts.print()
+
+ ssc.start()
+ }
+}
+
+// Produces some random words between 1 and 100.
+object KafkaWordCountProducer {
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: KafkaWordCountProducer <hostname> <port> <topic> <messagesPerSec> <wordsPerMessage>")
+ System.exit(1)
+ }
+
+ val Array(hostname, port, topic, messagesPerSec, wordsPerMessage) = args
+
+ // Zookeper connection properties
+ val props = new Properties()
+ props.put("zk.connect", hostname + ":" + port)
+ props.put("serializer.class", "kafka.serializer.StringEncoder")
+
+ val config = new ProducerConfig(props)
+ val producer = new Producer[String, String](config)
+
+ // Send some messages
+ while(true) {
+ val messages = (1 to messagesPerSec.toInt).map { messageNum =>
+ (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ")
+ }.toArray
+ println(messages.mkString(","))
+ val data = new ProducerData[String, String](topic, messages)
+ producer.send(data)
+ Thread.sleep(100)
+ }
+ }
+
+}
+
diff --git a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala
new file mode 100644
index 0000000000..32f7d57bea
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala
@@ -0,0 +1,36 @@
+package spark.streaming.examples
+
+import spark.streaming.{Seconds, StreamingContext}
+import spark.streaming.StreamingContext._
+
+/**
+ * Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ * Usage: NetworkWordCount <master> <hostname> <port>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ * `$ nc -lk 9999`
+ * and then run the example
+ * `$ ./run spark.streaming.examples.NetworkWordCount local[2] localhost 9999`
+ */
+object NetworkWordCount {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" +
+ "In local mode, <master> should be 'local[n]' with n > 1")
+ System.exit(1)
+ }
+
+ // Create the context with a 1 second batch size
+ val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1))
+
+ // Create a NetworkInputDStream on target ip:port and count the
+ // words in input stream of \n delimited test (eg. generated by 'nc')
+ val lines = ssc.networkTextStream(args(1), args(2).toInt)
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.print()
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/QueueStream.scala b/examples/src/main/scala/spark/streaming/examples/QueueStream.scala
new file mode 100644
index 0000000000..2a265d021d
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/QueueStream.scala
@@ -0,0 +1,39 @@
+package spark.streaming.examples
+
+import spark.RDD
+import spark.streaming.{Seconds, StreamingContext}
+import spark.streaming.StreamingContext._
+
+import scala.collection.mutable.SynchronizedQueue
+
+object QueueStream {
+
+ def main(args: Array[String]) {
+ if (args.length < 1) {
+ System.err.println("Usage: QueueStream <master>")
+ System.exit(1)
+ }
+
+ // Create the context
+ val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1))
+
+ // Create the queue through which RDDs can be pushed to
+ // a QueueInputDStream
+ val rddQueue = new SynchronizedQueue[RDD[Int]]()
+
+ // Create the QueueInputDStream and use it do some processing
+ val inputStream = ssc.queueStream(rddQueue)
+ val mappedStream = inputStream.map(x => (x % 10, 1))
+ val reducedStream = mappedStream.reduceByKey(_ + _)
+ reducedStream.print()
+ ssc.start()
+
+ // Create and push some RDDs into
+ for (i <- 1 to 30) {
+ rddQueue += ssc.sc.makeRDD(1 to 1000, 10)
+ Thread.sleep(1000)
+ }
+ ssc.stop()
+ System.exit(0)
+ }
+} \ No newline at end of file
diff --git a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala
new file mode 100644
index 0000000000..2eec777c54
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala
@@ -0,0 +1,46 @@
+package spark.streaming.examples
+
+import spark.util.IntParam
+import spark.storage.StorageLevel
+
+import spark.streaming._
+import spark.streaming.util.RawTextHelper
+
+/**
+ * Receives text from multiple rawNetworkStreams and counts how many '\n' delimited
+ * lines have the word 'the' in them. This is useful for benchmarking purposes. This
+ * will only work with spark.streaming.util.RawTextSender running on all worker nodes
+ * and with Spark using Kryo serialization (set Java property "spark.serializer" to
+ * "spark.KryoSerializer").
+ * Usage: RawNetworkGrep <master> <numStreams> <host> <port> <batchMillis>
+ * <master> is the Spark master URL
+ * <numStream> is the number rawNetworkStreams, which should be same as number
+ * of work nodes in the cluster
+ * <host> is "localhost".
+ * <port> is the port on which RawTextSender is running in the worker nodes.
+ * <batchMillise> is the Spark Streaming batch duration in milliseconds.
+ */
+
+object RawNetworkGrep {
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ System.err.println("Usage: RawNetworkGrep <master> <numStreams> <host> <port> <batchMillis>")
+ System.exit(1)
+ }
+
+ val Array(master, IntParam(numStreams), host, IntParam(port), IntParam(batchMillis)) = args
+
+ // Create the context
+ val ssc = new StreamingContext(master, "RawNetworkGrep", Milliseconds(batchMillis))
+
+ // Warm up the JVMs on master and slave for JIT compilation to kick in
+ RawTextHelper.warmUp(ssc.sc)
+
+ val rawStreams = (1 to numStreams).map(_ =>
+ ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray
+ val union = ssc.union(rawStreams)
+ union.filter(_.contains("the")).count().foreach(r =>
+ println("Grep count: " + r.collect().mkString))
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala
new file mode 100644
index 0000000000..4c6e08bc74
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewGenerator.scala
@@ -0,0 +1,85 @@
+package spark.streaming.examples.clickstream
+
+import java.net.{InetAddress,ServerSocket,Socket,SocketException}
+import java.io.{InputStreamReader, BufferedReader, PrintWriter}
+import util.Random
+
+/** Represents a page view on a website with associated dimension data.*/
+class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) {
+ override def toString() : String = {
+ "%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID)
+ }
+}
+object PageView {
+ def fromString(in : String) : PageView = {
+ val parts = in.split("\t")
+ new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt)
+ }
+}
+
+/** Generates streaming events to simulate page views on a website.
+ *
+ * This should be used in tandem with PageViewStream.scala. Example:
+ * $ ./run spark.streaming.examples.clickstream.PageViewGenerator 44444 10
+ * $ ./run spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444
+ * */
+object PageViewGenerator {
+ val pages = Map("http://foo.com/" -> .7,
+ "http://foo.com/news" -> 0.2,
+ "http://foo.com/contact" -> .1)
+ val httpStatus = Map(200 -> .95,
+ 404 -> .05)
+ val userZipCode = Map(94709 -> .5,
+ 94117 -> .5)
+ val userID = Map((1 to 100).map(_ -> .01):_*)
+
+
+ def pickFromDistribution[T](inputMap : Map[T, Double]) : T = {
+ val rand = new Random().nextDouble()
+ var total = 0.0
+ for ((item, prob) <- inputMap) {
+ total = total + prob
+ if (total > rand) {
+ return item
+ }
+ }
+ return inputMap.take(1).head._1 // Shouldn't get here if probabilities add up to 1.0
+ }
+
+ def getNextClickEvent() : String = {
+ val id = pickFromDistribution(userID)
+ val page = pickFromDistribution(pages)
+ val status = pickFromDistribution(httpStatus)
+ val zipCode = pickFromDistribution(userZipCode)
+ new PageView(page, status, zipCode, id).toString()
+ }
+
+ def main(args : Array[String]) {
+ if (args.length != 2) {
+ System.err.println("Usage: PageViewGenerator <port> <viewsPerSecond>")
+ System.exit(1)
+ }
+ val port = args(0).toInt
+ val viewsPerSecond = args(1).toFloat
+ val sleepDelayMs = (1000.0 / viewsPerSecond).toInt
+ val listener = new ServerSocket(port)
+ println("Listening on port: " + port)
+
+ while (true) {
+ val socket = listener.accept()
+ new Thread() {
+ override def run = {
+ println("Got client connected from: " + socket.getInetAddress)
+ val out = new PrintWriter(socket.getOutputStream(), true)
+
+ while (true) {
+ Thread.sleep(sleepDelayMs)
+ out.write(getNextClickEvent())
+ out.flush()
+ }
+ socket.close()
+ }
+ }.start()
+ }
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala
new file mode 100644
index 0000000000..a191321d91
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala
@@ -0,0 +1,84 @@
+package spark.streaming.examples.clickstream
+
+import spark.streaming.{Seconds, StreamingContext}
+import spark.streaming.StreamingContext._
+import spark.SparkContext._
+
+/** Analyses a streaming dataset of web page views. This class demonstrates several types of
+ * operators available in Spark streaming.
+ *
+ * This should be used in tandem with PageViewStream.scala. Example:
+ * $ ./run spark.streaming.examples.clickstream.PageViewGenerator 44444 10
+ * $ ./run spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444
+ * */
+object PageViewStream {
+ def main(args: Array[String]) {
+ if (args.length != 3) {
+ System.err.println("Usage: PageViewStream <metric> <host> <port>")
+ System.err.println("<metric> must be one of pageCounts, slidingPageCounts," +
+ " errorRatePerZipCode, activeUserCount, popularUsersSeen")
+ System.exit(1)
+ }
+ val metric = args(0)
+ val host = args(1)
+ val port = args(2).toInt
+
+ // Create the context
+ val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1))
+
+ // Create a NetworkInputDStream on target host:port and convert each line to a PageView
+ val pageViews = ssc.networkTextStream(host, port)
+ .flatMap(_.split("\n"))
+ .map(PageView.fromString(_))
+
+ // Return a count of views per URL seen in each batch
+ val pageCounts = pageViews.map(view => ((view.url, 1))).countByKey()
+
+ // Return a sliding window of page views per URL in the last ten seconds
+ val slidingPageCounts = pageViews.map(view => ((view.url, 1)))
+ .window(Seconds(10), Seconds(2))
+ .countByKey()
+
+
+ // Return the rate of error pages (a non 200 status) in each zip code over the last 30 seconds
+ val statusesPerZipCode = pageViews.window(Seconds(30), Seconds(2))
+ .map(view => ((view.zipCode, view.status)))
+ .groupByKey()
+ val errorRatePerZipCode = statusesPerZipCode.map{
+ case(zip, statuses) =>
+ val normalCount = statuses.filter(_ == 200).size
+ val errorCount = statuses.size - normalCount
+ val errorRatio = errorCount.toFloat / statuses.size
+ if (errorRatio > 0.05) {"%s: **%s**".format(zip, errorRatio)}
+ else {"%s: %s".format(zip, errorRatio)}
+ }
+
+ // Return the number unique users in last 15 seconds
+ val activeUserCount = pageViews.window(Seconds(15), Seconds(2))
+ .map(view => (view.userID, 1))
+ .groupByKey()
+ .count()
+ .map("Unique active users: " + _)
+
+ // An external dataset we want to join to this stream
+ val userList = ssc.sc.parallelize(
+ Map(1 -> "Patrick Wendell", 2->"Reynold Xin", 3->"Matei Zaharia").toSeq)
+
+ metric match {
+ case "pageCounts" => pageCounts.print()
+ case "slidingPageCounts" => slidingPageCounts.print()
+ case "errorRatePerZipCode" => errorRatePerZipCode.print()
+ case "activeUserCount" => activeUserCount.print()
+ case "popularUsersSeen" =>
+ // Look for users in our existing dataset and print it out if we have a match
+ pageViews.map(view => (view.userID, 1))
+ .foreach((rdd, time) => rdd.join(userList)
+ .map(_._2._2)
+ .take(10)
+ .foreach(u => println("Saw user %s at time %s".format(u, time))))
+ case _ => println("Invalid metric entered: " + metric)
+ }
+
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala
new file mode 100644
index 0000000000..377bc0c98e
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala
@@ -0,0 +1,60 @@
+package spark.streaming.examples.twitter
+
+import spark.streaming.StreamingContext._
+import spark.streaming.{Seconds, StreamingContext}
+import spark.SparkContext._
+import spark.storage.StorageLevel
+
+/**
+ * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter
+ * stream. The stream is instantiated with credentials and optionally filters supplied by the
+ * command line arguments.
+ */
+object TwitterBasic {
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: TwitterBasic <master> <twitter_username> <twitter_password>" +
+ " [filter1] [filter2] ... [filter n]")
+ System.exit(1)
+ }
+
+ val Array(master, username, password) = args.slice(0, 3)
+ val filters = args.slice(3, args.length)
+
+ val ssc = new StreamingContext(master, "TwitterBasic", Seconds(2))
+ val stream = new TwitterInputDStream(ssc, username, password, filters,
+ StorageLevel.MEMORY_ONLY_SER)
+ ssc.registerInputStream(stream)
+
+ val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#")))
+
+ val topCounts60 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(60))
+ .map{case (topic, count) => (count, topic)}
+ .transform(_.sortByKey(false))
+
+ val topCounts10 = hashTags.map((_, 1)).reduceByKeyAndWindow(_ + _, Seconds(10))
+ .map{case (topic, count) => (count, topic)}
+ .transform(_.sortByKey(false))
+
+
+ // Print popular hashtags
+ topCounts60.foreach(rdd => {
+ if (rdd.count() != 0) {
+ val topList = rdd.take(5)
+ println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count()))
+ topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))}
+ }
+ })
+
+ topCounts10.foreach(rdd => {
+ if (rdd.count() != 0) {
+ val topList = rdd.take(5)
+ println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count()))
+ topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))}
+ }
+ })
+
+ ssc.start()
+ }
+
+}
diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala
new file mode 100644
index 0000000000..99ed4cdc1c
--- /dev/null
+++ b/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala
@@ -0,0 +1,71 @@
+package spark.streaming.examples.twitter
+
+import spark._
+import spark.streaming._
+import dstream.{NetworkReceiver, NetworkInputDStream}
+import storage.StorageLevel
+import twitter4j._
+import twitter4j.auth.BasicAuthorization
+import collection.JavaConversions._
+
+/* A stream of Twitter statuses, potentially filtered by one or more keywords.
+*
+* @constructor create a new Twitter stream using the supplied username and password to authenticate.
+* An optional set of string filters can be used to restrict the set of tweets. The Twitter API is
+* such that this may return a sampled subset of all tweets during each interval.
+*/
+class TwitterInputDStream(
+ @transient ssc_ : StreamingContext,
+ username: String,
+ password: String,
+ filters: Seq[String],
+ storageLevel: StorageLevel
+ ) extends NetworkInputDStream[Status](ssc_) {
+
+ override def createReceiver(): NetworkReceiver[Status] = {
+ new TwitterReceiver(username, password, filters, storageLevel)
+ }
+}
+
+class TwitterReceiver(
+ username: String,
+ password: String,
+ filters: Seq[String],
+ storageLevel: StorageLevel
+ ) extends NetworkReceiver[Status] {
+
+ var twitterStream: TwitterStream = _
+ lazy val blockGenerator = new BlockGenerator(storageLevel)
+
+ protected override def onStart() {
+ blockGenerator.start()
+ twitterStream = new TwitterStreamFactory()
+ .getInstance(new BasicAuthorization(username, password))
+ twitterStream.addListener(new StatusListener {
+ def onStatus(status: Status) = {
+ blockGenerator += status
+ }
+ // Unimplemented
+ def onDeletionNotice(statusDeletionNotice: StatusDeletionNotice) {}
+ def onTrackLimitationNotice(i: Int) {}
+ def onScrubGeo(l: Long, l1: Long) {}
+ def onStallWarning(stallWarning: StallWarning) {}
+ def onException(e: Exception) {}
+ })
+
+ val query: FilterQuery = new FilterQuery
+ if (filters.size > 0) {
+ query.track(filters.toArray)
+ twitterStream.filter(query)
+ } else {
+ twitterStream.sample()
+ }
+ logInfo("Twitter receiver started")
+ }
+
+ protected override def onStop() {
+ blockGenerator.stop()
+ twitterStream.shutdown()
+ logInfo("Twitter receiver stopped")
+ }
+}
diff --git a/pom.xml b/pom.xml
index 52a4e9d932..3ea989a082 100644
--- a/pom.xml
+++ b/pom.xml
@@ -41,6 +41,7 @@
<module>core</module>
<module>bagel</module>
<module>examples</module>
+ <module>streaming</module>
<module>repl</module>
<module>repl-bin</module>
</modules>
@@ -54,6 +55,7 @@
<mesos.version>0.9.0-incubating</mesos.version>
<akka.version>2.0.3</akka.version>
<spray.version>1.0-M2.1</spray.version>
+ <spray.json.version>1.1.1</spray.json.version>
<slf4j.version>1.6.1</slf4j.version>
<cdh.version>4.1.2</cdh.version>
</properties>
@@ -103,6 +105,17 @@
<enabled>false</enabled>
</snapshots>
</repository>
+ <repository>
+ <id>twitter4j-repo</id>
+ <name>Twitter4J Repository</name>
+ <url>http://twitter4j.org/maven2/</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
+ </repository>
</repositories>
<pluginRepositories>
<pluginRepository>
@@ -185,7 +198,7 @@
<dependency>
<groupId>de.javakaffee</groupId>
<artifactId>kryo-serializers</artifactId>
- <version>0.9</version>
+ <version>0.20</version>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
@@ -223,6 +236,11 @@
<version>${spray.version}</version>
</dependency>
<dependency>
+ <groupId>cc.spray</groupId>
+ <artifactId>spray-json_${scala.version}</artifactId>
+ <version>${spray.json.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.tomdz.twirl</groupId>
<artifactId>twirl-api</artifactId>
<version>1.0.2</version>
@@ -481,6 +499,12 @@
<profiles>
<profile>
<id>hadoop1</id>
+ <activation>
+ <property>
+ <name>!hadoopVersion</name>
+ </property>
+ </activation>
+
<properties>
<hadoop.major.version>1</hadoop.major.version>
</properties>
@@ -489,7 +513,7 @@
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
- <version>0.20.205.0</version>
+ <version>1.0.3</version>
</dependency>
</dependencies>
</dependencyManagement>
@@ -497,6 +521,12 @@
<profile>
<id>hadoop2</id>
+ <activation>
+ <property>
+ <name>hadoopVersion</name>
+ <value>2</value>
+ </property>
+ </activation>
<properties>
<hadoop.major.version>2</hadoop.major.version>
</properties>
@@ -512,6 +542,17 @@
<artifactId>hadoop-client</artifactId>
<version>2.0.0-mr1-cdh${cdh.version}</version>
</dependency>
+ <!-- Specify Avro version because Kafka also has it as a dependency -->
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro</artifactId>
+ <version>1.7.1.cloudera.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro-ipc</artifactId>
+ <version>1.7.1.cloudera.2</version>
+ </dependency>
</dependencies>
</dependencyManagement>
</profile>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 2f67bb9921..03b8094f7d 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -10,23 +10,25 @@ import twirl.sbt.TwirlPlugin._
object SparkBuild extends Build {
// Hadoop version to build against. For example, "0.20.2", "0.20.205.0", or
// "1.0.3" for Apache releases, or "0.20.2-cdh3u5" for Cloudera Hadoop.
- val HADOOP_VERSION = "0.20.205.0"
+ val HADOOP_VERSION = "1.0.3"
val HADOOP_MAJOR_VERSION = "1"
// For Hadoop 2 versions such as "2.0.0-mr1-cdh4.1.1", set the HADOOP_MAJOR_VERSION to "2"
//val HADOOP_VERSION = "2.0.0-mr1-cdh4.1.1"
//val HADOOP_MAJOR_VERSION = "2"
- lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel)
+ lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming)
lazy val core = Project("core", file("core"), settings = coreSettings)
- lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core)
+ lazy val repl = Project("repl", file("repl"), settings = replSettings) dependsOn (core) dependsOn (streaming)
- lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core)
+ lazy val examples = Project("examples", file("examples"), settings = examplesSettings) dependsOn (core) dependsOn (streaming)
lazy val bagel = Project("bagel", file("bagel"), settings = bagelSettings) dependsOn (core)
+ lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn (core)
+
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
@@ -38,6 +40,7 @@ object SparkBuild extends Build {
scalacOptions := Seq(/*"-deprecation",*/ "-unchecked", "-optimize"), // -deprecation is too noisy due to usage of old Hadoop API, enable it once that's no longer an issue
unmanagedJars in Compile <<= baseDirectory map { base => (base / "lib" ** "*.jar").classpath },
retrieveManaged := true,
+ retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
transitiveClassifiers in Scope.GlobalScope := Seq("sources"),
testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))),
@@ -87,7 +90,7 @@ object SparkBuild extends Build {
libraryDependencies ++= Seq(
"org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011",
- "org.scalatest" %% "scalatest" % "1.6.1" % "test",
+ "org.scalatest" %% "scalatest" % "1.8" % "test",
"org.scalacheck" %% "scalacheck" % "1.9" % "test",
"com.novocode" % "junit-interface" % "0.8" % "test"
),
@@ -113,7 +116,8 @@ object SparkBuild extends Build {
"Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/",
"JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/",
"Spray Repository" at "http://repo.spray.cc/",
- "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/"
+ "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/",
+ "Twitter4J Repository" at "http://twitter4j.org/maven2/"
),
libraryDependencies ++= Seq(
@@ -125,7 +129,7 @@ object SparkBuild extends Build {
"org.apache.hadoop" % "hadoop-core" % HADOOP_VERSION,
"asm" % "asm-all" % "3.3.1",
"com.google.protobuf" % "protobuf-java" % "2.4.1",
- "de.javakaffee" % "kryo-serializers" % "0.9",
+ "de.javakaffee" % "kryo-serializers" % "0.20",
"com.typesafe.akka" % "akka-actor" % "2.0.3",
"com.typesafe.akka" % "akka-remote" % "2.0.3",
"com.typesafe.akka" % "akka-slf4j" % "2.0.3",
@@ -133,6 +137,7 @@ object SparkBuild extends Build {
"colt" % "colt" % "1.2.0",
"cc.spray" % "spray-can" % "1.0-M2.1",
"cc.spray" % "spray-server" % "1.0-M2.1",
+ "cc.spray" %% "spray-json" % "1.1.1",
"org.apache.mesos" % "mesos" % "0.9.0-incubating"
) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq,
unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") }
@@ -148,11 +153,22 @@ object SparkBuild extends Build {
)
def examplesSettings = sharedSettings ++ Seq(
- name := "spark-examples"
+ name := "spark-examples",
+ libraryDependencies ++= Seq(
+ "org.twitter4j" % "twitter4j-stream" % "3.0.3"
+ )
)
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
+ def streamingSettings = sharedSettings ++ Seq(
+ name := "spark-streaming",
+ libraryDependencies ++= Seq(
+ "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile",
+ "com.github.sgroschupf" % "zkclient" % "0.1"
+ )
+ ) ++ assemblySettings ++ extraAssemblySettings
+
def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq(
mergeStrategy in assembly := {
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
diff --git a/pyspark b/pyspark
new file mode 100755
index 0000000000..ab7f4f50c0
--- /dev/null
+++ b/pyspark
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+
+# Figure out where the Scala framework is installed
+FWDIR="$(cd `dirname $0`; pwd)"
+
+# Export this as SPARK_HOME
+export SPARK_HOME="$FWDIR"
+
+# Exit if the user hasn't compiled Spark
+if [ ! -e "$SPARK_HOME/repl/target" ]; then
+ echo "Failed to find Spark classes in $SPARK_HOME/repl/target" >&2
+ echo "You need to compile Spark before running this program" >&2
+ exit 1
+fi
+
+# Load environment variables from conf/spark-env.sh, if it exists
+if [ -e $FWDIR/conf/spark-env.sh ] ; then
+ . $FWDIR/conf/spark-env.sh
+fi
+
+# Figure out which Python executable to use
+if [ -z "$PYSPARK_PYTHON" ] ; then
+ PYSPARK_PYTHON="python"
+fi
+export PYSPARK_PYTHON
+
+# Add the PySpark classes to the Python path:
+export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH
+
+# Load the PySpark shell.py script when ./pyspark is used interactively:
+export OLD_PYTHONSTARTUP=$PYTHONSTARTUP
+export PYTHONSTARTUP=$FWDIR/python/pyspark/shell.py
+
+# Launch with `scala` by default:
+if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then
+ export SPARK_LAUNCH_WITH_SCALA=1
+fi
+
+exec "$PYSPARK_PYTHON" "$@"
diff --git a/python/.gitignore b/python/.gitignore
new file mode 100644
index 0000000000..5c56e638f9
--- /dev/null
+++ b/python/.gitignore
@@ -0,0 +1,2 @@
+*.pyc
+docs/
diff --git a/python/epydoc.conf b/python/epydoc.conf
new file mode 100644
index 0000000000..45102cd9fe
--- /dev/null
+++ b/python/epydoc.conf
@@ -0,0 +1,19 @@
+[epydoc] # Epydoc section marker (required by ConfigParser)
+
+# Information about the project.
+name: PySpark
+url: http://spark-project.org
+
+# The list of modules to document. Modules can be named using
+# dotted names, module filenames, or package directory names.
+# This option may be repeated.
+modules: pyspark
+
+# Write html output to the directory "apidocs"
+output: html
+target: docs/
+
+private: no
+
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+ pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
diff --git a/python/examples/als.py b/python/examples/als.py
new file mode 100755
index 0000000000..010f80097f
--- /dev/null
+++ b/python/examples/als.py
@@ -0,0 +1,71 @@
+"""
+This example requires numpy (http://www.numpy.org/)
+"""
+from os.path import realpath
+import sys
+
+import numpy as np
+from numpy.random import rand
+from numpy import matrix
+from pyspark import SparkContext
+
+LAMBDA = 0.01 # regularization
+np.random.seed(42)
+
+def rmse(R, ms, us):
+ diff = R - ms * us.T
+ return np.sqrt(np.sum(np.power(diff, 2)) / M * U)
+
+def update(i, vec, mat, ratings):
+ uu = mat.shape[0]
+ ff = mat.shape[1]
+ XtX = matrix(np.zeros((ff, ff)))
+ Xty = np.zeros((ff, 1))
+
+ for j in range(uu):
+ v = mat[j, :]
+ XtX += v.T * v
+ Xty += v.T * ratings[i, j]
+ XtX += np.eye(ff, ff) * LAMBDA * uu
+ return np.linalg.solve(XtX, Xty)
+
+if __name__ == "__main__":
+ if len(sys.argv) < 2:
+ print >> sys.stderr, \
+ "Usage: PythonALS <master> <M> <U> <F> <iters> <slices>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonALS", pyFiles=[realpath(__file__)])
+ M = int(sys.argv[2]) if len(sys.argv) > 2 else 100
+ U = int(sys.argv[3]) if len(sys.argv) > 3 else 500
+ F = int(sys.argv[4]) if len(sys.argv) > 4 else 10
+ ITERATIONS = int(sys.argv[5]) if len(sys.argv) > 5 else 5
+ slices = int(sys.argv[6]) if len(sys.argv) > 6 else 2
+
+ print "Running ALS with M=%d, U=%d, F=%d, iters=%d, slices=%d\n" % \
+ (M, U, F, ITERATIONS, slices)
+
+ R = matrix(rand(M, F)) * matrix(rand(U, F).T)
+ ms = matrix(rand(M ,F))
+ us = matrix(rand(U, F))
+
+ Rb = sc.broadcast(R)
+ msb = sc.broadcast(ms)
+ usb = sc.broadcast(us)
+
+ for i in range(ITERATIONS):
+ ms = sc.parallelize(range(M), slices) \
+ .map(lambda x: update(x, msb.value[x, :], usb.value, Rb.value)) \
+ .collect()
+ ms = matrix(np.array(ms)[:, :, 0]) # collect() returns a list, so array ends up being
+ # a 3-d array, we take the first 2 dims for the matrix
+ msb = sc.broadcast(ms)
+
+ us = sc.parallelize(range(U), slices) \
+ .map(lambda x: update(x, usb.value[x, :], msb.value, Rb.value.T)) \
+ .collect()
+ us = matrix(np.array(us)[:, :, 0])
+ usb = sc.broadcast(us)
+
+ error = rmse(R, ms, us)
+ print "Iteration %d:" % i
+ print "\nRMSE: %5.4f\n" % error
diff --git a/python/examples/kmeans.py b/python/examples/kmeans.py
new file mode 100644
index 0000000000..72cf9f88c6
--- /dev/null
+++ b/python/examples/kmeans.py
@@ -0,0 +1,54 @@
+"""
+This example requires numpy (http://www.numpy.org/)
+"""
+import sys
+
+import numpy as np
+from pyspark import SparkContext
+
+
+def parseVector(line):
+ return np.array([float(x) for x in line.split(' ')])
+
+
+def closestPoint(p, centers):
+ bestIndex = 0
+ closest = float("+inf")
+ for i in range(len(centers)):
+ tempDist = np.sum((p - centers[i]) ** 2)
+ if tempDist < closest:
+ closest = tempDist
+ bestIndex = i
+ return bestIndex
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 5:
+ print >> sys.stderr, \
+ "Usage: PythonKMeans <master> <file> <k> <convergeDist>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonKMeans")
+ lines = sc.textFile(sys.argv[2])
+ data = lines.map(parseVector).cache()
+ K = int(sys.argv[3])
+ convergeDist = float(sys.argv[4])
+
+ # TODO: change this after we port takeSample()
+ #kPoints = data.takeSample(False, K, 34)
+ kPoints = data.take(K)
+ tempDist = 1.0
+
+ while tempDist > convergeDist:
+ closest = data.map(
+ lambda p : (closestPoint(p, kPoints), (p, 1)))
+ pointStats = closest.reduceByKey(
+ lambda (x1, y1), (x2, y2): (x1 + x2, y1 + y2))
+ newPoints = pointStats.map(
+ lambda (x, (y, z)): (x, y / z)).collect()
+
+ tempDist = sum(np.sum((kPoints[x] - y) ** 2) for (x, y) in newPoints)
+
+ for (x, y) in newPoints:
+ kPoints[x] = y
+
+ print "Final centers: " + str(kPoints)
diff --git a/python/examples/logistic_regression.py b/python/examples/logistic_regression.py
new file mode 100755
index 0000000000..f13698a86f
--- /dev/null
+++ b/python/examples/logistic_regression.py
@@ -0,0 +1,57 @@
+"""
+This example requires numpy (http://www.numpy.org/)
+"""
+from collections import namedtuple
+from math import exp
+from os.path import realpath
+import sys
+
+import numpy as np
+from pyspark import SparkContext
+
+
+N = 100000 # Number of data points
+D = 10 # Number of dimensions
+R = 0.7 # Scaling factor
+ITERATIONS = 5
+np.random.seed(42)
+
+
+DataPoint = namedtuple("DataPoint", ['x', 'y'])
+from lr import DataPoint # So that DataPoint is properly serialized
+
+
+def generateData():
+ def generatePoint(i):
+ y = -1 if i % 2 == 0 else 1
+ x = np.random.normal(size=D) + (y * R)
+ return DataPoint(x, y)
+ return [generatePoint(i) for i in range(N)]
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ print >> sys.stderr, \
+ "Usage: PythonLR <master> [<slices>]"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonLR", pyFiles=[realpath(__file__)])
+ slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
+ points = sc.parallelize(generateData(), slices).cache()
+
+ # Initialize w to a random value
+ w = 2 * np.random.ranf(size=D) - 1
+ print "Initial w: " + str(w)
+
+ def add(x, y):
+ x += y
+ return x
+
+ for i in range(1, ITERATIONS + 1):
+ print "On iteration %i" % i
+
+ gradient = points.map(lambda p:
+ (1.0 / (1.0 + exp(-p.y * np.dot(w, p.x)))) * p.y * p.x
+ ).reduce(add)
+ w -= gradient
+
+ print "Final w: " + str(w)
diff --git a/python/examples/pi.py b/python/examples/pi.py
new file mode 100644
index 0000000000..127cba029b
--- /dev/null
+++ b/python/examples/pi.py
@@ -0,0 +1,21 @@
+import sys
+from random import random
+from operator import add
+
+from pyspark import SparkContext
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ print >> sys.stderr, \
+ "Usage: PythonPi <master> [<slices>]"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonPi")
+ slices = int(sys.argv[2]) if len(sys.argv) > 2 else 2
+ n = 100000 * slices
+ def f(_):
+ x = random() * 2 - 1
+ y = random() * 2 - 1
+ return 1 if x ** 2 + y ** 2 < 1 else 0
+ count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add)
+ print "Pi is roughly %f" % (4.0 * count / n)
diff --git a/python/examples/transitive_closure.py b/python/examples/transitive_closure.py
new file mode 100644
index 0000000000..73f7f8fbaf
--- /dev/null
+++ b/python/examples/transitive_closure.py
@@ -0,0 +1,50 @@
+import sys
+from random import Random
+
+from pyspark import SparkContext
+
+numEdges = 200
+numVertices = 100
+rand = Random(42)
+
+
+def generateGraph():
+ edges = set()
+ while len(edges) < numEdges:
+ src = rand.randrange(0, numEdges)
+ dst = rand.randrange(0, numEdges)
+ if src != dst:
+ edges.add((src, dst))
+ return edges
+
+
+if __name__ == "__main__":
+ if len(sys.argv) == 1:
+ print >> sys.stderr, \
+ "Usage: PythonTC <master> [<slices>]"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonTC")
+ slices = sys.argv[2] if len(sys.argv) > 2 else 2
+ tc = sc.parallelize(generateGraph(), slices).cache()
+
+ # Linear transitive closure: each round grows paths by one edge,
+ # by joining the graph's edges with the already-discovered paths.
+ # e.g. join the path (y, z) from the TC with the edge (x, y) from
+ # the graph to obtain the path (x, z).
+
+ # Because join() joins on keys, the edges are stored in reversed order.
+ edges = tc.map(lambda (x, y): (y, x))
+
+ oldCount = 0L
+ nextCount = tc.count()
+ while True:
+ oldCount = nextCount
+ # Perform the join, obtaining an RDD of (y, (z, x)) pairs,
+ # then project the result to obtain the new (x, z) paths.
+ new_edges = tc.join(edges).map(lambda (_, (a, b)): (b, a))
+ tc = tc.union(new_edges).distinct().cache()
+ nextCount = tc.count()
+ if nextCount == oldCount:
+ break
+
+ print "TC has %i edges" % tc.count()
diff --git a/python/examples/wordcount.py b/python/examples/wordcount.py
new file mode 100644
index 0000000000..857160624b
--- /dev/null
+++ b/python/examples/wordcount.py
@@ -0,0 +1,19 @@
+import sys
+from operator import add
+
+from pyspark import SparkContext
+
+
+if __name__ == "__main__":
+ if len(sys.argv) < 3:
+ print >> sys.stderr, \
+ "Usage: PythonWordCount <master> <file>"
+ exit(-1)
+ sc = SparkContext(sys.argv[1], "PythonWordCount")
+ lines = sc.textFile(sys.argv[2], 1)
+ counts = lines.flatMap(lambda x: x.split(' ')) \
+ .map(lambda x: (x, 1)) \
+ .reduceByKey(add)
+ output = counts.collect()
+ for (word, count) in output:
+ print "%s : %i" % (word, count)
diff --git a/python/lib/PY4J_LICENSE.txt b/python/lib/PY4J_LICENSE.txt
new file mode 100644
index 0000000000..a70279ca14
--- /dev/null
+++ b/python/lib/PY4J_LICENSE.txt
@@ -0,0 +1,27 @@
+
+Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+- Redistributions of source code must retain the above copyright notice, this
+list of conditions and the following disclaimer.
+
+- Redistributions in binary form must reproduce the above copyright notice,
+this list of conditions and the following disclaimer in the documentation
+and/or other materials provided with the distribution.
+
+- The name of the author may not be used to endorse or promote products
+derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+POSSIBILITY OF SUCH DAMAGE.
diff --git a/python/lib/PY4J_VERSION.txt b/python/lib/PY4J_VERSION.txt
new file mode 100644
index 0000000000..04a0cd52a8
--- /dev/null
+++ b/python/lib/PY4J_VERSION.txt
@@ -0,0 +1 @@
+b7924aabe9c5e63f0a4d8bbd17019534c7ec014e
diff --git a/python/lib/py4j0.7.egg b/python/lib/py4j0.7.egg
new file mode 100644
index 0000000000..f8a339d8ee
--- /dev/null
+++ b/python/lib/py4j0.7.egg
Binary files differ
diff --git a/python/lib/py4j0.7.jar b/python/lib/py4j0.7.jar
new file mode 100644
index 0000000000..73b7ddb7d1
--- /dev/null
+++ b/python/lib/py4j0.7.jar
Binary files differ
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
new file mode 100644
index 0000000000..3e8bca62f0
--- /dev/null
+++ b/python/pyspark/__init__.py
@@ -0,0 +1,27 @@
+"""
+PySpark is a Python API for Spark.
+
+Public classes:
+
+ - L{SparkContext<pyspark.context.SparkContext>}
+ Main entry point for Spark functionality.
+ - L{RDD<pyspark.rdd.RDD>}
+ A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
+ - L{Broadcast<pyspark.broadcast.Broadcast>}
+ A broadcast variable that gets reused across tasks.
+ - L{Accumulator<pyspark.accumulators.Accumulator>}
+ An "add-only" shared variable that tasks can only add values to.
+ - L{SparkFiles<pyspark.files.SparkFiles>}
+ Access files shipped with jobs.
+"""
+import sys
+import os
+sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg"))
+
+
+from pyspark.context import SparkContext
+from pyspark.rdd import RDD
+from pyspark.files import SparkFiles
+
+
+__all__ = ["SparkContext", "RDD", "SparkFiles"]
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
new file mode 100644
index 0000000000..8011779ddc
--- /dev/null
+++ b/python/pyspark/accumulators.py
@@ -0,0 +1,187 @@
+"""
+>>> from pyspark.context import SparkContext
+>>> sc = SparkContext('local', 'test')
+>>> a = sc.accumulator(1)
+>>> a.value
+1
+>>> a.value = 2
+>>> a.value
+2
+>>> a += 5
+>>> a.value
+7
+
+>>> sc.accumulator(1.0).value
+1.0
+
+>>> sc.accumulator(1j).value
+1j
+
+>>> rdd = sc.parallelize([1,2,3])
+>>> def f(x):
+... global a
+... a += x
+>>> rdd.foreach(f)
+>>> a.value
+13
+
+>>> class VectorAccumulatorParam(object):
+... def zero(self, value):
+... return [0.0] * len(value)
+... def addInPlace(self, val1, val2):
+... for i in xrange(len(val1)):
+... val1[i] += val2[i]
+... return val1
+>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
+>>> va.value
+[1.0, 2.0, 3.0]
+>>> def g(x):
+... global va
+... va += [x] * 3
+>>> rdd.foreach(g)
+>>> va.value
+[7.0, 8.0, 9.0]
+
+>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+ ...
+Py4JJavaError:...
+
+>>> def h(x):
+... global a
+... a.value = 7
+>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+ ...
+Py4JJavaError:...
+
+>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+ ...
+Exception:...
+"""
+
+import struct
+import SocketServer
+import threading
+from pyspark.cloudpickle import CloudPickler
+from pyspark.serializers import read_int, read_with_length, load_pickle
+
+
+# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
+# the local accumulator updates back to the driver program at the end of a task.
+_accumulatorRegistry = {}
+
+
+def _deserialize_accumulator(aid, zero_value, accum_param):
+ from pyspark.accumulators import _accumulatorRegistry
+ accum = Accumulator(aid, zero_value, accum_param)
+ accum._deserialized = True
+ _accumulatorRegistry[aid] = accum
+ return accum
+
+
+class Accumulator(object):
+ """
+ A shared variable that can be accumulated, i.e., has a commutative and associative "add"
+ operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=}
+ operator, but only the driver program is allowed to access its value, using C{value}.
+ Updates from the workers get propagated automatically to the driver program.
+
+ While C{SparkContext} supports accumulators for primitive data types like C{int} and
+ C{float}, users can also define accumulators for custom types by providing a custom
+ C{AccumulatorParam} object with a C{zero} and C{addInPlace} method. Refer to the doctest
+ of this module for an example.
+ """
+
+ def __init__(self, aid, value, accum_param):
+ """Create a new Accumulator with a given initial value and AccumulatorParam object"""
+ from pyspark.accumulators import _accumulatorRegistry
+ self.aid = aid
+ self.accum_param = accum_param
+ self._value = value
+ self._deserialized = False
+ _accumulatorRegistry[aid] = self
+
+ def __reduce__(self):
+ """Custom serialization; saves the zero value from our AccumulatorParam"""
+ param = self.accum_param
+ return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
+
+ @property
+ def value(self):
+ """Get the accumulator's value; only usable in driver program"""
+ if self._deserialized:
+ raise Exception("Accumulator.value cannot be accessed inside tasks")
+ return self._value
+
+ @value.setter
+ def value(self, value):
+ """Sets the accumulator's value; only usable in driver program"""
+ if self._deserialized:
+ raise Exception("Accumulator.value cannot be accessed inside tasks")
+ self._value = value
+
+ def __iadd__(self, term):
+ """The += operator; adds a term to this accumulator's value"""
+ self._value = self.accum_param.addInPlace(self._value, term)
+ return self
+
+ def __str__(self):
+ return str(self._value)
+
+ def __repr__(self):
+ return "Accumulator<id=%i, value=%s>" % (self.aid, self._value)
+
+
+class AddingAccumulatorParam(object):
+ """
+ An AccumulatorParam that uses the + operators to add values. Designed for simple types
+ such as integers, floats, and lists. Requires the zero value for the underlying type
+ as a parameter.
+ """
+
+ def __init__(self, zero_value):
+ self.zero_value = zero_value
+
+ def zero(self, value):
+ return self.zero_value
+
+ def addInPlace(self, value1, value2):
+ value1 += value2
+ return value1
+
+
+# Singleton accumulator params for some standard types
+INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
+FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
+COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
+
+
+class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
+ def handle(self):
+ from pyspark.accumulators import _accumulatorRegistry
+ num_updates = read_int(self.rfile)
+ for _ in range(num_updates):
+ (aid, update) = load_pickle(read_with_length(self.rfile))
+ _accumulatorRegistry[aid] += update
+ # Write a byte in acknowledgement
+ self.wfile.write(struct.pack("!b", 1))
+
+
+def _start_update_server():
+ """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
+ server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
+ thread = threading.Thread(target=server.serve_forever)
+ thread.daemon = True
+ thread.start()
+ return server
+
+
+def _test():
+ import doctest
+ doctest.testmod()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
new file mode 100644
index 0000000000..93876fa738
--- /dev/null
+++ b/python/pyspark/broadcast.py
@@ -0,0 +1,48 @@
+"""
+>>> from pyspark.context import SparkContext
+>>> sc = SparkContext('local', 'test')
+>>> b = sc.broadcast([1, 2, 3, 4, 5])
+>>> b.value
+[1, 2, 3, 4, 5]
+
+>>> from pyspark.broadcast import _broadcastRegistry
+>>> _broadcastRegistry[b.bid] = b
+>>> from cPickle import dumps, loads
+>>> loads(dumps(b)).value
+[1, 2, 3, 4, 5]
+
+>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
+[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
+
+>>> large_broadcast = sc.broadcast(list(range(10000)))
+"""
+# Holds broadcasted data received from Java, keyed by its id.
+_broadcastRegistry = {}
+
+
+def _from_id(bid):
+ from pyspark.broadcast import _broadcastRegistry
+ if bid not in _broadcastRegistry:
+ raise Exception("Broadcast variable '%s' not loaded!" % bid)
+ return _broadcastRegistry[bid]
+
+
+class Broadcast(object):
+ def __init__(self, bid, value, java_broadcast=None, pickle_registry=None):
+ self.value = value
+ self.bid = bid
+ self._jbroadcast = java_broadcast
+ self._pickle_registry = pickle_registry
+
+ def __reduce__(self):
+ self._pickle_registry.add(self)
+ return (_from_id, (self.bid, ))
+
+
+def _test():
+ import doctest
+ doctest.testmod()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
new file mode 100644
index 0000000000..6a7c23a069
--- /dev/null
+++ b/python/pyspark/cloudpickle.py
@@ -0,0 +1,974 @@
+"""
+This class is defined to override standard pickle functionality
+
+The goals of it follow:
+-Serialize lambdas and nested functions to compiled byte code
+-Deal with main module correctly
+-Deal with other non-serializable objects
+
+It does not include an unpickler, as standard python unpickling suffices.
+
+This module was extracted from the `cloud` package, developed by `PiCloud, Inc.
+<http://www.picloud.com>`_.
+
+Copyright (c) 2012, Regents of the University of California.
+Copyright (c) 2009 `PiCloud, Inc. <http://www.picloud.com>`_.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions
+are met:
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+ * Neither the name of the University of California, Berkeley nor the
+ names of its contributors may be used to endorse or promote
+ products derived from this software without specific prior written
+ permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
+TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+
+import operator
+import os
+import pickle
+import struct
+import sys
+import types
+from functools import partial
+import itertools
+from copy_reg import _extension_registry, _inverted_registry, _extension_cache
+import new
+import dis
+import traceback
+
+#relevant opcodes
+STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL'))
+DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL'))
+LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL'))
+GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]
+
+HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT)
+EXTENDED_ARG = chr(dis.EXTENDED_ARG)
+
+import logging
+cloudLog = logging.getLogger("Cloud.Transport")
+
+try:
+ import ctypes
+except (MemoryError, ImportError):
+ logging.warning('Exception raised on importing ctypes. Likely python bug.. some functionality will be disabled', exc_info = True)
+ ctypes = None
+ PyObject_HEAD = None
+else:
+
+ # for reading internal structures
+ PyObject_HEAD = [
+ ('ob_refcnt', ctypes.c_size_t),
+ ('ob_type', ctypes.c_void_p),
+ ]
+
+
+try:
+ from cStringIO import StringIO
+except ImportError:
+ from StringIO import StringIO
+
+# These helper functions were copied from PiCloud's util module.
+def islambda(func):
+ return getattr(func,'func_name') == '<lambda>'
+
+def xrange_params(xrangeobj):
+ """Returns a 3 element tuple describing the xrange start, step, and len
+ respectively
+
+ Note: Only guarentees that elements of xrange are the same. parameters may
+ be different.
+ e.g. xrange(1,1) is interpretted as xrange(0,0); both behave the same
+ though w/ iteration
+ """
+
+ xrange_len = len(xrangeobj)
+ if not xrange_len: #empty
+ return (0,1,0)
+ start = xrangeobj[0]
+ if xrange_len == 1: #one element
+ return start, 1, 1
+ return (start, xrangeobj[1] - xrangeobj[0], xrange_len)
+
+#debug variables intended for developer use:
+printSerialization = False
+printMemoization = False
+
+useForcedImports = True #Should I use forced imports for tracking?
+
+
+
+class CloudPickler(pickle.Pickler):
+
+ dispatch = pickle.Pickler.dispatch.copy()
+ savedForceImports = False
+ savedDjangoEnv = False #hack tro transport django environment
+
+ def __init__(self, file, protocol=None, min_size_to_save= 0):
+ pickle.Pickler.__init__(self,file,protocol)
+ self.modules = set() #set of modules needed to depickle
+ self.globals_ref = {} # map ids to dictionary. used to ensure that functions can share global env
+
+ def dump(self, obj):
+ # note: not thread safe
+ # minimal side-effects, so not fixing
+ recurse_limit = 3000
+ base_recurse = sys.getrecursionlimit()
+ if base_recurse < recurse_limit:
+ sys.setrecursionlimit(recurse_limit)
+ self.inject_addons()
+ try:
+ return pickle.Pickler.dump(self, obj)
+ except RuntimeError, e:
+ if 'recursion' in e.args[0]:
+ msg = """Could not pickle object as excessively deep recursion required.
+ Try _fast_serialization=2 or contact PiCloud support"""
+ raise pickle.PicklingError(msg)
+ finally:
+ new_recurse = sys.getrecursionlimit()
+ if new_recurse == recurse_limit:
+ sys.setrecursionlimit(base_recurse)
+
+ def save_buffer(self, obj):
+ """Fallback to save_string"""
+ pickle.Pickler.save_string(self,str(obj))
+ dispatch[buffer] = save_buffer
+
+ #block broken objects
+ def save_unsupported(self, obj, pack=None):
+ raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj))
+ dispatch[types.GeneratorType] = save_unsupported
+
+ #python2.6+ supports slice pickling. some py2.5 extensions might as well. We just test it
+ try:
+ slice(0,1).__reduce__()
+ except TypeError: #can't pickle -
+ dispatch[slice] = save_unsupported
+
+ #itertools objects do not pickle!
+ for v in itertools.__dict__.values():
+ if type(v) is type:
+ dispatch[v] = save_unsupported
+
+
+ def save_dict(self, obj):
+ """hack fix
+ If the dict is a global, deal with it in a special way
+ """
+ #print 'saving', obj
+ if obj is __builtins__:
+ self.save_reduce(_get_module_builtins, (), obj=obj)
+ else:
+ pickle.Pickler.save_dict(self, obj)
+ dispatch[pickle.DictionaryType] = save_dict
+
+
+ def save_module(self, obj, pack=struct.pack):
+ """
+ Save a module as an import
+ """
+ #print 'try save import', obj.__name__
+ self.modules.add(obj)
+ self.save_reduce(subimport,(obj.__name__,), obj=obj)
+ dispatch[types.ModuleType] = save_module #new type
+
+ def save_codeobject(self, obj, pack=struct.pack):
+ """
+ Save a code object
+ """
+ #print 'try to save codeobj: ', obj
+ args = (
+ obj.co_argcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code,
+ obj.co_consts, obj.co_names, obj.co_varnames, obj.co_filename, obj.co_name,
+ obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars
+ )
+ self.save_reduce(types.CodeType, args, obj=obj)
+ dispatch[types.CodeType] = save_codeobject #new type
+
+ def save_function(self, obj, name=None, pack=struct.pack):
+ """ Registered with the dispatch to handle all function types.
+
+ Determines what kind of function obj is (e.g. lambda, defined at
+ interactive prompt, etc) and handles the pickling appropriately.
+ """
+ write = self.write
+
+ name = obj.__name__
+ modname = pickle.whichmodule(obj, name)
+ #print 'which gives %s %s %s' % (modname, obj, name)
+ try:
+ themodule = sys.modules[modname]
+ except KeyError: # eval'd items such as namedtuple give invalid items for their function __module__
+ modname = '__main__'
+
+ if modname == '__main__':
+ themodule = None
+
+ if themodule:
+ self.modules.add(themodule)
+
+ if not self.savedDjangoEnv:
+ #hack for django - if we detect the settings module, we transport it
+ django_settings = os.environ.get('DJANGO_SETTINGS_MODULE', '')
+ if django_settings:
+ django_mod = sys.modules.get(django_settings)
+ if django_mod:
+ cloudLog.debug('Transporting django settings %s during save of %s', django_mod, name)
+ self.savedDjangoEnv = True
+ self.modules.add(django_mod)
+ write(pickle.MARK)
+ self.save_reduce(django_settings_load, (django_mod.__name__,), obj=django_mod)
+ write(pickle.POP_MARK)
+
+
+ # if func is lambda, def'ed at prompt, is in main, or is nested, then
+ # we'll pickle the actual function object rather than simply saving a
+ # reference (as is done in default pickler), via save_function_tuple.
+ if islambda(obj) or obj.func_code.co_filename == '<stdin>' or themodule == None:
+ #Force server to import modules that have been imported in main
+ modList = None
+ if themodule == None and not self.savedForceImports:
+ mainmod = sys.modules['__main__']
+ if useForcedImports and hasattr(mainmod,'___pyc_forcedImports__'):
+ modList = list(mainmod.___pyc_forcedImports__)
+ self.savedForceImports = True
+ self.save_function_tuple(obj, modList)
+ return
+ else: # func is nested
+ klass = getattr(themodule, name, None)
+ if klass is None or klass is not obj:
+ self.save_function_tuple(obj, [themodule])
+ return
+
+ if obj.__dict__:
+ # essentially save_reduce, but workaround needed to avoid recursion
+ self.save(_restore_attr)
+ write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n')
+ self.memoize(obj)
+ self.save(obj.__dict__)
+ write(pickle.TUPLE + pickle.REDUCE)
+ else:
+ write(pickle.GLOBAL + modname + '\n' + name + '\n')
+ self.memoize(obj)
+ dispatch[types.FunctionType] = save_function
+
+ def save_function_tuple(self, func, forced_imports):
+ """ Pickles an actual func object.
+
+ A func comprises: code, globals, defaults, closure, and dict. We
+ extract and save these, injecting reducing functions at certain points
+ to recreate the func object. Keep in mind that some of these pieces
+ can contain a ref to the func itself. Thus, a naive save on these
+ pieces could trigger an infinite loop of save's. To get around that,
+ we first create a skeleton func object using just the code (this is
+ safe, since this won't contain a ref to the func), and memoize it as
+ soon as it's created. The other stuff can then be filled in later.
+ """
+ save = self.save
+ write = self.write
+
+ # save the modules (if any)
+ if forced_imports:
+ write(pickle.MARK)
+ save(_modules_to_main)
+ #print 'forced imports are', forced_imports
+
+ forced_names = map(lambda m: m.__name__, forced_imports)
+ save((forced_names,))
+
+ #save((forced_imports,))
+ write(pickle.REDUCE)
+ write(pickle.POP_MARK)
+
+ code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
+
+ save(_fill_function) # skeleton function updater
+ write(pickle.MARK) # beginning of tuple that _fill_function expects
+
+ # create a skeleton function object and memoize it
+ save(_make_skel_func)
+ save((code, len(closure), base_globals))
+ write(pickle.REDUCE)
+ self.memoize(func)
+
+ # save the rest of the func data needed by _fill_function
+ save(f_globals)
+ save(defaults)
+ save(closure)
+ save(dct)
+ write(pickle.TUPLE)
+ write(pickle.REDUCE) # applies _fill_function on the tuple
+
+ @staticmethod
+ def extract_code_globals(co):
+ """
+ Find all globals names read or written to by codeblock co
+ """
+ code = co.co_code
+ names = co.co_names
+ out_names = set()
+
+ n = len(code)
+ i = 0
+ extended_arg = 0
+ while i < n:
+ op = code[i]
+
+ i = i+1
+ if op >= HAVE_ARGUMENT:
+ oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
+ extended_arg = 0
+ i = i+2
+ if op == EXTENDED_ARG:
+ extended_arg = oparg*65536L
+ if op in GLOBAL_OPS:
+ out_names.add(names[oparg])
+ #print 'extracted', out_names, ' from ', names
+ return out_names
+
+ def extract_func_data(self, func):
+ """
+ Turn the function into a tuple of data necessary to recreate it:
+ code, globals, defaults, closure, dict
+ """
+ code = func.func_code
+
+ # extract all global ref's
+ func_global_refs = CloudPickler.extract_code_globals(code)
+ if code.co_consts: # see if nested function have any global refs
+ for const in code.co_consts:
+ if type(const) is types.CodeType and const.co_names:
+ func_global_refs = func_global_refs.union( CloudPickler.extract_code_globals(const))
+ # process all variables referenced by global environment
+ f_globals = {}
+ for var in func_global_refs:
+ #Some names, such as class functions are not global - we don't need them
+ if func.func_globals.has_key(var):
+ f_globals[var] = func.func_globals[var]
+
+ # defaults requires no processing
+ defaults = func.func_defaults
+
+ def get_contents(cell):
+ try:
+ return cell.cell_contents
+ except ValueError, e: #cell is empty error on not yet assigned
+ raise pickle.PicklingError('Function to be pickled has free variables that are referenced before assignment in enclosing scope')
+
+
+ # process closure
+ if func.func_closure:
+ closure = map(get_contents, func.func_closure)
+ else:
+ closure = []
+
+ # save the dict
+ dct = func.func_dict
+
+ if printSerialization:
+ outvars = ['code: ' + str(code) ]
+ outvars.append('globals: ' + str(f_globals))
+ outvars.append('defaults: ' + str(defaults))
+ outvars.append('closure: ' + str(closure))
+ print 'function ', func, 'is extracted to: ', ', '.join(outvars)
+
+ base_globals = self.globals_ref.get(id(func.func_globals), {})
+ self.globals_ref[id(func.func_globals)] = base_globals
+
+ return (code, f_globals, defaults, closure, dct, base_globals)
+
+ def save_global(self, obj, name=None, pack=struct.pack):
+ write = self.write
+ memo = self.memo
+
+ if name is None:
+ name = obj.__name__
+
+ modname = getattr(obj, "__module__", None)
+ if modname is None:
+ modname = pickle.whichmodule(obj, name)
+
+ try:
+ __import__(modname)
+ themodule = sys.modules[modname]
+ except (ImportError, KeyError, AttributeError): #should never occur
+ raise pickle.PicklingError(
+ "Can't pickle %r: Module %s cannot be found" %
+ (obj, modname))
+
+ if modname == '__main__':
+ themodule = None
+
+ if themodule:
+ self.modules.add(themodule)
+
+ sendRef = True
+ typ = type(obj)
+ #print 'saving', obj, typ
+ try:
+ try: #Deal with case when getattribute fails with exceptions
+ klass = getattr(themodule, name)
+ except (AttributeError):
+ if modname == '__builtin__': #new.* are misrepeported
+ modname = 'new'
+ __import__(modname)
+ themodule = sys.modules[modname]
+ try:
+ klass = getattr(themodule, name)
+ except AttributeError, a:
+ #print themodule, name, obj, type(obj)
+ raise pickle.PicklingError("Can't pickle builtin %s" % obj)
+ else:
+ raise
+
+ except (ImportError, KeyError, AttributeError):
+ if typ == types.TypeType or typ == types.ClassType:
+ sendRef = False
+ else: #we can't deal with this
+ raise
+ else:
+ if klass is not obj and (typ == types.TypeType or typ == types.ClassType):
+ sendRef = False
+ if not sendRef:
+ #note: Third party types might crash this - add better checks!
+ d = dict(obj.__dict__) #copy dict proxy to a dict
+ if not isinstance(d.get('__dict__', None), property): # don't extract dict that are properties
+ d.pop('__dict__',None)
+ d.pop('__weakref__',None)
+
+ # hack as __new__ is stored differently in the __dict__
+ new_override = d.get('__new__', None)
+ if new_override:
+ d['__new__'] = obj.__new__
+
+ self.save_reduce(type(obj),(obj.__name__,obj.__bases__,
+ d),obj=obj)
+ #print 'internal reduce dask %s %s' % (obj, d)
+ return
+
+ if self.proto >= 2:
+ code = _extension_registry.get((modname, name))
+ if code:
+ assert code > 0
+ if code <= 0xff:
+ write(pickle.EXT1 + chr(code))
+ elif code <= 0xffff:
+ write("%c%c%c" % (pickle.EXT2, code&0xff, code>>8))
+ else:
+ write(pickle.EXT4 + pack("<i", code))
+ return
+
+ write(pickle.GLOBAL + modname + '\n' + name + '\n')
+ self.memoize(obj)
+ dispatch[types.ClassType] = save_global
+ dispatch[types.BuiltinFunctionType] = save_global
+ dispatch[types.TypeType] = save_global
+
+ def save_instancemethod(self, obj):
+ #Memoization rarely is ever useful due to python bounding
+ self.save_reduce(types.MethodType, (obj.im_func, obj.im_self,obj.im_class), obj=obj)
+ dispatch[types.MethodType] = save_instancemethod
+
+ def save_inst_logic(self, obj):
+ """Inner logic to save instance. Based off pickle.save_inst
+ Supports __transient__"""
+ cls = obj.__class__
+
+ memo = self.memo
+ write = self.write
+ save = self.save
+
+ if hasattr(obj, '__getinitargs__'):
+ args = obj.__getinitargs__()
+ len(args) # XXX Assert it's a sequence
+ pickle._keep_alive(args, memo)
+ else:
+ args = ()
+
+ write(pickle.MARK)
+
+ if self.bin:
+ save(cls)
+ for arg in args:
+ save(arg)
+ write(pickle.OBJ)
+ else:
+ for arg in args:
+ save(arg)
+ write(pickle.INST + cls.__module__ + '\n' + cls.__name__ + '\n')
+
+ self.memoize(obj)
+
+ try:
+ getstate = obj.__getstate__
+ except AttributeError:
+ stuff = obj.__dict__
+ #remove items if transient
+ if hasattr(obj, '__transient__'):
+ transient = obj.__transient__
+ stuff = stuff.copy()
+ for k in list(stuff.keys()):
+ if k in transient:
+ del stuff[k]
+ else:
+ stuff = getstate()
+ pickle._keep_alive(stuff, memo)
+ save(stuff)
+ write(pickle.BUILD)
+
+
+ def save_inst(self, obj):
+ # Hack to detect PIL Image instances without importing Imaging
+ # PIL can be loaded with multiple names, so we don't check sys.modules for it
+ if hasattr(obj,'im') and hasattr(obj,'palette') and 'Image' in obj.__module__:
+ self.save_image(obj)
+ else:
+ self.save_inst_logic(obj)
+ dispatch[types.InstanceType] = save_inst
+
+ def save_property(self, obj):
+ # properties not correctly saved in python
+ self.save_reduce(property, (obj.fget, obj.fset, obj.fdel, obj.__doc__), obj=obj)
+ dispatch[property] = save_property
+
+ def save_itemgetter(self, obj):
+ """itemgetter serializer (needed for namedtuple support)
+ a bit of a pain as we need to read ctypes internals"""
+ class ItemGetterType(ctypes.Structure):
+ _fields_ = PyObject_HEAD + [
+ ('nitems', ctypes.c_size_t),
+ ('item', ctypes.py_object)
+ ]
+
+
+ itemgetter_obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents
+ return self.save_reduce(operator.itemgetter, (itemgetter_obj.item,))
+
+ if PyObject_HEAD:
+ dispatch[operator.itemgetter] = save_itemgetter
+
+
+
+ def save_reduce(self, func, args, state=None,
+ listitems=None, dictitems=None, obj=None):
+ """Modified to support __transient__ on new objects
+ Change only affects protocol level 2 (which is always used by PiCloud"""
+ # Assert that args is a tuple or None
+ if not isinstance(args, types.TupleType):
+ raise pickle.PicklingError("args from reduce() should be a tuple")
+
+ # Assert that func is callable
+ if not hasattr(func, '__call__'):
+ raise pickle.PicklingError("func from reduce should be callable")
+
+ save = self.save
+ write = self.write
+
+ # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ
+ if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__":
+ #Added fix to allow transient
+ cls = args[0]
+ if not hasattr(cls, "__new__"):
+ raise pickle.PicklingError(
+ "args[0] from __newobj__ args has no __new__")
+ if obj is not None and cls is not obj.__class__:
+ raise pickle.PicklingError(
+ "args[0] from __newobj__ args has the wrong class")
+ args = args[1:]
+ save(cls)
+
+ #Don't pickle transient entries
+ if hasattr(obj, '__transient__'):
+ transient = obj.__transient__
+ state = state.copy()
+
+ for k in list(state.keys()):
+ if k in transient:
+ del state[k]
+
+ save(args)
+ write(pickle.NEWOBJ)
+ else:
+ save(func)
+ save(args)
+ write(pickle.REDUCE)
+
+ if obj is not None:
+ self.memoize(obj)
+
+ # More new special cases (that work with older protocols as
+ # well): when __reduce__ returns a tuple with 4 or 5 items,
+ # the 4th and 5th item should be iterators that provide list
+ # items and dict items (as (key, value) tuples), or None.
+
+ if listitems is not None:
+ self._batch_appends(listitems)
+
+ if dictitems is not None:
+ self._batch_setitems(dictitems)
+
+ if state is not None:
+ #print 'obj %s has state %s' % (obj, state)
+ save(state)
+ write(pickle.BUILD)
+
+
+ def save_xrange(self, obj):
+ """Save an xrange object in python 2.5
+ Python 2.6 supports this natively
+ """
+ range_params = xrange_params(obj)
+ self.save_reduce(_build_xrange,range_params)
+
+ #python2.6+ supports xrange pickling. some py2.5 extensions might as well. We just test it
+ try:
+ xrange(0).__reduce__()
+ except TypeError: #can't pickle -- use PiCloud pickler
+ dispatch[xrange] = save_xrange
+
+ def save_partial(self, obj):
+ """Partial objects do not serialize correctly in python2.x -- this fixes the bugs"""
+ self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords))
+
+ if sys.version_info < (2,7): #2.7 supports partial pickling
+ dispatch[partial] = save_partial
+
+
+ def save_file(self, obj):
+ """Save a file"""
+ import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute
+ from ..transport.adapter import SerializingAdapter
+
+ if not hasattr(obj, 'name') or not hasattr(obj, 'mode'):
+ raise pickle.PicklingError("Cannot pickle files that do not map to an actual file")
+ if obj.name == '<stdout>':
+ return self.save_reduce(getattr, (sys,'stdout'), obj=obj)
+ if obj.name == '<stderr>':
+ return self.save_reduce(getattr, (sys,'stderr'), obj=obj)
+ if obj.name == '<stdin>':
+ raise pickle.PicklingError("Cannot pickle standard input")
+ if hasattr(obj, 'isatty') and obj.isatty():
+ raise pickle.PicklingError("Cannot pickle files that map to tty objects")
+ if 'r' not in obj.mode:
+ raise pickle.PicklingError("Cannot pickle files that are not opened for reading")
+ name = obj.name
+ try:
+ fsize = os.stat(name).st_size
+ except OSError:
+ raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name)
+
+ if obj.closed:
+ #create an empty closed string io
+ retval = pystringIO.StringIO("")
+ retval.close()
+ elif not fsize: #empty file
+ retval = pystringIO.StringIO("")
+ try:
+ tmpfile = file(name)
+ tst = tmpfile.read(1)
+ except IOError:
+ raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name)
+ tmpfile.close()
+ if tst != '':
+ raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name)
+ elif fsize > SerializingAdapter.max_transmit_data:
+ raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" %
+ (name,SerializingAdapter.max_transmit_data))
+ else:
+ try:
+ tmpfile = file(name)
+ contents = tmpfile.read(SerializingAdapter.max_transmit_data)
+ tmpfile.close()
+ except IOError:
+ raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name)
+ retval = pystringIO.StringIO(contents)
+ curloc = obj.tell()
+ retval.seek(curloc)
+
+ retval.name = name
+ self.save(retval) #save stringIO
+ self.memoize(obj)
+
+ dispatch[file] = save_file
+ """Special functions for Add-on libraries"""
+
+ def inject_numpy(self):
+ numpy = sys.modules.get('numpy')
+ if not numpy or not hasattr(numpy, 'ufunc'):
+ return
+ self.dispatch[numpy.ufunc] = self.__class__.save_ufunc
+
+ numpy_tst_mods = ['numpy', 'scipy.special']
+ def save_ufunc(self, obj):
+ """Hack function for saving numpy ufunc objects"""
+ name = obj.__name__
+ for tst_mod_name in self.numpy_tst_mods:
+ tst_mod = sys.modules.get(tst_mod_name, None)
+ if tst_mod:
+ if name in tst_mod.__dict__:
+ self.save_reduce(_getobject, (tst_mod_name, name))
+ return
+ raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' % str(obj))
+
+ def inject_timeseries(self):
+ """Handle bugs with pickling scikits timeseries"""
+ tseries = sys.modules.get('scikits.timeseries.tseries')
+ if not tseries or not hasattr(tseries, 'Timeseries'):
+ return
+ self.dispatch[tseries.Timeseries] = self.__class__.save_timeseries
+
+ def save_timeseries(self, obj):
+ import scikits.timeseries.tseries as ts
+
+ func, reduce_args, state = obj.__reduce__()
+ if func != ts._tsreconstruct:
+ raise pickle.PicklingError('timeseries using unexpected reconstruction function %s' % str(func))
+ state = (1,
+ obj.shape,
+ obj.dtype,
+ obj.flags.fnc,
+ obj._data.tostring(),
+ ts.getmaskarray(obj).tostring(),
+ obj._fill_value,
+ obj._dates.shape,
+ obj._dates.__array__().tostring(),
+ obj._dates.dtype, #added -- preserve type
+ obj.freq,
+ obj._optinfo,
+ )
+ return self.save_reduce(_genTimeSeries, (reduce_args, state))
+
+ def inject_email(self):
+ """Block email LazyImporters from being saved"""
+ email = sys.modules.get('email')
+ if not email:
+ return
+ self.dispatch[email.LazyImporter] = self.__class__.save_unsupported
+
+ def inject_addons(self):
+ """Plug in system. Register additional pickling functions if modules already loaded"""
+ self.inject_numpy()
+ self.inject_timeseries()
+ self.inject_email()
+
+ """Python Imaging Library"""
+ def save_image(self, obj):
+ if not obj.im and obj.fp and 'r' in obj.fp.mode and obj.fp.name \
+ and not obj.fp.closed and (not hasattr(obj, 'isatty') or not obj.isatty()):
+ #if image not loaded yet -- lazy load
+ self.save_reduce(_lazyloadImage,(obj.fp,), obj=obj)
+ else:
+ #image is loaded - just transmit it over
+ self.save_reduce(_generateImage, (obj.size, obj.mode, obj.tostring()), obj=obj)
+
+ """
+ def memoize(self, obj):
+ pickle.Pickler.memoize(self, obj)
+ if printMemoization:
+ print 'memoizing ' + str(obj)
+ """
+
+
+
+# Shorthands for legacy support
+
+def dump(obj, file, protocol=2):
+ CloudPickler(file, protocol).dump(obj)
+
+def dumps(obj, protocol=2):
+ file = StringIO()
+
+ cp = CloudPickler(file,protocol)
+ cp.dump(obj)
+
+ #print 'cloud dumped', str(obj), str(cp.modules)
+
+ return file.getvalue()
+
+
+#hack for __import__ not working as desired
+def subimport(name):
+ __import__(name)
+ return sys.modules[name]
+
+#hack to load django settings:
+def django_settings_load(name):
+ modified_env = False
+
+ if 'DJANGO_SETTINGS_MODULE' not in os.environ:
+ os.environ['DJANGO_SETTINGS_MODULE'] = name # must set name first due to circular deps
+ modified_env = True
+ try:
+ module = subimport(name)
+ except Exception, i:
+ print >> sys.stderr, 'Cloud not import django settings %s:' % (name)
+ print_exec(sys.stderr)
+ if modified_env:
+ del os.environ['DJANGO_SETTINGS_MODULE']
+ else:
+ #add project directory to sys,path:
+ if hasattr(module,'__file__'):
+ dirname = os.path.split(module.__file__)[0] + '/'
+ sys.path.append(dirname)
+
+# restores function attributes
+def _restore_attr(obj, attr):
+ for key, val in attr.items():
+ setattr(obj, key, val)
+ return obj
+
+def _get_module_builtins():
+ return pickle.__builtins__
+
+def print_exec(stream):
+ ei = sys.exc_info()
+ traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
+
+def _modules_to_main(modList):
+ """Force every module in modList to be placed into main"""
+ if not modList:
+ return
+
+ main = sys.modules['__main__']
+ for modname in modList:
+ if type(modname) is str:
+ try:
+ mod = __import__(modname)
+ except Exception, i: #catch all...
+ sys.stderr.write('warning: could not import %s\n. Your function may unexpectedly error due to this import failing; \
+A version mismatch is likely. Specific error was:\n' % modname)
+ print_exec(sys.stderr)
+ else:
+ setattr(main,mod.__name__, mod)
+ else:
+ #REVERSE COMPATIBILITY FOR CLOUD CLIENT 1.5 (WITH EPD)
+ #In old version actual module was sent
+ setattr(main,modname.__name__, modname)
+
+#object generators:
+def _build_xrange(start, step, len):
+ """Built xrange explicitly"""
+ return xrange(start, start + step*len, step)
+
+def _genpartial(func, args, kwds):
+ if not args:
+ args = ()
+ if not kwds:
+ kwds = {}
+ return partial(func, *args, **kwds)
+
+
+def _fill_function(func, globals, defaults, closure, dict):
+ """ Fills in the rest of function data into the skeleton function object
+ that were created via _make_skel_func().
+ """
+ func.func_globals.update(globals)
+ func.func_defaults = defaults
+ func.func_dict = dict
+
+ if len(closure) != len(func.func_closure):
+ raise pickle.UnpicklingError("closure lengths don't match up")
+ for i in range(len(closure)):
+ _change_cell_value(func.func_closure[i], closure[i])
+
+ return func
+
+def _make_skel_func(code, num_closures, base_globals = None):
+ """ Creates a skeleton function object that contains just the provided
+ code and the correct number of cells in func_closure. All other
+ func attributes (e.g. func_globals) are empty.
+ """
+ #build closure (cells):
+ if not ctypes:
+ raise Exception('ctypes failed to import; cannot build function')
+
+ cellnew = ctypes.pythonapi.PyCell_New
+ cellnew.restype = ctypes.py_object
+ cellnew.argtypes = (ctypes.py_object,)
+ dummy_closure = tuple(map(lambda i: cellnew(None), range(num_closures)))
+
+ if base_globals is None:
+ base_globals = {}
+ base_globals['__builtins__'] = __builtins__
+
+ return types.FunctionType(code, base_globals,
+ None, None, dummy_closure)
+
+# this piece of opaque code is needed below to modify 'cell' contents
+cell_changer_code = new.code(
+ 1, 1, 2, 0,
+ ''.join([
+ chr(dis.opmap['LOAD_FAST']), '\x00\x00',
+ chr(dis.opmap['DUP_TOP']),
+ chr(dis.opmap['STORE_DEREF']), '\x00\x00',
+ chr(dis.opmap['RETURN_VALUE'])
+ ]),
+ (), (), ('newval',), '<nowhere>', 'cell_changer', 1, '', ('c',), ()
+)
+
+def _change_cell_value(cell, newval):
+ """ Changes the contents of 'cell' object to newval """
+ return new.function(cell_changer_code, {}, None, (), (cell,))(newval)
+
+"""Constructors for 3rd party libraries
+Note: These can never be renamed due to client compatibility issues"""
+
+def _getobject(modname, attribute):
+ mod = __import__(modname)
+ return mod.__dict__[attribute]
+
+def _generateImage(size, mode, str_rep):
+ """Generate image from string representation"""
+ import Image
+ i = Image.new(mode, size)
+ i.fromstring(str_rep)
+ return i
+
+def _lazyloadImage(fp):
+ import Image
+ fp.seek(0) #works in almost any case
+ return Image.open(fp)
+
+"""Timeseries"""
+def _genTimeSeries(reduce_args, state):
+ import scikits.timeseries.tseries as ts
+ from numpy import ndarray
+ from numpy.ma import MaskedArray
+
+
+ time_series = ts._tsreconstruct(*reduce_args)
+
+ #from setstate modified
+ (ver, shp, typ, isf, raw, msk, flv, dsh, dtm, dtyp, frq, infodict) = state
+ #print 'regenerating %s' % dtyp
+
+ MaskedArray.__setstate__(time_series, (ver, shp, typ, isf, raw, msk, flv))
+ _dates = time_series._dates
+ #_dates.__setstate__((ver, dsh, typ, isf, dtm, frq)) #use remote typ
+ ndarray.__setstate__(_dates,(dsh,dtyp, isf, dtm))
+ _dates.freq = frq
+ _dates._cachedinfo.update(dict(full=None, hasdups=None, steps=None,
+ toobj=None, toord=None, tostr=None))
+ # Update the _optinfo dictionary
+ time_series._optinfo.update(infodict)
+ return time_series
+
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
new file mode 100644
index 0000000000..3e33776af0
--- /dev/null
+++ b/python/pyspark/context.py
@@ -0,0 +1,258 @@
+import os
+import atexit
+import shutil
+import sys
+import tempfile
+from threading import Lock
+from tempfile import NamedTemporaryFile
+
+from pyspark import accumulators
+from pyspark.accumulators import Accumulator
+from pyspark.broadcast import Broadcast
+from pyspark.files import SparkFiles
+from pyspark.java_gateway import launch_gateway
+from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.rdd import RDD
+
+from py4j.java_collections import ListConverter
+
+
+class SparkContext(object):
+ """
+ Main entry point for Spark functionality. A SparkContext represents the
+ connection to a Spark cluster, and can be used to create L{RDD}s and
+ broadcast variables on that cluster.
+ """
+
+ gateway = launch_gateway()
+ jvm = gateway.jvm
+ _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
+ _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
+ _takePartition = jvm.PythonRDD.takePartition
+ _next_accum_id = 0
+ _active_spark_context = None
+ _lock = Lock()
+
+ def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
+ environment=None, batchSize=1024):
+ """
+ Create a new SparkContext.
+
+ @param master: Cluster URL to connect to
+ (e.g. mesos://host:port, spark://host:port, local[4]).
+ @param jobName: A name for your job, to display on the cluster web UI
+ @param sparkHome: Location where Spark is installed on cluster nodes.
+ @param pyFiles: Collection of .zip or .py files to send to the cluster
+ and add to PYTHONPATH. These can be paths on the local file
+ system or HDFS, HTTP, HTTPS, or FTP URLs.
+ @param environment: A dictionary of environment variables to set on
+ worker nodes.
+ @param batchSize: The number of Python objects represented as a single
+ Java object. Set 1 to disable batching or -1 to use an
+ unlimited batch size.
+ """
+ with SparkContext._lock:
+ if SparkContext._active_spark_context:
+ raise ValueError("Cannot run multiple SparkContexts at once")
+ else:
+ SparkContext._active_spark_context = self
+ self.master = master
+ self.jobName = jobName
+ self.sparkHome = sparkHome or None # None becomes null in Py4J
+ self.environment = environment or {}
+ self.batchSize = batchSize # -1 represents a unlimited batch size
+
+ # Create the Java SparkContext through Py4J
+ empty_string_array = self.gateway.new_array(self.jvm.String, 0)
+ self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
+ empty_string_array)
+
+ # Create a single Accumulator in Java that we'll send all our updates through;
+ # they will be passed back to us through a TCP server
+ self._accumulatorServer = accumulators._start_update_server()
+ (host, port) = self._accumulatorServer.server_address
+ self._javaAccumulator = self._jsc.accumulator(
+ self.jvm.java.util.ArrayList(),
+ self.jvm.PythonAccumulatorParam(host, port))
+
+ self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
+ # Broadcast's __reduce__ method stores Broadcast instances here.
+ # This allows other code to determine which Broadcast instances have
+ # been pickled, so it can determine which Java broadcast objects to
+ # send.
+ self._pickled_broadcast_vars = set()
+
+ # Deploy any code dependencies specified in the constructor
+ for path in (pyFiles or []):
+ self.addPyFile(path)
+ SparkFiles._sc = self
+ sys.path.append(SparkFiles.getRootDirectory())
+
+ @property
+ def defaultParallelism(self):
+ """
+ Default level of parallelism to use when not given by user (e.g. for
+ reduce tasks)
+ """
+ return self._jsc.sc().defaultParallelism()
+
+ def __del__(self):
+ self.stop()
+
+ def stop(self):
+ """
+ Shut down the SparkContext.
+ """
+ if self._jsc:
+ self._jsc.stop()
+ self._jsc = None
+ if self._accumulatorServer:
+ self._accumulatorServer.shutdown()
+ self._accumulatorServer = None
+ with SparkContext._lock:
+ SparkContext._active_spark_context = None
+
+ def parallelize(self, c, numSlices=None):
+ """
+ Distribute a local Python collection to form an RDD.
+ """
+ numSlices = numSlices or self.defaultParallelism
+ # Calling the Java parallelize() method with an ArrayList is too slow,
+ # because it sends O(n) Py4J commands. As an alternative, serialized
+ # objects are written to a file and loaded through textFile().
+ tempFile = NamedTemporaryFile(delete=False)
+ atexit.register(lambda: os.unlink(tempFile.name))
+ if self.batchSize != 1:
+ c = batched(c, self.batchSize)
+ for x in c:
+ write_with_length(dump_pickle(x), tempFile)
+ tempFile.close()
+ jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+ return RDD(jrdd, self)
+
+ def textFile(self, name, minSplits=None):
+ """
+ Read a text file from HDFS, a local file system (available on all
+ nodes), or any Hadoop-supported file system URI, and return it as an
+ RDD of Strings.
+ """
+ minSplits = minSplits or min(self.defaultParallelism, 2)
+ jrdd = self._jsc.textFile(name, minSplits)
+ return RDD(jrdd, self)
+
+ def _checkpointFile(self, name):
+ jrdd = self._jsc.checkpointFile(name)
+ return RDD(jrdd, self)
+
+ def union(self, rdds):
+ """
+ Build the union of a list of RDDs.
+ """
+ first = rdds[0]._jrdd
+ rest = [x._jrdd for x in rdds[1:]]
+ rest = ListConverter().convert(rest, self.gateway._gateway_client)
+ return RDD(self._jsc.union(first, rest), self)
+
+ def broadcast(self, value):
+ """
+ Broadcast a read-only variable to the cluster, returning a C{Broadcast}
+ object for reading it in distributed functions. The variable will be
+ sent to each cluster only once.
+ """
+ jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+ return Broadcast(jbroadcast.id(), value, jbroadcast,
+ self._pickled_broadcast_vars)
+
+ def accumulator(self, value, accum_param=None):
+ """
+ Create an C{Accumulator} with the given initial value, using a given
+ AccumulatorParam helper object to define how to add values of the data
+ type if provided. Default AccumulatorParams are used for integers and
+ floating-point numbers if you do not provide one. For other types, the
+ AccumulatorParam must implement two methods:
+ - C{zero(value)}: provide a "zero value" for the type, compatible in
+ dimensions with the provided C{value} (e.g., a zero vector).
+ - C{addInPlace(val1, val2)}: add two values of the accumulator's data
+ type, returning a new value; for efficiency, can also update C{val1}
+ in place and return it.
+ """
+ if accum_param == None:
+ if isinstance(value, int):
+ accum_param = accumulators.INT_ACCUMULATOR_PARAM
+ elif isinstance(value, float):
+ accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
+ elif isinstance(value, complex):
+ accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
+ else:
+ raise Exception("No default accumulator param for type %s" % type(value))
+ SparkContext._next_accum_id += 1
+ return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
+
+ def addFile(self, path):
+ """
+ Add a file to be downloaded with this Spark job on every node.
+ The C{path} passed can be either a local file, a file in HDFS
+ (or other Hadoop-supported filesystems), or an HTTP, HTTPS or
+ FTP URI.
+
+ To access the file in Spark jobs, use
+ L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its
+ download location.
+
+ >>> from pyspark import SparkFiles
+ >>> path = os.path.join(tempdir, "test.txt")
+ >>> with open(path, "w") as testFile:
+ ... testFile.write("100")
+ >>> sc.addFile(path)
+ >>> def func(iterator):
+ ... with open(SparkFiles.get("test.txt")) as testFile:
+ ... fileVal = int(testFile.readline())
+ ... return [x * 100 for x in iterator]
+ >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect()
+ [100, 200, 300, 400]
+ """
+ self._jsc.sc().addFile(path)
+
+ def clearFiles(self):
+ """
+ Clear the job's list of files added by L{addFile} or L{addPyFile} so
+ that they do not get downloaded to any new nodes.
+ """
+ # TODO: remove added .py or .zip files from the PYTHONPATH?
+ self._jsc.sc().clearFiles()
+
+ def addPyFile(self, path):
+ """
+ Add a .py or .zip dependency for all tasks to be executed on this
+ SparkContext in the future. The C{path} passed can be either a local
+ file, a file in HDFS (or other Hadoop-supported filesystems), or an
+ HTTP, HTTPS or FTP URI.
+ """
+ self.addFile(path)
+ filename = path.split("/")[-1]
+
+ def setCheckpointDir(self, dirName, useExisting=False):
+ """
+ Set the directory under which RDDs are going to be checkpointed. The
+ directory must be a HDFS path if running on a cluster.
+
+ If the directory does not exist, it will be created. If the directory
+ exists and C{useExisting} is set to true, then the exisiting directory
+ will be used. Otherwise an exception will be thrown to prevent
+ accidental overriding of checkpoint files in the existing directory.
+ """
+ self._jsc.sc().setCheckpointDir(dirName, useExisting)
+
+
+def _test():
+ import doctest
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['tempdir'] = tempfile.mkdtemp()
+ atexit.register(lambda: shutil.rmtree(globs['tempdir']))
+ doctest.testmod(globs=globs)
+ globs['sc'].stop()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/files.py b/python/pyspark/files.py
new file mode 100644
index 0000000000..98f6a399cc
--- /dev/null
+++ b/python/pyspark/files.py
@@ -0,0 +1,38 @@
+import os
+
+
+class SparkFiles(object):
+ """
+ Resolves paths to files added through
+ L{SparkContext.addFile()<pyspark.context.SparkContext.addFile>}.
+
+ SparkFiles contains only classmethods; users should not create SparkFiles
+ instances.
+ """
+
+ _root_directory = None
+ _is_running_on_worker = False
+ _sc = None
+
+ def __init__(self):
+ raise NotImplementedError("Do not construct SparkFiles objects")
+
+ @classmethod
+ def get(cls, filename):
+ """
+ Get the absolute path of a file added through C{SparkContext.addFile()}.
+ """
+ path = os.path.join(SparkFiles.getRootDirectory(), filename)
+ return os.path.abspath(path)
+
+ @classmethod
+ def getRootDirectory(cls):
+ """
+ Get the root directory that contains files added through
+ C{SparkContext.addFile()}.
+ """
+ if cls._is_running_on_worker:
+ return cls._root_directory
+ else:
+ # This will have to change if we support multiple SparkContexts:
+ return cls._sc.jvm.spark.SparkFiles.getRootDirectory()
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
new file mode 100644
index 0000000000..2329e536cc
--- /dev/null
+++ b/python/pyspark/java_gateway.py
@@ -0,0 +1,38 @@
+import os
+import sys
+from subprocess import Popen, PIPE
+from threading import Thread
+from py4j.java_gateway import java_import, JavaGateway, GatewayClient
+
+
+SPARK_HOME = os.environ["SPARK_HOME"]
+
+
+def launch_gateway():
+ # Launch the Py4j gateway using Spark's run command so that we pick up the
+ # proper classpath and SPARK_MEM settings from spark-env.sh
+ command = [os.path.join(SPARK_HOME, "run"), "py4j.GatewayServer",
+ "--die-on-broken-pipe", "0"]
+ proc = Popen(command, stdout=PIPE, stdin=PIPE)
+ # Determine which ephemeral port the server started on:
+ port = int(proc.stdout.readline())
+ # Create a thread to echo output from the GatewayServer, which is required
+ # for Java log output to show up:
+ class EchoOutputThread(Thread):
+ def __init__(self, stream):
+ Thread.__init__(self)
+ self.daemon = True
+ self.stream = stream
+
+ def run(self):
+ while True:
+ line = self.stream.readline()
+ sys.stderr.write(line)
+ EchoOutputThread(proc.stdout).start()
+ # Connect to the gateway
+ gateway = JavaGateway(GatewayClient(port=port), auto_convert=False)
+ # Import the classes used by PySpark
+ java_import(gateway.jvm, "spark.api.java.*")
+ java_import(gateway.jvm, "spark.api.python.*")
+ java_import(gateway.jvm, "scala.Tuple2")
+ return gateway
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
new file mode 100644
index 0000000000..7036c47980
--- /dev/null
+++ b/python/pyspark/join.py
@@ -0,0 +1,92 @@
+"""
+Copyright (c) 2011, Douban Inc. <http://www.douban.com/>
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+
+ * Neither the name of the Douban Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+
+def _do_python_join(rdd, other, numSplits, dispatch):
+ vs = rdd.map(lambda (k, v): (k, (1, v)))
+ ws = other.map(lambda (k, v): (k, (2, v)))
+ return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch)
+
+
+def python_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ return [(v, w) for v in vbuf for w in wbuf]
+ return _do_python_join(rdd, other, numSplits, dispatch)
+
+
+def python_right_outer_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ if not vbuf:
+ vbuf.append(None)
+ return [(v, w) for v in vbuf for w in wbuf]
+ return _do_python_join(rdd, other, numSplits, dispatch)
+
+
+def python_left_outer_join(rdd, other, numSplits):
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ if not wbuf:
+ wbuf.append(None)
+ return [(v, w) for v in vbuf for w in wbuf]
+ return _do_python_join(rdd, other, numSplits, dispatch)
+
+
+def python_cogroup(rdd, other, numSplits):
+ vs = rdd.map(lambda (k, v): (k, (1, v)))
+ ws = other.map(lambda (k, v): (k, (2, v)))
+ def dispatch(seq):
+ vbuf, wbuf = [], []
+ for (n, v) in seq:
+ if n == 1:
+ vbuf.append(v)
+ elif n == 2:
+ wbuf.append(v)
+ return (vbuf, wbuf)
+ return vs.union(ws).groupByKey(numSplits).mapValues(dispatch)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
new file mode 100644
index 0000000000..d53355a8f1
--- /dev/null
+++ b/python/pyspark/rdd.py
@@ -0,0 +1,761 @@
+import atexit
+from base64 import standard_b64encode as b64enc
+import copy
+from collections import defaultdict
+from itertools import chain, ifilter, imap, product
+import operator
+import os
+import shlex
+from subprocess import Popen, PIPE
+from tempfile import NamedTemporaryFile
+from threading import Thread
+
+from pyspark import cloudpickle
+from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
+ read_from_pickle_file
+from pyspark.join import python_join, python_left_outer_join, \
+ python_right_outer_join, python_cogroup
+
+from py4j.java_collections import ListConverter, MapConverter
+
+
+__all__ = ["RDD"]
+
+
+class RDD(object):
+ """
+ A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
+ Represents an immutable, partitioned collection of elements that can be
+ operated on in parallel.
+ """
+
+ def __init__(self, jrdd, ctx):
+ self._jrdd = jrdd
+ self.is_cached = False
+ self.is_checkpointed = False
+ self.ctx = ctx
+ self._partitionFunc = None
+
+ @property
+ def context(self):
+ """
+ The L{SparkContext} that this RDD was created on.
+ """
+ return self.ctx
+
+ def cache(self):
+ """
+ Persist this RDD with the default storage level (C{MEMORY_ONLY}).
+ """
+ self.is_cached = True
+ self._jrdd.cache()
+ return self
+
+ def checkpoint(self):
+ """
+ Mark this RDD for checkpointing. It will be saved to a file inside the
+ checkpoint directory set with L{SparkContext.setCheckpointDir()} and
+ all references to its parent RDDs will be removed. This function must
+ be called before any job has been executed on this RDD. It is strongly
+ recommended that this RDD is persisted in memory, otherwise saving it
+ on a file will require recomputation.
+ """
+ self.is_checkpointed = True
+ self._jrdd.rdd().checkpoint()
+
+ def isCheckpointed(self):
+ """
+ Return whether this RDD has been checkpointed or not
+ """
+ return self._jrdd.rdd().isCheckpointed()
+
+ def getCheckpointFile(self):
+ """
+ Gets the name of the file to which this RDD was checkpointed
+ """
+ checkpointFile = self._jrdd.rdd().getCheckpointFile()
+ if checkpointFile.isDefined():
+ return checkpointFile.get()
+ else:
+ return None
+
+ # TODO persist(self, storageLevel)
+
+ def map(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD containing the distinct elements in this RDD.
+ """
+ def func(split, iterator): return imap(f, iterator)
+ return PipelinedRDD(self, func, preservesPartitioning)
+
+ def flatMap(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD by first applying a function to all elements of this
+ RDD, and then flattening the results.
+
+ >>> rdd = sc.parallelize([2, 3, 4])
+ >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect())
+ [1, 1, 1, 2, 2, 3]
+ >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
+ [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
+ """
+ def func(s, iterator): return chain.from_iterable(imap(f, iterator))
+ return self.mapPartitionsWithSplit(func, preservesPartitioning)
+
+ def mapPartitions(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD by applying a function to each partition of this RDD.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
+ >>> def f(iterator): yield sum(iterator)
+ >>> rdd.mapPartitions(f).collect()
+ [3, 7]
+ """
+ def func(s, iterator): return f(iterator)
+ return self.mapPartitionsWithSplit(func)
+
+ def mapPartitionsWithSplit(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD by applying a function to each partition of this RDD,
+ while tracking the index of the original partition.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
+ >>> def f(splitIndex, iterator): yield splitIndex
+ >>> rdd.mapPartitionsWithSplit(f).sum()
+ 6
+ """
+ return PipelinedRDD(self, f, preservesPartitioning)
+
+ def filter(self, f):
+ """
+ Return a new RDD containing only the elements that satisfy a predicate.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4, 5])
+ >>> rdd.filter(lambda x: x % 2 == 0).collect()
+ [2, 4]
+ """
+ def func(iterator): return ifilter(f, iterator)
+ return self.mapPartitions(func)
+
+ def distinct(self):
+ """
+ Return a new RDD containing the distinct elements in this RDD.
+
+ >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
+ [1, 2, 3]
+ """
+ return self.map(lambda x: (x, "")) \
+ .reduceByKey(lambda x, _: x) \
+ .map(lambda (x, _): x)
+
+ # TODO: sampling needs to be re-implemented due to Batch
+ #def sample(self, withReplacement, fraction, seed):
+ # jrdd = self._jrdd.sample(withReplacement, fraction, seed)
+ # return RDD(jrdd, self.ctx)
+
+ #def takeSample(self, withReplacement, num, seed):
+ # vals = self._jrdd.takeSample(withReplacement, num, seed)
+ # return [load_pickle(bytes(x)) for x in vals]
+
+ def union(self, other):
+ """
+ Return the union of this RDD and another one.
+
+ >>> rdd = sc.parallelize([1, 1, 2, 3])
+ >>> rdd.union(rdd).collect()
+ [1, 1, 2, 3, 1, 1, 2, 3]
+ """
+ return RDD(self._jrdd.union(other._jrdd), self.ctx)
+
+ def __add__(self, other):
+ """
+ Return the union of this RDD and another one.
+
+ >>> rdd = sc.parallelize([1, 1, 2, 3])
+ >>> (rdd + rdd).collect()
+ [1, 1, 2, 3, 1, 1, 2, 3]
+ """
+ if not isinstance(other, RDD):
+ raise TypeError
+ return self.union(other)
+
+ # TODO: sort
+
+ def glom(self):
+ """
+ Return an RDD created by coalescing all elements within each partition
+ into a list.
+
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
+ >>> sorted(rdd.glom().collect())
+ [[1, 2], [3, 4]]
+ """
+ def func(iterator): yield list(iterator)
+ return self.mapPartitions(func)
+
+ def cartesian(self, other):
+ """
+ Return the Cartesian product of this RDD and another one, that is, the
+ RDD of all pairs of elements C{(a, b)} where C{a} is in C{self} and
+ C{b} is in C{other}.
+
+ >>> rdd = sc.parallelize([1, 2])
+ >>> sorted(rdd.cartesian(rdd).collect())
+ [(1, 1), (1, 2), (2, 1), (2, 2)]
+ """
+ # Due to batching, we can't use the Java cartesian method.
+ java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
+ def unpack_batches(pair):
+ (x, y) = pair
+ if type(x) == Batch or type(y) == Batch:
+ xs = x.items if type(x) == Batch else [x]
+ ys = y.items if type(y) == Batch else [y]
+ for pair in product(xs, ys):
+ yield pair
+ else:
+ yield pair
+ return java_cartesian.flatMap(unpack_batches)
+
+ def groupBy(self, f, numSplits=None):
+ """
+ Return an RDD of grouped items.
+
+ >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
+ >>> result = rdd.groupBy(lambda x: x % 2).collect()
+ >>> sorted([(x, sorted(y)) for (x, y) in result])
+ [(0, [2, 8]), (1, [1, 1, 3, 5])]
+ """
+ return self.map(lambda x: (f(x), x)).groupByKey(numSplits)
+
+ def pipe(self, command, env={}):
+ """
+ Return an RDD created by piping elements to a forked external process.
+
+ >>> sc.parallelize([1, 2, 3]).pipe('cat').collect()
+ ['1', '2', '3']
+ """
+ def func(iterator):
+ pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
+ def pipe_objs(out):
+ for obj in iterator:
+ out.write(str(obj).rstrip('\n') + '\n')
+ out.close()
+ Thread(target=pipe_objs, args=[pipe.stdin]).start()
+ return (x.rstrip('\n') for x in pipe.stdout)
+ return self.mapPartitions(func)
+
+ def foreach(self, f):
+ """
+ Applies a function to all elements of this RDD.
+
+ >>> def f(x): print x
+ >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
+ """
+ self.map(f).collect() # Force evaluation
+
+ def collect(self):
+ """
+ Return a list that contains all of the elements in this RDD.
+ """
+ picklesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(picklesInJava))
+
+ def _collect_iterator_through_file(self, iterator):
+ # Transferring lots of data through Py4J can be slow because
+ # socket.readline() is inefficient. Instead, we'll dump the data to a
+ # file and read it back.
+ tempFile = NamedTemporaryFile(delete=False)
+ tempFile.close()
+ def clean_up_file():
+ try: os.unlink(tempFile.name)
+ except: pass
+ atexit.register(clean_up_file)
+ self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+ # Read the data into Python and deserialize it:
+ with open(tempFile.name, 'rb') as tempFile:
+ for item in read_from_pickle_file(tempFile):
+ yield item
+ os.unlink(tempFile.name)
+
+ def reduce(self, f):
+ """
+ Reduces the elements of this RDD using the specified associative binary
+ operator.
+
+ >>> from operator import add
+ >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
+ 15
+ >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
+ 10
+ """
+ def func(iterator):
+ acc = None
+ for obj in iterator:
+ if acc is None:
+ acc = obj
+ else:
+ acc = f(obj, acc)
+ if acc is not None:
+ yield acc
+ vals = self.mapPartitions(func).collect()
+ return reduce(f, vals)
+
+ def fold(self, zeroValue, op):
+ """
+ Aggregate the elements of each partition, and then the results for all
+ the partitions, using a given associative function and a neutral "zero
+ value."
+
+ The function C{op(t1, t2)} is allowed to modify C{t1} and return it
+ as its result value to avoid object allocation; however, it should not
+ modify C{t2}.
+
+ >>> from operator import add
+ >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
+ 15
+ """
+ def func(iterator):
+ acc = zeroValue
+ for obj in iterator:
+ acc = op(obj, acc)
+ yield acc
+ vals = self.mapPartitions(func).collect()
+ return reduce(op, vals, zeroValue)
+
+ # TODO: aggregate
+
+ def sum(self):
+ """
+ Add up the elements in this RDD.
+
+ >>> sc.parallelize([1.0, 2.0, 3.0]).sum()
+ 6.0
+ """
+ return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add)
+
+ def count(self):
+ """
+ Return the number of elements in this RDD.
+
+ >>> sc.parallelize([2, 3, 4]).count()
+ 3
+ """
+ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
+
+ def countByValue(self):
+ """
+ Return the count of each unique value in this RDD as a dictionary of
+ (value, count) pairs.
+
+ >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items())
+ [(1, 2), (2, 3)]
+ """
+ def countPartition(iterator):
+ counts = defaultdict(int)
+ for obj in iterator:
+ counts[obj] += 1
+ yield counts
+ def mergeMaps(m1, m2):
+ for (k, v) in m2.iteritems():
+ m1[k] += v
+ return m1
+ return self.mapPartitions(countPartition).reduce(mergeMaps)
+
+ def take(self, num):
+ """
+ Take the first num elements of the RDD.
+
+ This currently scans the partitions *one by one*, so it will be slow if
+ a lot of partitions are required. In that case, use L{collect} to get
+ the whole RDD instead.
+
+ >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
+ [2, 3]
+ >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
+ [2, 3, 4, 5, 6]
+ """
+ items = []
+ for partition in range(self._jrdd.splits().size()):
+ iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
+ items.extend(self._collect_iterator_through_file(iterator))
+ if len(items) >= num:
+ break
+ return items[:num]
+
+ def first(self):
+ """
+ Return the first element in this RDD.
+
+ >>> sc.parallelize([2, 3, 4]).first()
+ 2
+ """
+ return self.take(1)[0]
+
+ def saveAsTextFile(self, path):
+ """
+ Save this RDD as a text file, using string representations of elements.
+
+ >>> tempFile = NamedTemporaryFile(delete=True)
+ >>> tempFile.close()
+ >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name)
+ >>> from fileinput import input
+ >>> from glob import glob
+ >>> ''.join(input(glob(tempFile.name + "/part-0000*")))
+ '0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
+ """
+ def func(split, iterator):
+ return (str(x).encode("utf-8") for x in iterator)
+ keyed = PipelinedRDD(self, func)
+ keyed._bypass_serializer = True
+ keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
+
+ # Pair functions
+
+ def collectAsMap(self):
+ """
+ Return the key-value pairs in this RDD to the master as a dictionary.
+
+ >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap()
+ >>> m[1]
+ 2
+ >>> m[3]
+ 4
+ """
+ return dict(self.collect())
+
+ def reduceByKey(self, func, numSplits=None):
+ """
+ Merge the values for each key using an associative reduce function.
+
+ This will also perform the merging locally on each mapper before
+ sending results to a reducer, similarly to a "combiner" in MapReduce.
+
+ Output will be hash-partitioned with C{numSplits} splits, or the
+ default parallelism level if C{numSplits} is not specified.
+
+ >>> from operator import add
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(rdd.reduceByKey(add).collect())
+ [('a', 2), ('b', 1)]
+ """
+ return self.combineByKey(lambda x: x, func, func, numSplits)
+
+ def reduceByKeyLocally(self, func):
+ """
+ Merge the values for each key using an associative reduce function, but
+ return the results immediately to the master as a dictionary.
+
+ This will also perform the merging locally on each mapper before
+ sending results to a reducer, similarly to a "combiner" in MapReduce.
+
+ >>> from operator import add
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(rdd.reduceByKeyLocally(add).items())
+ [('a', 2), ('b', 1)]
+ """
+ def reducePartition(iterator):
+ m = {}
+ for (k, v) in iterator:
+ m[k] = v if k not in m else func(m[k], v)
+ yield m
+ def mergeMaps(m1, m2):
+ for (k, v) in m2.iteritems():
+ m1[k] = v if k not in m1 else func(m1[k], v)
+ return m1
+ return self.mapPartitions(reducePartition).reduce(mergeMaps)
+
+ def countByKey(self):
+ """
+ Count the number of elements for each key, and return the result to the
+ master as a dictionary.
+
+ >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(rdd.countByKey().items())
+ [('a', 2), ('b', 1)]
+ """
+ return self.map(lambda x: x[0]).countByValue()
+
+ def join(self, other, numSplits=None):
+ """
+ Return an RDD containing all pairs of elements with matching keys in
+ C{self} and C{other}.
+
+ Each pair of elements will be returned as a (k, (v1, v2)) tuple, where
+ (k, v1) is in C{self} and (k, v2) is in C{other}.
+
+ Performs a hash join across the cluster.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2), ("a", 3)])
+ >>> sorted(x.join(y).collect())
+ [('a', (1, 2)), ('a', (1, 3))]
+ """
+ return python_join(self, other, numSplits)
+
+ def leftOuterJoin(self, other, numSplits=None):
+ """
+ Perform a left outer join of C{self} and C{other}.
+
+ For each element (k, v) in C{self}, the resulting RDD will either
+ contain all pairs (k, (v, w)) for w in C{other}, or the pair
+ (k, (v, None)) if no elements in other have key k.
+
+ Hash-partitions the resulting RDD into the given number of partitions.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
+ >>> sorted(x.leftOuterJoin(y).collect())
+ [('a', (1, 2)), ('b', (4, None))]
+ """
+ return python_left_outer_join(self, other, numSplits)
+
+ def rightOuterJoin(self, other, numSplits=None):
+ """
+ Perform a right outer join of C{self} and C{other}.
+
+ For each element (k, w) in C{other}, the resulting RDD will either
+ contain all pairs (k, (v, w)) for v in this, or the pair (k, (None, w))
+ if no elements in C{self} have key k.
+
+ Hash-partitions the resulting RDD into the given number of partitions.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
+ >>> sorted(y.rightOuterJoin(x).collect())
+ [('a', (2, 1)), ('b', (None, 4))]
+ """
+ return python_right_outer_join(self, other, numSplits)
+
+ # TODO: add option to control map-side combining
+ def partitionBy(self, numSplits, partitionFunc=hash):
+ """
+ Return a copy of the RDD partitioned using the specified partitioner.
+
+ >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
+ >>> sets = pairs.partitionBy(2).glom().collect()
+ >>> set(sets[0]).intersection(set(sets[1]))
+ set([])
+ """
+ if numSplits is None:
+ numSplits = self.ctx.defaultParallelism
+ # Transferring O(n) objects to Java is too expensive. Instead, we'll
+ # form the hash buckets in Python, transferring O(numSplits) objects
+ # to Java. Each object is a (splitNumber, [objects]) pair.
+ def add_shuffle_key(split, iterator):
+ buckets = defaultdict(list)
+ for (k, v) in iterator:
+ buckets[partitionFunc(k) % numSplits].append((k, v))
+ for (split, items) in buckets.iteritems():
+ yield str(split)
+ yield dump_pickle(Batch(items))
+ keyed = PipelinedRDD(self, add_shuffle_key)
+ keyed._bypass_serializer = True
+ pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+ partitioner = self.ctx.jvm.PythonPartitioner(numSplits,
+ id(partitionFunc))
+ jrdd = pairRDD.partitionBy(partitioner).values()
+ rdd = RDD(jrdd, self.ctx)
+ # This is required so that id(partitionFunc) remains unique, even if
+ # partitionFunc is a lambda:
+ rdd._partitionFunc = partitionFunc
+ return rdd
+
+ # TODO: add control over map-side aggregation
+ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
+ numSplits=None):
+ """
+ Generic function to combine the elements for each key using a custom
+ set of aggregation functions.
+
+ Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined
+ type" C. Note that V and C can be different -- for example, one might
+ group an RDD of type (Int, Int) into an RDD of type (Int, List[Int]).
+
+ Users provide three functions:
+
+ - C{createCombiner}, which turns a V into a C (e.g., creates
+ a one-element list)
+ - C{mergeValue}, to merge a V into a C (e.g., adds it to the end of
+ a list)
+ - C{mergeCombiners}, to combine two C's into a single one.
+
+ In addition, users can control the partitioning of the output RDD.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> def f(x): return x
+ >>> def add(a, b): return a + str(b)
+ >>> sorted(x.combineByKey(str, add, add).collect())
+ [('a', '11'), ('b', '1')]
+ """
+ if numSplits is None:
+ numSplits = self.ctx.defaultParallelism
+ def combineLocally(iterator):
+ combiners = {}
+ for (k, v) in iterator:
+ if k not in combiners:
+ combiners[k] = createCombiner(v)
+ else:
+ combiners[k] = mergeValue(combiners[k], v)
+ return combiners.iteritems()
+ locally_combined = self.mapPartitions(combineLocally)
+ shuffled = locally_combined.partitionBy(numSplits)
+ def _mergeCombiners(iterator):
+ combiners = {}
+ for (k, v) in iterator:
+ if not k in combiners:
+ combiners[k] = v
+ else:
+ combiners[k] = mergeCombiners(combiners[k], v)
+ return combiners.iteritems()
+ return shuffled.mapPartitions(_mergeCombiners)
+
+ # TODO: support variant with custom partitioner
+ def groupByKey(self, numSplits=None):
+ """
+ Group the values for each key in the RDD into a single sequence.
+ Hash-partitions the resulting RDD with into numSplits partitions.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
+ >>> sorted(x.groupByKey().collect())
+ [('a', [1, 1]), ('b', [1])]
+ """
+
+ def createCombiner(x):
+ return [x]
+
+ def mergeValue(xs, x):
+ xs.append(x)
+ return xs
+
+ def mergeCombiners(a, b):
+ return a + b
+
+ return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
+ numSplits)
+
+ # TODO: add tests
+ def flatMapValues(self, f):
+ """
+ Pass each value in the key-value pair RDD through a flatMap function
+ without changing the keys; this also retains the original RDD's
+ partitioning.
+ """
+ flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
+ return self.flatMap(flat_map_fn, preservesPartitioning=True)
+
+ def mapValues(self, f):
+ """
+ Pass each value in the key-value pair RDD through a map function
+ without changing the keys; this also retains the original RDD's
+ partitioning.
+ """
+ map_values_fn = lambda (k, v): (k, f(v))
+ return self.map(map_values_fn, preservesPartitioning=True)
+
+ # TODO: support varargs cogroup of several RDDs.
+ def groupWith(self, other):
+ """
+ Alias for cogroup.
+ """
+ return self.cogroup(other)
+
+ # TODO: add variant with custom parittioner
+ def cogroup(self, other, numSplits=None):
+ """
+ For each key k in C{self} or C{other}, return a resulting RDD that
+ contains a tuple with the list of values for that key in C{self} as well
+ as C{other}.
+
+ >>> x = sc.parallelize([("a", 1), ("b", 4)])
+ >>> y = sc.parallelize([("a", 2)])
+ >>> sorted(x.cogroup(y).collect())
+ [('a', ([1], [2])), ('b', ([4], []))]
+ """
+ return python_cogroup(self, other, numSplits)
+
+ # TODO: `lookup` is disabled because we can't make direct comparisons based
+ # on the key; we need to compare the hash of the key to the hash of the
+ # keys in the pairs. This could be an expensive operation, since those
+ # hashes aren't retained.
+
+
+class PipelinedRDD(RDD):
+ """
+ Pipelined maps:
+ >>> rdd = sc.parallelize([1, 2, 3, 4])
+ >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
+ [4, 8, 12, 16]
+ >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
+ [4, 8, 12, 16]
+
+ Pipelined reduces:
+ >>> from operator import add
+ >>> rdd.map(lambda x: 2 * x).reduce(add)
+ 20
+ >>> rdd.flatMap(lambda x: [x, x]).reduce(add)
+ 20
+ """
+ def __init__(self, prev, func, preservesPartitioning=False):
+ if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
+ prev_func = prev.func
+ def pipeline_func(split, iterator):
+ return func(split, prev_func(split, iterator))
+ self.func = pipeline_func
+ self.preservesPartitioning = \
+ prev.preservesPartitioning and preservesPartitioning
+ self._prev_jrdd = prev._prev_jrdd
+ else:
+ self.func = func
+ self.preservesPartitioning = preservesPartitioning
+ self._prev_jrdd = prev._jrdd
+ self.is_cached = False
+ self.is_checkpointed = False
+ self.ctx = prev.ctx
+ self.prev = prev
+ self._jrdd_val = None
+ self._bypass_serializer = False
+
+ @property
+ def _jrdd(self):
+ if self._jrdd_val:
+ return self._jrdd_val
+ func = self.func
+ if not self._bypass_serializer and self.ctx.batchSize != 1:
+ oldfunc = self.func
+ batchSize = self.ctx.batchSize
+ def batched_func(split, iterator):
+ return batched(oldfunc(split, iterator), batchSize)
+ func = batched_func
+ cmds = [func, self._bypass_serializer]
+ pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+ broadcast_vars = ListConverter().convert(
+ [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
+ self.ctx.gateway._gateway_client)
+ self.ctx._pickled_broadcast_vars.clear()
+ class_manifest = self._prev_jrdd.classManifest()
+ env = copy.copy(self.ctx.environment)
+ env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
+ env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
+ python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
+ pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
+ broadcast_vars, self.ctx._javaAccumulator, class_manifest)
+ self._jrdd_val = python_rdd.asJavaRDD()
+ return self._jrdd_val
+
+ def _is_pipelinable(self):
+ return not (self.is_cached or self.is_checkpointed)
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ globs = globals().copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ doctest.testmod(globs=globs)
+ globs['sc'].stop()
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
new file mode 100644
index 0000000000..115cf28cc2
--- /dev/null
+++ b/python/pyspark/serializers.py
@@ -0,0 +1,83 @@
+import struct
+import cPickle
+
+
+class Batch(object):
+ """
+ Used to store multiple RDD entries as a single Java object.
+
+ This relieves us from having to explicitly track whether an RDD
+ is stored as batches of objects and avoids problems when processing
+ the union() of batched and unbatched RDDs (e.g. the union() of textFile()
+ with another RDD).
+ """
+ def __init__(self, items):
+ self.items = items
+
+
+def batched(iterator, batchSize):
+ if batchSize == -1: # unlimited batch size
+ yield Batch(list(iterator))
+ else:
+ items = []
+ count = 0
+ for item in iterator:
+ items.append(item)
+ count += 1
+ if count == batchSize:
+ yield Batch(items)
+ items = []
+ count = 0
+ if items:
+ yield Batch(items)
+
+
+def dump_pickle(obj):
+ return cPickle.dumps(obj, 2)
+
+
+load_pickle = cPickle.loads
+
+
+def read_long(stream):
+ length = stream.read(8)
+ if length == "":
+ raise EOFError
+ return struct.unpack("!q", length)[0]
+
+
+def read_int(stream):
+ length = stream.read(4)
+ if length == "":
+ raise EOFError
+ return struct.unpack("!i", length)[0]
+
+
+def write_int(value, stream):
+ stream.write(struct.pack("!i", value))
+
+
+def write_with_length(obj, stream):
+ write_int(len(obj), stream)
+ stream.write(obj)
+
+
+def read_with_length(stream):
+ length = read_int(stream)
+ obj = stream.read(length)
+ if obj == "":
+ raise EOFError
+ return obj
+
+
+def read_from_pickle_file(stream):
+ try:
+ while True:
+ obj = load_pickle(read_with_length(stream))
+ if type(obj) == Batch: # We don't care about inheritance
+ for item in obj.items:
+ yield item
+ else:
+ yield obj
+ except EOFError:
+ return
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
new file mode 100644
index 0000000000..f6328c561f
--- /dev/null
+++ b/python/pyspark/shell.py
@@ -0,0 +1,17 @@
+"""
+An interactive shell.
+
+This file is designed to be launched as a PYTHONSTARTUP script.
+"""
+import os
+from pyspark.context import SparkContext
+
+
+sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell")
+print "Spark context avaiable as sc."
+
+# The ./pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP,
+# which allows us to execute the user's PYTHONSTARTUP file:
+_pythonstartup = os.environ.get('OLD_PYTHONSTARTUP')
+if _pythonstartup and os.path.isfile(_pythonstartup):
+ execfile(_pythonstartup)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
new file mode 100644
index 0000000000..46ab34f063
--- /dev/null
+++ b/python/pyspark/tests.py
@@ -0,0 +1,112 @@
+"""
+Unit tests for PySpark; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import shutil
+import sys
+from tempfile import NamedTemporaryFile
+import time
+import unittest
+
+from pyspark.context import SparkContext
+from pyspark.files import SparkFiles
+from pyspark.java_gateway import SPARK_HOME
+
+
+class PySparkTestCase(unittest.TestCase):
+
+ def setUp(self):
+ self._old_sys_path = list(sys.path)
+ class_name = self.__class__.__name__
+ self.sc = SparkContext('local[4]', class_name , batchSize=2)
+
+ def tearDown(self):
+ self.sc.stop()
+ sys.path = self._old_sys_path
+ # To avoid Akka rebinding to the same port, since it doesn't unbind
+ # immediately on shutdown
+ self.sc.jvm.System.clearProperty("spark.master.port")
+
+
+class TestCheckpoint(PySparkTestCase):
+
+ def setUp(self):
+ PySparkTestCase.setUp(self)
+ self.checkpointDir = NamedTemporaryFile(delete=False)
+ os.unlink(self.checkpointDir.name)
+ self.sc.setCheckpointDir(self.checkpointDir.name)
+
+ def tearDown(self):
+ PySparkTestCase.tearDown(self)
+ shutil.rmtree(self.checkpointDir.name)
+
+ def test_basic_checkpointing(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ result = flatMappedRDD.collect()
+ time.sleep(1) # 1 second
+ self.assertTrue(flatMappedRDD.isCheckpointed())
+ self.assertEqual(flatMappedRDD.collect(), result)
+ self.assertEqual(self.checkpointDir.name,
+ os.path.dirname(flatMappedRDD.getCheckpointFile()))
+
+ def test_checkpoint_and_restore(self):
+ parCollection = self.sc.parallelize([1, 2, 3, 4])
+ flatMappedRDD = parCollection.flatMap(lambda x: [x])
+
+ self.assertFalse(flatMappedRDD.isCheckpointed())
+ self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+ flatMappedRDD.checkpoint()
+ flatMappedRDD.count() # forces a checkpoint to be computed
+ time.sleep(1) # 1 second
+
+ self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ self.assertEquals([1, 2, 3, 4], recovered.collect())
+
+
+class TestAddFile(PySparkTestCase):
+
+ def test_add_py_file(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this job fails due to `userlibrary` not being on the Python path:
+ def func(x):
+ from userlibrary import UserClass
+ return UserClass().hello()
+ self.assertRaises(Exception,
+ self.sc.parallelize(range(2)).map(func).first)
+ # Add the file, so the job should now succeed:
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addPyFile(path)
+ res = self.sc.parallelize(range(2)).map(func).first()
+ self.assertEqual("Hello World!", res)
+
+ def test_add_file_locally(self):
+ path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ self.sc.addFile(path)
+ download_path = SparkFiles.get("hello.txt")
+ self.assertNotEqual(path, download_path)
+ with open(download_path) as test_file:
+ self.assertEquals("Hello World!\n", test_file.readline())
+
+ def test_add_py_file_locally(self):
+ # To ensure that we're actually testing addPyFile's effects, check that
+ # this fails due to `userlibrary` not being on the Python path:
+ def func():
+ from userlibrary import UserClass
+ self.assertRaises(ImportError, func)
+ path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
+ self.sc.addFile(path)
+ from userlibrary import UserClass
+ self.assertEqual("Hello World!", UserClass().hello())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
new file mode 100644
index 0000000000..d33d6dd15f
--- /dev/null
+++ b/python/pyspark/worker.py
@@ -0,0 +1,52 @@
+"""
+Worker that receives input from Piped RDD.
+"""
+import sys
+from base64 import standard_b64decode
+# CloudPickler needs to be imported so that depicklers are registered using the
+# copy_reg module.
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.broadcast import Broadcast, _broadcastRegistry
+from pyspark.cloudpickle import CloudPickler
+from pyspark.files import SparkFiles
+from pyspark.serializers import write_with_length, read_with_length, write_int, \
+ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+
+
+# Redirect stdout to stderr so that users must return values from functions.
+old_stdout = sys.stdout
+sys.stdout = sys.stderr
+
+
+def load_obj():
+ return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
+
+
+def main():
+ split_index = read_int(sys.stdin)
+ spark_files_dir = load_pickle(read_with_length(sys.stdin))
+ SparkFiles._root_directory = spark_files_dir
+ SparkFiles._is_running_on_worker = True
+ sys.path.append(spark_files_dir)
+ num_broadcast_variables = read_int(sys.stdin)
+ for _ in range(num_broadcast_variables):
+ bid = read_long(sys.stdin)
+ value = read_with_length(sys.stdin)
+ _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+ func = load_obj()
+ bypassSerializer = load_obj()
+ if bypassSerializer:
+ dumps = lambda x: x
+ else:
+ dumps = dump_pickle
+ iterator = read_from_pickle_file(sys.stdin)
+ for obj in func(split_index, iterator):
+ write_with_length(dumps(obj), old_stdout)
+ # Mark the beginning of the accumulators section of the output
+ write_int(-1, old_stdout)
+ for aid, accum in _accumulatorRegistry.items():
+ write_with_length(dump_pickle((aid, accum._value)), old_stdout)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/python/run-tests b/python/run-tests
new file mode 100755
index 0000000000..a3a9ff5dcb
--- /dev/null
+++ b/python/run-tests
@@ -0,0 +1,35 @@
+#!/usr/bin/env bash
+
+# Figure out where the Scala framework is installed
+FWDIR="$(cd `dirname $0`; cd ../; pwd)"
+
+FAILED=0
+
+$FWDIR/pyspark pyspark/rdd.py
+FAILED=$(($?||$FAILED))
+
+$FWDIR/pyspark pyspark/context.py
+FAILED=$(($?||$FAILED))
+
+$FWDIR/pyspark -m doctest pyspark/broadcast.py
+FAILED=$(($?||$FAILED))
+
+$FWDIR/pyspark -m doctest pyspark/accumulators.py
+FAILED=$(($?||$FAILED))
+
+$FWDIR/pyspark -m unittest pyspark.tests
+FAILED=$(($?||$FAILED))
+
+if [[ $FAILED != 0 ]]; then
+ echo -en "\033[31m" # Red
+ echo "Had test failures; see logs."
+ echo -en "\033[0m" # No color
+ exit -1
+else
+ echo -en "\033[32m" # Green
+ echo "Tests passed."
+ echo -en "\033[0m" # No color
+fi
+
+# TODO: in the long-run, it would be nice to use a test runner like `nose`.
+# The doctest fixtures are the current barrier to doing this.
diff --git a/python/test_support/hello.txt b/python/test_support/hello.txt
new file mode 100755
index 0000000000..980a0d5f19
--- /dev/null
+++ b/python/test_support/hello.txt
@@ -0,0 +1 @@
+Hello World!
diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py
new file mode 100755
index 0000000000..5bb6f5009f
--- /dev/null
+++ b/python/test_support/userlibrary.py
@@ -0,0 +1,7 @@
+"""
+Used to test shipping of code depenencies with SparkContext.addPyFile().
+"""
+
+class UserClass(object):
+ def hello(self):
+ return "Hello World!"
diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml
index 72a946f3d7..da91c0f3ab 100644
--- a/repl-bin/pom.xml
+++ b/repl-bin/pom.xml
@@ -15,7 +15,8 @@
<url>http://spark-project.org/</url>
<properties>
- <deb.install.path>/usr/share/spark</deb.install.path>
+ <deb.pkg.name>spark-${classifier}</deb.pkg.name>
+ <deb.install.path>/usr/share/spark-${classifier}</deb.install.path>
<deb.user>root</deb.user>
</properties>
@@ -69,6 +70,11 @@
<profiles>
<profile>
<id>hadoop1</id>
+ <activation>
+ <property>
+ <name>!hadoopVersion</name>
+ </property>
+ </activation>
<properties>
<classifier>hadoop1</classifier>
</properties>
@@ -109,6 +115,12 @@
</profile>
<profile>
<id>hadoop2</id>
+ <activation>
+ <property>
+ <name>hadoopVersion</name>
+ <value>2</value>
+ </property>
+ </activation>
<properties>
<classifier>hadoop2</classifier>
</properties>
@@ -183,7 +195,7 @@
<goal>jdeb</goal>
</goals>
<configuration>
- <deb>${project.build.directory}/${project.artifactId}-${classifier}_${project.version}-${buildNumber}_all.deb</deb>
+ <deb>${project.build.directory}/${deb.pkg.name}_${project.version}-${buildNumber}_all.deb</deb>
<attach>false</attach>
<compression>gzip</compression>
<dataSet>
diff --git a/repl-bin/src/deb/control/control b/repl-bin/src/deb/control/control
index afadb3fbfe..a6b4471d48 100644
--- a/repl-bin/src/deb/control/control
+++ b/repl-bin/src/deb/control/control
@@ -1,8 +1,8 @@
-Package: [[artifactId]]
+Package: [[deb.pkg.name]]
Version: [[version]]-[[buildNumber]]
Section: misc
Priority: extra
Architecture: all
Maintainer: Matei Zaharia <matei.zaharia@gmail.com>
-Description: spark repl
+Description: [[name]]
Distribution: development
diff --git a/repl/pom.xml b/repl/pom.xml
index 114e3e9932..2dc96beaf5 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -72,6 +72,11 @@
<profiles>
<profile>
<id>hadoop1</id>
+ <activation>
+ <property>
+ <name>!hadoopVersion</name>
+ </property>
+ </activation>
<properties>
<classifier>hadoop1</classifier>
</properties>
@@ -97,6 +102,13 @@
<scope>runtime</scope>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop1</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
@@ -116,6 +128,12 @@
</profile>
<profile>
<id>hadoop2</id>
+ <activation>
+ <property>
+ <name>hadoopVersion</name>
+ <value>2</value>
+ </property>
+ </activation>
<properties>
<classifier>hadoop2</classifier>
</properties>
@@ -141,6 +159,13 @@
<scope>runtime</scope>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2</classifier>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
@@ -150,6 +175,16 @@
<artifactId>hadoop-client</artifactId>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro-ipc</artifactId>
+ <scope>provided</scope>
+ </dependency>
</dependencies>
<build>
<plugins>
diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties
index 4c99e450bc..cfb1a390e6 100644
--- a/repl/src/test/resources/log4j.properties
+++ b/repl/src/test/resources/log4j.properties
@@ -1,8 +1,8 @@
-# Set everything to be logged to the console
+# Set everything to be logged to the repl/target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
-log4j.appender.file.file=spark-tests.log
+log4j.appender.file.file=repl/target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
diff --git a/run b/run
index 1528f83534..a094629449 100755
--- a/run
+++ b/run
@@ -63,6 +63,15 @@ CORE_DIR="$FWDIR/core"
REPL_DIR="$FWDIR/repl"
EXAMPLES_DIR="$FWDIR/examples"
BAGEL_DIR="$FWDIR/bagel"
+STREAMING_DIR="$FWDIR/streaming"
+PYSPARK_DIR="$FWDIR/python"
+
+# Exit if the user hasn't compiled Spark
+if [ ! -e "$REPL_DIR/target" ]; then
+ echo "Failed to find Spark classes in $REPL_DIR/target" >&2
+ echo "You need to compile Spark before running this program" >&2
+ exit 1
+fi
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH"
@@ -74,21 +83,21 @@ fi
CLASSPATH+=":$CORE_DIR/src/main/resources"
CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
+CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
if [ -e "$FWDIR/lib_managed" ]; then
- for jar in `find "$FWDIR/lib_managed/jars" -name '*jar'`; do
- CLASSPATH+=":$jar"
- done
- for jar in `find "$FWDIR/lib_managed/bundles" -name '*jar'`; do
+ CLASSPATH+=":$FWDIR/lib_managed/jars/*"
+ CLASSPATH+=":$FWDIR/lib_managed/bundles/*"
+fi
+CLASSPATH+=":$REPL_DIR/lib/*"
+if [ -e repl-bin/target ]; then
+ for jar in `find "repl-bin/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
CLASSPATH+=":$jar"
done
fi
-for jar in `find "$REPL_DIR/lib" -name '*jar'`; do
- CLASSPATH+=":$jar"
-done
-for jar in `find "$REPL_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
+CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
+for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
CLASSPATH+=":$jar"
done
-CLASSPATH+=":$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
export CLASSPATH # Needed for spark-shell
# Figure out whether to run our class with java or with the scala launcher.
diff --git a/run2.cmd b/run2.cmd
index 333d0506b0..67f1e465e4 100644
--- a/run2.cmd
+++ b/run2.cmd
@@ -1,6 +1,6 @@
@echo off
-set SCALA_VERSION=2.9.1
+set SCALA_VERSION=2.9.2
rem Figure out where the Spark framework is installed
set FWDIR=%~dp0
@@ -34,6 +34,7 @@ set CORE_DIR=%FWDIR%core
set REPL_DIR=%FWDIR%repl
set EXAMPLES_DIR=%FWDIR%examples
set BAGEL_DIR=%FWDIR%bagel
+set PYSPARK_DIR=%FWDIR%python
rem Build up classpath
set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
@@ -42,6 +43,7 @@ set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMP
for /R "%FWDIR%\lib_managed\jars" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
for /R "%FWDIR%\lib_managed\bundles" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
for /R "%REPL_DIR%\lib" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
+for /R "%PYSPARK_DIR%\lib" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j
set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
rem Figure out whether to run our class with java or with the scala launcher.
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar
new file mode 100644
index 0000000000..65f79925a4
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar
Binary files differ
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5
new file mode 100644
index 0000000000..29f45f4adb
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5
@@ -0,0 +1 @@
+18876b8bc2e4cef28b6d191aa49d963f \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1
new file mode 100644
index 0000000000..e3bd62bac0
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1
@@ -0,0 +1 @@
+06b27270ffa52250a2c08703b397c99127b72060 \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom
new file mode 100644
index 0000000000..082d35726a
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom
@@ -0,0 +1,9 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
+ <modelVersion>4.0.0</modelVersion>
+ <groupId>org.apache.kafka</groupId>
+ <artifactId>kafka</artifactId>
+ <version>0.7.2-spark</version>
+ <description>POM was created from install:install-file</description>
+</project>
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5
new file mode 100644
index 0000000000..92c4132b5b
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5
@@ -0,0 +1 @@
+7bc4322266e6032bdf9ef6eebdd8097d \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1
new file mode 100644
index 0000000000..8a1d8a097a
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1
@@ -0,0 +1 @@
+d0f79e8eff0db43ca7bcf7dce2c8cd2972685c9d \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml
new file mode 100644
index 0000000000..720cd51c2f
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml
@@ -0,0 +1,12 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<metadata>
+ <groupId>org.apache.kafka</groupId>
+ <artifactId>kafka</artifactId>
+ <versioning>
+ <release>0.7.2-spark</release>
+ <versions>
+ <version>0.7.2-spark</version>
+ </versions>
+ <lastUpdated>20130121015225</lastUpdated>
+ </versioning>
+</metadata>
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5
new file mode 100644
index 0000000000..a4ce5dc9e8
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5
@@ -0,0 +1 @@
+e2b9c7c5f6370dd1d21a0aae5e8dcd77 \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1
new file mode 100644
index 0000000000..b869eaf2a6
--- /dev/null
+++ b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1
@@ -0,0 +1 @@
+2a4341da936b6c07a09383d17ffb185ac558ee91 \ No newline at end of file
diff --git a/streaming/pom.xml b/streaming/pom.xml
new file mode 100644
index 0000000000..3dae815e1a
--- /dev/null
+++ b/streaming/pom.xml
@@ -0,0 +1,155 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.spark-project</groupId>
+ <artifactId>parent</artifactId>
+ <version>0.7.0-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-streaming</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project Streaming</name>
+ <url>http://spark-project.org/</url>
+
+ <repositories>
+ <!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
+ <repository>
+ <id>lib</id>
+ <url>file://${project.basedir}/lib</url>
+ </repository>
+ </repositories>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.eclipse.jetty</groupId>
+ <artifactId>jetty-server</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.codehaus.jackson</groupId>
+ <artifactId>jackson-mapper-asl</artifactId>
+ <version>1.9.11</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.kafka</groupId>
+ <artifactId>kafka</artifactId>
+ <version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
+ </dependency>
+ <dependency>
+ <groupId>org.apache.flume</groupId>
+ <artifactId>flume-ng-sdk</artifactId>
+ <version>1.2.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.github.sgroschupf</groupId>
+ <artifactId>zkclient</artifactId>
+ <version>0.1</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.scalacheck</groupId>
+ <artifactId>scalacheck_${scala.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.novocode</groupId>
+ <artifactId>junit-interface</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-log4j12</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+
+ <profiles>
+ <profile>
+ <id>hadoop1</id>
+ <activation>
+ <property>
+ <name>!hadoopVersion</name>
+ </property>
+ </activation>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop1</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-core</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop1</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>hadoop2</id>
+ <activation>
+ <property>
+ <name>hadoopVersion</name>
+ <value>2</value>
+ </property>
+ </activation>
+ <dependencies>
+ <dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-core</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-core</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <configuration>
+ <classifier>hadoop2</classifier>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+</project>
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
new file mode 100644
index 0000000000..2f3adb39c2
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -0,0 +1,118 @@
+package spark.streaming
+
+import spark.{Logging, Utils}
+
+import org.apache.hadoop.fs.{FileUtil, Path}
+import org.apache.hadoop.conf.Configuration
+
+import java.io._
+
+
+private[streaming]
+class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
+ extends Logging with Serializable {
+ val master = ssc.sc.master
+ val framework = ssc.sc.jobName
+ val sparkHome = ssc.sc.sparkHome
+ val jars = ssc.sc.jars
+ val graph = ssc.graph
+ val checkpointDir = ssc.checkpointDir
+ val checkpointDuration: Duration = ssc.checkpointDuration
+
+ def validate() {
+ assert(master != null, "Checkpoint.master is null")
+ assert(framework != null, "Checkpoint.framework is null")
+ assert(graph != null, "Checkpoint.graph is null")
+ assert(checkpointTime != null, "Checkpoint.checkpointTime is null")
+ logInfo("Checkpoint for time " + checkpointTime + " validated")
+ }
+}
+
+/**
+ * Convenience class to speed up the writing of graph checkpoint to file
+ */
+private[streaming]
+class CheckpointWriter(checkpointDir: String) extends Logging {
+ val file = new Path(checkpointDir, "graph")
+ val conf = new Configuration()
+ var fs = file.getFileSystem(conf)
+ val maxAttempts = 3
+
+ def write(checkpoint: Checkpoint) {
+ // TODO: maybe do this in a different thread from the main stream execution thread
+ var attempts = 0
+ while (attempts < maxAttempts) {
+ attempts += 1
+ try {
+ logDebug("Saving checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'")
+ if (fs.exists(file)) {
+ val bkFile = new Path(file.getParent, file.getName + ".bk")
+ FileUtil.copy(fs, file, fs, bkFile, true, true, conf)
+ logDebug("Moved existing checkpoint file to " + bkFile)
+ }
+ val fos = fs.create(file)
+ val oos = new ObjectOutputStream(fos)
+ oos.writeObject(checkpoint)
+ oos.close()
+ logInfo("Checkpoint for time " + checkpoint.checkpointTime + " saved to file '" + file + "'")
+ fos.close()
+ return
+ } catch {
+ case ioe: IOException =>
+ logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe)
+ }
+ }
+ logError("Could not write checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'")
+ }
+}
+
+
+private[streaming]
+object CheckpointReader extends Logging {
+
+ def read(path: String): Checkpoint = {
+ val fs = new Path(path).getFileSystem(new Configuration())
+ val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk"))
+
+ attempts.foreach(file => {
+ if (fs.exists(file)) {
+ logInfo("Attempting to load checkpoint from file '" + file + "'")
+ try {
+ val fis = fs.open(file)
+ // ObjectInputStream uses the last defined user-defined class loader in the stack
+ // to find classes, which maybe the wrong class loader. Hence, a inherited version
+ // of ObjectInputStream is used to explicitly use the current thread's default class
+ // loader to find and load classes. This is a well know Java issue and has popped up
+ // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
+ val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader)
+ val cp = ois.readObject.asInstanceOf[Checkpoint]
+ ois.close()
+ fs.close()
+ cp.validate()
+ logInfo("Checkpoint successfully loaded from file '" + file + "'")
+ logInfo("Checkpoint was generated at time " + cp.checkpointTime)
+ return cp
+ } catch {
+ case e: Exception =>
+ logError("Error loading checkpoint from file '" + file + "'", e)
+ }
+ } else {
+ logWarning("Could not read checkpoint from file '" + file + "' as it does not exist")
+ }
+
+ })
+ throw new Exception("Could not read checkpoint from path '" + path + "'")
+ }
+}
+
+private[streaming]
+class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) extends ObjectInputStream(inputStream_) {
+ override def resolveClass(desc: ObjectStreamClass): Class[_] = {
+ try {
+ return loader.loadClass(desc.getName())
+ } catch {
+ case e: Exception =>
+ }
+ return super.resolveClass(desc)
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
new file mode 100644
index 0000000000..b11ef443dc
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -0,0 +1,657 @@
+package spark.streaming
+
+import spark.streaming.dstream._
+import StreamingContext._
+//import Time._
+
+import spark.{RDD, Logging}
+import spark.storage.StorageLevel
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
+
+/**
+ * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
+ * sequence of RDDs (of the same type) representing a continuous stream of data (see [[spark.RDD]]
+ * for more details on RDDs). DStreams can either be created from live data (such as, data from
+ * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations
+ * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each
+ * DStream periodically generates a RDD, either from live data or by transforming the RDD generated
+ * by a parent DStream.
+ *
+ * This class contains the basic operations available on all DStreams, such as `map`, `filter` and
+ * `window`. In addition, [[spark.streaming.PairDStreamFunctions]] contains operations available
+ * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. These operations
+ * are automatically available on any DStream of the right type (e.g., DStream[(Int, Int)] through
+ * implicit conversions when `spark.streaming.StreamingContext._` is imported.
+ *
+ * DStreams internally is characterized by a few basic properties:
+ * - A list of other DStreams that the DStream depends on
+ * - A time interval at which the DStream generates an RDD
+ * - A function that is used to generate an RDD after each time interval
+ */
+
+abstract class DStream[T: ClassManifest] (
+ @transient protected[streaming] var ssc: StreamingContext
+ ) extends Serializable with Logging {
+
+ initLogging()
+
+ // =======================================================================
+ // Methods that should be implemented by subclasses of DStream
+ // =======================================================================
+
+ /** Time interval after which the DStream generates a RDD */
+ def slideDuration: Duration
+
+ /** List of parent DStreams on which this DStream depends on */
+ def dependencies: List[DStream[_]]
+
+ /** Method that generates a RDD for the given time */
+ def compute (validTime: Time): Option[RDD[T]]
+
+ // =======================================================================
+ // Methods and fields available on all DStreams
+ // =======================================================================
+
+ // RDDs generated, marked as protected[streaming] so that testsuites can access it
+ @transient
+ protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
+
+ // Time zero for the DStream
+ protected[streaming] var zeroTime: Time = null
+
+ // Duration for which the DStream will remember each RDD created
+ protected[streaming] var rememberDuration: Duration = null
+
+ // Storage level of the RDDs in the stream
+ protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE
+
+ // Checkpoint details
+ protected[streaming] val mustCheckpoint = false
+ protected[streaming] var checkpointDuration: Duration = null
+ protected[streaming] var checkpointData = new DStreamCheckpointData(HashMap[Time, Any]())
+
+ // Reference to whole DStream graph
+ protected[streaming] var graph: DStreamGraph = null
+
+ protected[streaming] def isInitialized = (zeroTime != null)
+
+ // Duration for which the DStream requires its parent DStream to remember each RDD created
+ protected[streaming] def parentRememberDuration = rememberDuration
+
+ /** Returns the StreamingContext associated with this DStream */
+ def context() = ssc
+
+ /** Persists the RDDs of this DStream with the given storage level */
+ def persist(level: StorageLevel): DStream[T] = {
+ if (this.isInitialized) {
+ throw new UnsupportedOperationException(
+ "Cannot change storage level of an DStream after streaming context has started")
+ }
+ this.storageLevel = level
+ this
+ }
+
+ /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
+ def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY_SER)
+
+ /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
+ def cache(): DStream[T] = persist()
+
+ /**
+ * Enable periodic checkpointing of RDDs of this DStream
+ * @param interval Time interval after which generated RDD will be checkpointed
+ */
+ def checkpoint(interval: Duration): DStream[T] = {
+ if (isInitialized) {
+ throw new UnsupportedOperationException(
+ "Cannot change checkpoint interval of an DStream after streaming context has started")
+ }
+ persist()
+ checkpointDuration = interval
+ this
+ }
+
+ /**
+ * Initialize the DStream by setting the "zero" time, based on which
+ * the validity of future times is calculated. This method also recursively initializes
+ * its parent DStreams.
+ */
+ protected[streaming] def initialize(time: Time) {
+ if (zeroTime != null && zeroTime != time) {
+ throw new Exception("ZeroTime is already initialized to " + zeroTime
+ + ", cannot initialize it again to " + time)
+ }
+ zeroTime = time
+
+ // Set the checkpoint interval to be slideDuration or 10 seconds, which ever is larger
+ if (mustCheckpoint && checkpointDuration == null) {
+ checkpointDuration = slideDuration.max(Seconds(10))
+ logInfo("Checkpoint interval automatically set to " + checkpointDuration)
+ }
+
+ // Set the minimum value of the rememberDuration if not already set
+ var minRememberDuration = slideDuration
+ if (checkpointDuration != null && minRememberDuration <= checkpointDuration) {
+ minRememberDuration = checkpointDuration * 2 // times 2 just to be sure that the latest checkpoint is not forgetten
+ }
+ if (rememberDuration == null || rememberDuration < minRememberDuration) {
+ rememberDuration = minRememberDuration
+ }
+
+ // Initialize the dependencies
+ dependencies.foreach(_.initialize(zeroTime))
+ }
+
+ protected[streaming] def validate() {
+ assert(rememberDuration != null, "Remember duration is set to null")
+
+ assert(
+ !mustCheckpoint || checkpointDuration != null,
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." +
+ " Please use DStream.checkpoint() to set the interval."
+ )
+
+ assert(
+ checkpointDuration == null || ssc.sc.checkpointDir.isDefined,
+ "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" +
+ " or SparkContext.checkpoint() to set the checkpoint directory."
+ )
+
+ assert(
+ checkpointDuration == null || checkpointDuration >= slideDuration,
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " +
+ checkpointDuration + " which is lower than its slide time (" + slideDuration + "). " +
+ "Please set it to at least " + slideDuration + "."
+ )
+
+ assert(
+ checkpointDuration == null || checkpointDuration.isMultipleOf(slideDuration),
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " +
+ checkpointDuration + " which not a multiple of its slide time (" + slideDuration + "). " +
+ "Please set it to a multiple " + slideDuration + "."
+ )
+
+ assert(
+ checkpointDuration == null || storageLevel != StorageLevel.NONE,
+ "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " +
+ "level has not been set to enable persisting. Please use DStream.persist() to set the " +
+ "storage level to use memory for better checkpointing performance."
+ )
+
+ assert(
+ checkpointDuration == null || rememberDuration > checkpointDuration,
+ "The remember duration for " + this.getClass.getSimpleName + " has been set to " +
+ rememberDuration + " which is not more than the checkpoint interval (" +
+ checkpointDuration + "). Please set it to higher than " + checkpointDuration + "."
+ )
+
+ val metadataCleanerDelay = spark.util.MetadataCleaner.getDelaySeconds
+ logInfo("metadataCleanupDelay = " + metadataCleanerDelay)
+ assert(
+ metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000,
+ "It seems you are doing some DStream window operation or setting a checkpoint interval " +
+ "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " +
+ "than " + rememberDuration.milliseconds + " milliseconds. But the Spark's metadata cleanup" +
+ "delay is set to " + (metadataCleanerDelay / 60.0) + " minutes, which is not sufficient. Please set " +
+ "the Java property 'spark.cleaner.delay' to more than " +
+ math.ceil(rememberDuration.milliseconds.toDouble / 60000.0).toInt + " minutes."
+ )
+
+ dependencies.foreach(_.validate())
+
+ logInfo("Slide time = " + slideDuration)
+ logInfo("Storage level = " + storageLevel)
+ logInfo("Checkpoint interval = " + checkpointDuration)
+ logInfo("Remember duration = " + rememberDuration)
+ logInfo("Initialized and validated " + this)
+ }
+
+ protected[streaming] def setContext(s: StreamingContext) {
+ if (ssc != null && ssc != s) {
+ throw new Exception("Context is already set in " + this + ", cannot set it again")
+ }
+ ssc = s
+ logInfo("Set context for " + this)
+ dependencies.foreach(_.setContext(ssc))
+ }
+
+ protected[streaming] def setGraph(g: DStreamGraph) {
+ if (graph != null && graph != g) {
+ throw new Exception("Graph is already set in " + this + ", cannot set it again")
+ }
+ graph = g
+ dependencies.foreach(_.setGraph(graph))
+ }
+
+ protected[streaming] def remember(duration: Duration) {
+ if (duration != null && duration > rememberDuration) {
+ rememberDuration = duration
+ logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this)
+ }
+ dependencies.foreach(_.remember(parentRememberDuration))
+ }
+
+ /** This method checks whether the 'time' is valid wrt slideDuration for generating RDD */
+ protected def isTimeValid(time: Time): Boolean = {
+ if (!isInitialized) {
+ throw new Exception (this + " has not been initialized")
+ } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) {
+ false
+ } else {
+ true
+ }
+ }
+
+ /**
+ * Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal
+ * method that should not be called directly.
+ */
+ protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = {
+ // If this DStream was not initialized (i.e., zeroTime not set), then do it
+ // If RDD was already generated, then retrieve it from HashMap
+ generatedRDDs.get(time) match {
+
+ // If an RDD was already generated and is being reused, then
+ // probably all RDDs in this DStream will be reused and hence should be cached
+ case Some(oldRDD) => Some(oldRDD)
+
+ // if RDD was not generated, and if the time is valid
+ // (based on sliding time of this DStream), then generate the RDD
+ case None => {
+ if (isTimeValid(time)) {
+ compute(time) match {
+ case Some(newRDD) =>
+ if (storageLevel != StorageLevel.NONE) {
+ newRDD.persist(storageLevel)
+ logInfo("Persisting RDD " + newRDD.id + " for time " + time + " to " + storageLevel + " at time " + time)
+ }
+ if (checkpointDuration != null && (time - zeroTime).isMultipleOf(checkpointDuration)) {
+ newRDD.checkpoint()
+ logInfo("Marking RDD " + newRDD.id + " for time " + time + " for checkpointing at time " + time)
+ }
+ generatedRDDs.put(time, newRDD)
+ Some(newRDD)
+ case None =>
+ None
+ }
+ } else {
+ None
+ }
+ }
+ }
+ }
+
+ /**
+ * Generate a SparkStreaming job for the given time. This is an internal method that
+ * should not be called directly. This default implementation creates a job
+ * that materializes the corresponding RDD. Subclasses of DStream may override this
+ * (eg. ForEachDStream).
+ */
+ protected[streaming] def generateJob(time: Time): Option[Job] = {
+ getOrCompute(time) match {
+ case Some(rdd) => {
+ val jobFunc = () => {
+ val emptyFunc = { (iterator: Iterator[T]) => {} }
+ ssc.sc.runJob(rdd, emptyFunc)
+ }
+ Some(new Job(time, jobFunc))
+ }
+ case None => None
+ }
+ }
+
+ /**
+ * Dereference RDDs that are older than rememberDuration.
+ */
+ protected[streaming] def forgetOldRDDs(time: Time) {
+ val keys = generatedRDDs.keys
+ var numForgotten = 0
+ keys.foreach(t => {
+ if (t <= (time - rememberDuration)) {
+ generatedRDDs.remove(t)
+ numForgotten += 1
+ logInfo("Forgot RDD of time " + t + " from " + this)
+ }
+ })
+ logInfo("Forgot " + numForgotten + " RDDs from " + this)
+ dependencies.foreach(_.forgetOldRDDs(time))
+ }
+
+ /* Adds metadata to the Stream while it is running.
+ * This methd should be overwritten by sublcasses of InputDStream.
+ */
+ protected[streaming] def addMetadata(metadata: Any) {
+ if (metadata != null) {
+ logInfo("Dropping Metadata: " + metadata.toString)
+ }
+ }
+
+ /**
+ * Refresh the list of checkpointed RDDs that will be saved along with checkpoint of
+ * this stream. This is an internal method that should not be called directly. This is
+ * a default implementation that saves only the file names of the checkpointed RDDs to
+ * checkpointData. Subclasses of DStream (especially those of InputDStream) may override
+ * this method to save custom checkpoint data.
+ */
+ protected[streaming] def updateCheckpointData(currentTime: Time) {
+ logInfo("Updating checkpoint data for time " + currentTime)
+
+ // Get the checkpointed RDDs from the generated RDDs
+ val newRdds = generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
+ .map(x => (x._1, x._2.getCheckpointFile.get))
+
+ // Make a copy of the existing checkpoint data (checkpointed RDDs)
+ val oldRdds = checkpointData.rdds.clone()
+
+ // If the new checkpoint data has checkpoints then replace existing with the new one
+ if (newRdds.size > 0) {
+ checkpointData.rdds.clear()
+ checkpointData.rdds ++= newRdds
+ }
+
+ // Make parent DStreams update their checkpoint data
+ dependencies.foreach(_.updateCheckpointData(currentTime))
+
+ // TODO: remove this, this is just for debugging
+ newRdds.foreach {
+ case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") }
+ }
+
+ if (newRdds.size > 0) {
+ (oldRdds -- newRdds.keySet).foreach {
+ case (time, data) => {
+ val path = new Path(data.toString)
+ val fs = path.getFileSystem(new Configuration())
+ fs.delete(path, true)
+ logInfo("Deleted checkpoint file '" + path + "' for time " + time)
+ }
+ }
+ }
+ logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.rdds.size + " checkpoints, "
+ + "[" + checkpointData.rdds.mkString(",") + "]")
+ }
+
+ /**
+ * Restore the RDDs in generatedRDDs from the checkpointData. This is an internal method
+ * that should not be called directly. This is a default implementation that recreates RDDs
+ * from the checkpoint file names stored in checkpointData. Subclasses of DStream that
+ * override the updateCheckpointData() method would also need to override this method.
+ */
+ protected[streaming] def restoreCheckpointData() {
+ // Create RDDs from the checkpoint data
+ logInfo("Restoring checkpoint data from " + checkpointData.rdds.size + " checkpointed RDDs")
+ checkpointData.rdds.foreach {
+ case(time, data) => {
+ logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'")
+ val rdd = ssc.sc.checkpointFile[T](data.toString)
+ generatedRDDs += ((time, rdd))
+ }
+ }
+ dependencies.foreach(_.restoreCheckpointData())
+ logInfo("Restored checkpoint data")
+ }
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ logDebug(this.getClass().getSimpleName + ".writeObject used")
+ if (graph != null) {
+ graph.synchronized {
+ if (graph.checkpointInProgress) {
+ oos.defaultWriteObject()
+ } else {
+ val msg = "Object of " + this.getClass.getName + " is being serialized " +
+ " possibly as a part of closure of an RDD operation. This is because " +
+ " the DStream object is being referred to from within the closure. " +
+ " Please rewrite the RDD operation inside this DStream to avoid this. " +
+ " This has been enforced to avoid bloating of Spark tasks " +
+ " with unnecessary objects."
+ throw new java.io.NotSerializableException(msg)
+ }
+ }
+ } else {
+ throw new java.io.NotSerializableException("Graph is unexpectedly null when DStream is being serialized.")
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ logDebug(this.getClass().getSimpleName + ".readObject used")
+ ois.defaultReadObject()
+ generatedRDDs = new HashMap[Time, RDD[T]] ()
+ }
+
+ // =======================================================================
+ // DStream operations
+ // =======================================================================
+
+ /** Return a new DStream by applying a function to all elements of this DStream. */
+ def map[U: ClassManifest](mapFunc: T => U): DStream[U] = {
+ new MappedDStream(this, ssc.sc.clean(mapFunc))
+ }
+
+ /**
+ * Return a new DStream by applying a function to all elements of this DStream,
+ * and then flattening the results
+ */
+ def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = {
+ new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc))
+ }
+
+ /** Return a new DStream containing only the elements that satisfy a predicate. */
+ def filter(filterFunc: T => Boolean): DStream[T] = new FilteredDStream(this, filterFunc)
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying glom() to each RDD of
+ * this DStream. Applying glom() to an RDD coalesces all elements within each partition into
+ * an array.
+ */
+ def glom(): DStream[Array[T]] = new GlommedDStream(this)
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
+ * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
+ * of the RDD.
+ */
+ def mapPartitions[U: ClassManifest](
+ mapPartFunc: Iterator[T] => Iterator[U],
+ preservePartitioning: Boolean = false
+ ): DStream[U] = {
+ new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc), preservePartitioning)
+ }
+
+ /**
+ * Return a new DStream in which each RDD has a single element generated by reducing each RDD
+ * of this DStream.
+ */
+ def reduce(reduceFunc: (T, T) => T): DStream[T] =
+ this.map(x => (null, x)).reduceByKey(reduceFunc, 1).map(_._2)
+
+ /**
+ * Return a new DStream in which each RDD has a single element generated by counting each RDD
+ * of this DStream.
+ */
+ def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _)
+
+ /**
+ * Apply a function to each RDD in this DStream. This is an output operator, so
+ * this DStream will be registered as an output stream and therefore materialized.
+ */
+ def foreach(foreachFunc: RDD[T] => Unit) {
+ foreach((r: RDD[T], t: Time) => foreachFunc(r))
+ }
+
+ /**
+ * Apply a function to each RDD in this DStream. This is an output operator, so
+ * this DStream will be registered as an output stream and therefore materialized.
+ */
+ def foreach(foreachFunc: (RDD[T], Time) => Unit) {
+ val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc))
+ ssc.registerOutputStream(newStream)
+ newStream
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of this DStream.
+ */
+ def transform[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
+ transform((r: RDD[T], t: Time) => transformFunc(r))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of this DStream.
+ */
+ def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
+ new TransformedDStream(this, ssc.sc.clean(transformFunc))
+ }
+
+ /**
+ * Print the first ten elements of each RDD generated in this DStream. This is an output
+ * operator, so this DStream will be registered as an output stream and there materialized.
+ */
+ def print() {
+ def foreachFunc = (rdd: RDD[T], time: Time) => {
+ val first11 = rdd.take(11)
+ println ("-------------------------------------------")
+ println ("Time: " + time)
+ println ("-------------------------------------------")
+ first11.take(10).foreach(println)
+ if (first11.size > 10) println("...")
+ println()
+ }
+ val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc))
+ ssc.registerOutputStream(newStream)
+ }
+
+ /**
+ * Return a new DStream which is computed based on windowed batches of this DStream.
+ * The new DStream generates RDDs with the same interval as this DStream.
+ * @param windowDuration width of the window; must be a multiple of this DStream's interval.
+ */
+ def window(windowDuration: Duration): DStream[T] = window(windowDuration, this.slideDuration)
+
+ /**
+ * Return a new DStream which is computed based on windowed batches of this DStream.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def window(windowDuration: Duration, slideDuration: Duration): DStream[T] = {
+ new WindowedDStream(this, windowDuration, slideDuration)
+ }
+
+ /**
+ * Return a new DStream which computed based on tumbling window on this DStream.
+ * This is equivalent to window(batchTime, batchTime).
+ * @param batchDuration tumbling window duration; must be a multiple of this DStream's
+ * batching interval
+ */
+ def tumble(batchDuration: Duration): DStream[T] = window(batchDuration, batchDuration)
+
+ /**
+ * Return a new DStream in which each RDD has a single element generated by reducing all
+ * elements in a window over this DStream. windowDuration and slideDuration are as defined
+ * in the window() operation. This is equivalent to
+ * window(windowDuration, slideDuration).reduce(reduceFunc)
+ */
+ def reduceByWindow(
+ reduceFunc: (T, T) => T,
+ windowDuration: Duration,
+ slideDuration: Duration
+ ): DStream[T] = {
+ this.window(windowDuration, slideDuration).reduce(reduceFunc)
+ }
+
+ def reduceByWindow(
+ reduceFunc: (T, T) => T,
+ invReduceFunc: (T, T) => T,
+ windowDuration: Duration,
+ slideDuration: Duration
+ ): DStream[T] = {
+ this.map(x => (1, x))
+ .reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration, 1)
+ .map(_._2)
+ }
+
+ /**
+ * Return a new DStream in which each RDD has a single element generated by counting the number
+ * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the
+ * window() operation. This is equivalent to window(windowDuration, slideDuration).count()
+ */
+ def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = {
+ this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration)
+ }
+
+ /**
+ * Return a new DStream by unifying data of another DStream with this DStream.
+ * @param that Another DStream having the same slideDuration as this DStream.
+ */
+ def union(that: DStream[T]): DStream[T] = new UnionDStream[T](Array(this, that))
+
+ /**
+ * Return all the RDDs defined by the Interval object (both end times included)
+ */
+ protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = {
+ slice(interval.beginTime, interval.endTime)
+ }
+
+ /**
+ * Return all the RDDs between 'fromTime' to 'toTime' (both included)
+ */
+ def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = {
+ val rdds = new ArrayBuffer[RDD[T]]()
+ var time = toTime.floor(slideDuration)
+ while (time >= zeroTime && time >= fromTime) {
+ getOrCompute(time) match {
+ case Some(rdd) => rdds += rdd
+ case None => //throw new Exception("Could not get RDD for time " + time)
+ }
+ time -= slideDuration
+ }
+ rdds.toSeq
+ }
+
+ /**
+ * Save each RDD in this DStream as a Sequence file of serialized objects.
+ * The file name at each batch interval is generated based on `prefix` and
+ * `suffix`: "prefix-TIME_IN_MS.suffix".
+ */
+ def saveAsObjectFiles(prefix: String, suffix: String = "") {
+ val saveFunc = (rdd: RDD[T], time: Time) => {
+ val file = rddToFileName(prefix, suffix, time)
+ rdd.saveAsObjectFile(file)
+ }
+ this.foreach(saveFunc)
+ }
+
+ /**
+ * Save each RDD in this DStream as at text file, using string representation
+ * of elements. The file name at each batch interval is generated based on
+ * `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
+ */
+ def saveAsTextFiles(prefix: String, suffix: String = "") {
+ val saveFunc = (rdd: RDD[T], time: Time) => {
+ val file = rddToFileName(prefix, suffix, time)
+ rdd.saveAsTextFile(file)
+ }
+ this.foreach(saveFunc)
+ }
+
+ def register() {
+ ssc.registerOutputStream(this)
+ }
+}
+
+private[streaming]
+case class DStreamCheckpointData(rdds: HashMap[Time, Any])
+
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
new file mode 100644
index 0000000000..bc4a40d7bc
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -0,0 +1,134 @@
+package spark.streaming
+
+import dstream.InputDStream
+import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
+import collection.mutable.ArrayBuffer
+import spark.Logging
+
+final private[streaming] class DStreamGraph extends Serializable with Logging {
+ initLogging()
+
+ private val inputStreams = new ArrayBuffer[InputDStream[_]]()
+ private val outputStreams = new ArrayBuffer[DStream[_]]()
+
+ private[streaming] var zeroTime: Time = null
+ private[streaming] var batchDuration: Duration = null
+ private[streaming] var rememberDuration: Duration = null
+ private[streaming] var checkpointInProgress = false
+
+ private[streaming] def start(time: Time) {
+ this.synchronized {
+ if (zeroTime != null) {
+ throw new Exception("DStream graph computation already started")
+ }
+ zeroTime = time
+ outputStreams.foreach(_.initialize(zeroTime))
+ outputStreams.foreach(_.remember(rememberDuration))
+ outputStreams.foreach(_.validate)
+ inputStreams.par.foreach(_.start())
+ }
+ }
+
+ private[streaming] def stop() {
+ this.synchronized {
+ inputStreams.par.foreach(_.stop())
+ }
+ }
+
+ private[streaming] def setContext(ssc: StreamingContext) {
+ this.synchronized {
+ outputStreams.foreach(_.setContext(ssc))
+ }
+ }
+
+ private[streaming] def setBatchDuration(duration: Duration) {
+ this.synchronized {
+ if (batchDuration != null) {
+ throw new Exception("Batch duration already set as " + batchDuration +
+ ". cannot set it again.")
+ }
+ }
+ batchDuration = duration
+ }
+
+ private[streaming] def remember(duration: Duration) {
+ this.synchronized {
+ if (rememberDuration != null) {
+ throw new Exception("Batch duration already set as " + batchDuration +
+ ". cannot set it again.")
+ }
+ }
+ rememberDuration = duration
+ }
+
+ private[streaming] def addInputStream(inputStream: InputDStream[_]) {
+ this.synchronized {
+ inputStream.setGraph(this)
+ inputStreams += inputStream
+ }
+ }
+
+ private[streaming] def addOutputStream(outputStream: DStream[_]) {
+ this.synchronized {
+ outputStream.setGraph(this)
+ outputStreams += outputStream
+ }
+ }
+
+ private[streaming] def getInputStreams() = this.synchronized { inputStreams.toArray }
+
+ private[streaming] def getOutputStreams() = this.synchronized { outputStreams.toArray }
+
+ private[streaming] def generateRDDs(time: Time): Seq[Job] = {
+ this.synchronized {
+ outputStreams.flatMap(outputStream => outputStream.generateJob(time))
+ }
+ }
+
+ private[streaming] def forgetOldRDDs(time: Time) {
+ this.synchronized {
+ outputStreams.foreach(_.forgetOldRDDs(time))
+ }
+ }
+
+ private[streaming] def updateCheckpointData(time: Time) {
+ this.synchronized {
+ outputStreams.foreach(_.updateCheckpointData(time))
+ }
+ }
+
+ private[streaming] def restoreCheckpointData() {
+ this.synchronized {
+ outputStreams.foreach(_.restoreCheckpointData())
+ }
+ }
+
+ private[streaming] def validate() {
+ this.synchronized {
+ assert(batchDuration != null, "Batch duration has not been set")
+ //assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + " is very low")
+ assert(getOutputStreams().size > 0, "No output streams registered, so nothing to execute")
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ this.synchronized {
+ logDebug("DStreamGraph.writeObject used")
+ checkpointInProgress = true
+ oos.defaultWriteObject()
+ checkpointInProgress = false
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ this.synchronized {
+ logDebug("DStreamGraph.readObject used")
+ checkpointInProgress = true
+ ois.defaultReadObject()
+ checkpointInProgress = false
+ }
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/Duration.scala b/streaming/src/main/scala/spark/streaming/Duration.scala
new file mode 100644
index 0000000000..e4dc579a17
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Duration.scala
@@ -0,0 +1,62 @@
+package spark.streaming
+
+case class Duration (private val millis: Long) {
+
+ def < (that: Duration): Boolean = (this.millis < that.millis)
+
+ def <= (that: Duration): Boolean = (this.millis <= that.millis)
+
+ def > (that: Duration): Boolean = (this.millis > that.millis)
+
+ def >= (that: Duration): Boolean = (this.millis >= that.millis)
+
+ def + (that: Duration): Duration = new Duration(millis + that.millis)
+
+ def - (that: Duration): Duration = new Duration(millis - that.millis)
+
+ def * (times: Int): Duration = new Duration(millis * times)
+
+ def / (that: Duration): Long = millis / that.millis
+
+ def isMultipleOf(that: Duration): Boolean =
+ (this.millis % that.millis == 0)
+
+ def min(that: Duration): Duration = if (this < that) this else that
+
+ def max(that: Duration): Duration = if (this > that) this else that
+
+ def isZero: Boolean = (this.millis == 0)
+
+ override def toString: String = (millis.toString + " ms")
+
+ def toFormattedString: String = millis.toString
+
+ def milliseconds: Long = millis
+}
+
+
+/**
+ * Helper object that creates instance of [[spark.streaming.Duration]] representing
+ * a given number of milliseconds.
+ */
+object Milliseconds {
+ def apply(milliseconds: Long) = new Duration(milliseconds)
+}
+
+/**
+ * Helper object that creates instance of [[spark.streaming.Duration]] representing
+ * a given number of seconds.
+ */
+object Seconds {
+ def apply(seconds: Long) = new Duration(seconds * 1000)
+}
+
+/**
+ * Helper object that creates instance of [[spark.streaming.Duration]] representing
+ * a given number of minutes.
+ */
+object Minutes {
+ def apply(minutes: Long) = new Duration(minutes * 60000)
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala
new file mode 100644
index 0000000000..dc21dfb722
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Interval.scala
@@ -0,0 +1,41 @@
+package spark.streaming
+
+private[streaming]
+class Interval(val beginTime: Time, val endTime: Time) {
+ def this(beginMs: Long, endMs: Long) = this(new Time(beginMs), new Time(endMs))
+
+ def duration(): Duration = endTime - beginTime
+
+ def + (time: Duration): Interval = {
+ new Interval(beginTime + time, endTime + time)
+ }
+
+ def - (time: Duration): Interval = {
+ new Interval(beginTime - time, endTime - time)
+ }
+
+ def < (that: Interval): Boolean = {
+ if (this.duration != that.duration) {
+ throw new Exception("Comparing two intervals with different durations [" + this + ", " + that + "]")
+ }
+ this.endTime < that.endTime
+ }
+
+ def <= (that: Interval) = (this < that || this == that)
+
+ def > (that: Interval) = !(this <= that)
+
+ def >= (that: Interval) = !(this < that)
+
+ override def toString = "[" + beginTime + ", " + endTime + "]"
+}
+
+object Interval {
+ def currentInterval(duration: Duration): Interval = {
+ val time = new Time(System.currentTimeMillis)
+ val intervalBegin = time.floor(duration)
+ new Interval(intervalBegin, intervalBegin + duration)
+ }
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/Job.scala b/streaming/src/main/scala/spark/streaming/Job.scala
new file mode 100644
index 0000000000..67bd8388bc
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Job.scala
@@ -0,0 +1,24 @@
+package spark.streaming
+
+import java.util.concurrent.atomic.AtomicLong
+
+private[streaming]
+class Job(val time: Time, func: () => _) {
+ val id = Job.getNewId()
+ def run(): Long = {
+ val startTime = System.currentTimeMillis
+ func()
+ val stopTime = System.currentTimeMillis
+ (stopTime - startTime)
+ }
+
+ override def toString = "streaming job " + id + " @ " + time
+}
+
+private[streaming]
+object Job {
+ val id = new AtomicLong(0)
+
+ def getNewId() = id.getAndIncrement()
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala
new file mode 100644
index 0000000000..3b910538e0
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/JobManager.scala
@@ -0,0 +1,33 @@
+package spark.streaming
+
+import spark.Logging
+import spark.SparkEnv
+import java.util.concurrent.Executors
+
+
+private[streaming]
+class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging {
+
+ class JobHandler(ssc: StreamingContext, job: Job) extends Runnable {
+ def run() {
+ SparkEnv.set(ssc.env)
+ try {
+ val timeTaken = job.run()
+ logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format(
+ (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, timeTaken / 1000.0))
+ } catch {
+ case e: Exception =>
+ logError("Running " + job + " failed", e)
+ }
+ }
+ }
+
+ initLogging()
+
+ val jobExecutor = Executors.newFixedThreadPool(numThreads)
+
+ def runJob(job: Job) {
+ jobExecutor.execute(new JobHandler(ssc, job))
+ logInfo("Added " + job + " to queue")
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
new file mode 100644
index 0000000000..e4152f3a61
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala
@@ -0,0 +1,151 @@
+package spark.streaming
+
+import spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
+import spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
+import spark.Logging
+import spark.SparkEnv
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Queue
+
+import akka.actor._
+import akka.pattern.ask
+import akka.util.duration._
+import akka.dispatch._
+
+private[streaming] sealed trait NetworkInputTrackerMessage
+private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
+private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage
+private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
+
+/**
+ * This class manages the execution of the receivers of NetworkInputDStreams.
+ */
+private[streaming]
+class NetworkInputTracker(
+ @transient ssc: StreamingContext,
+ @transient networkInputStreams: Array[NetworkInputDStream[_]])
+ extends Logging {
+
+ val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
+ val receiverExecutor = new ReceiverExecutor()
+ val receiverInfo = new HashMap[Int, ActorRef]
+ val receivedBlockIds = new HashMap[Int, Queue[String]]
+ val timeout = 5000.milliseconds
+
+ var currentTime: Time = null
+
+ /** Start the actor and receiver execution thread. */
+ def start() {
+ ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker")
+ receiverExecutor.start()
+ }
+
+ /** Stop the receiver execution thread. */
+ def stop() {
+ // TODO: stop the actor as well
+ receiverExecutor.interrupt()
+ receiverExecutor.stopReceivers()
+ }
+
+ /** Return all the blocks received from a receiver. */
+ def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized {
+ val queue = receivedBlockIds.synchronized {
+ receivedBlockIds.getOrElse(receiverId, new Queue[String]())
+ }
+ val result = queue.synchronized {
+ queue.dequeueAll(x => true)
+ }
+ logInfo("Stream " + receiverId + " received " + result.size + " blocks")
+ result.toArray
+ }
+
+ /** Actor to receive messages from the receivers. */
+ private class NetworkInputTrackerActor extends Actor {
+ def receive = {
+ case RegisterReceiver(streamId, receiverActor) => {
+ if (!networkInputStreamMap.contains(streamId)) {
+ throw new Exception("Register received for unexpected id " + streamId)
+ }
+ receiverInfo += ((streamId, receiverActor))
+ logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address)
+ sender ! true
+ }
+ case AddBlocks(streamId, blockIds, metadata) => {
+ val tmp = receivedBlockIds.synchronized {
+ if (!receivedBlockIds.contains(streamId)) {
+ receivedBlockIds += ((streamId, new Queue[String]))
+ }
+ receivedBlockIds(streamId)
+ }
+ tmp.synchronized {
+ tmp ++= blockIds
+ }
+ networkInputStreamMap(streamId).addMetadata(metadata)
+ }
+ case DeregisterReceiver(streamId, msg) => {
+ receiverInfo -= streamId
+ logInfo("De-registered receiver for network stream " + streamId
+ + " with message " + msg)
+ //TODO: Do something about the corresponding NetworkInputDStream
+ }
+ }
+ }
+
+ /** This thread class runs all the receivers on the cluster. */
+ class ReceiverExecutor extends Thread {
+ val env = ssc.env
+
+ override def run() {
+ try {
+ SparkEnv.set(env)
+ startReceivers()
+ } catch {
+ case ie: InterruptedException => logInfo("ReceiverExecutor interrupted")
+ } finally {
+ stopReceivers()
+ }
+ }
+
+ /**
+ * Get the receivers from the NetworkInputDStreams, distributes them to the
+ * worker nodes as a parallel collection, and runs them.
+ */
+ def startReceivers() {
+ val receivers = networkInputStreams.map(nis => {
+ val rcvr = nis.createReceiver()
+ rcvr.setStreamId(nis.id)
+ rcvr
+ })
+
+ // Right now, we only honor preferences if all receivers have them
+ val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined).reduce(_ && _)
+
+ // Create the parallel collection of receivers to distributed them on the worker nodes
+ val tempRDD =
+ if (hasLocationPreferences) {
+ val receiversWithPreferences = receivers.map(r => (r, Seq(r.getLocationPreference().toString)))
+ ssc.sc.makeRDD[NetworkReceiver[_]](receiversWithPreferences)
+ }
+ else {
+ ssc.sc.makeRDD(receivers, receivers.size)
+ }
+
+ // Function to start the receiver on the worker node
+ val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => {
+ if (!iterator.hasNext) {
+ throw new Exception("Could not start receiver as details not found.")
+ }
+ iterator.next().start()
+ }
+ // Distribute the receivers and start them
+ ssc.sc.runJob(tempRDD, startReceiver)
+ }
+
+ /** Stops the receivers. */
+ def stopReceivers() {
+ // Signal the receivers to stop
+ receiverInfo.values.foreach(_ ! StopReceiver)
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
new file mode 100644
index 0000000000..fbcf061126
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala
@@ -0,0 +1,562 @@
+package spark.streaming
+
+import spark.streaming.StreamingContext._
+import spark.streaming.dstream.{ReducedWindowedDStream, StateDStream}
+import spark.streaming.dstream.{CoGroupedDStream, ShuffledDStream}
+import spark.streaming.dstream.{MapValuedDStream, FlatMapValuedDStream}
+
+import spark.{Manifests, RDD, Partitioner, HashPartitioner}
+import spark.SparkContext._
+import spark.storage.StorageLevel
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.hadoop.mapred.{JobConf, OutputFormat}
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
+import org.apache.hadoop.mapred.OutputFormat
+import org.apache.hadoop.conf.Configuration
+
+class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)])
+extends Serializable {
+
+ def ssc = self.ssc
+
+ private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = {
+ new HashPartitioner(numPartitions)
+ }
+
+ /**
+ * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
+ * generate the RDDs with Spark's default number of partitions.
+ */
+ def groupByKey(): DStream[(K, Seq[V])] = {
+ groupByKey(defaultPartitioner())
+ }
+
+ /**
+ * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
+ * generate the RDDs with `numPartitions` partitions.
+ */
+ def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = {
+ groupByKey(defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Create a new DStream by applying `groupByKey` on each RDD. The supplied [[spark.Partitioner]]
+ * is used to control the partitioning of each RDD.
+ */
+ def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = {
+ val createCombiner = (v: V) => ArrayBuffer[V](v)
+ val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v)
+ val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2)
+ combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner)
+ .asInstanceOf[DStream[(K, Seq[V])]]
+ }
+
+ /**
+ * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * merged using the associative reduce function. Hash partitioning is used to generate the RDDs
+ * with Spark's default number of partitions.
+ */
+ def reduceByKey(reduceFunc: (V, V) => V): DStream[(K, V)] = {
+ reduceByKey(reduceFunc, defaultPartitioner())
+ }
+
+ /**
+ * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs
+ * with `numPartitions` partitions.
+ */
+ def reduceByKey(reduceFunc: (V, V) => V, numPartitions: Int): DStream[(K, V)] = {
+ reduceByKey(reduceFunc, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * merged using the supplied reduce function. [[spark.Partitioner]] is used to control the
+ * partitioning of each RDD.
+ */
+ def reduceByKey(reduceFunc: (V, V) => V, partitioner: Partitioner): DStream[(K, V)] = {
+ val cleanedReduceFunc = ssc.sc.clean(reduceFunc)
+ combineByKey((v: V) => v, cleanedReduceFunc, cleanedReduceFunc, partitioner)
+ }
+
+ /**
+ * Combine elements of each key in DStream's RDDs using custom function. This is similar to the
+ * combineByKey for RDDs. Please refer to combineByKey in [[spark.PairRDDFunctions]] for more
+ * information.
+ */
+ def combineByKey[C: ClassManifest](
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiner: (C, C) => C,
+ partitioner: Partitioner) : DStream[(K, C)] = {
+ new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner)
+ }
+
+ /**
+ * Create a new DStream by counting the number of values of each key in each RDD. Hash
+ * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions.
+ */
+ def countByKey(numPartitions: Int = self.ssc.sc.defaultParallelism): DStream[(K, Long)] = {
+ self.map(x => (x._1, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions)
+ }
+
+ /**
+ * Creates a new DStream by applying `groupByKey` over a sliding window. This is similar to
+ * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs
+ * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with
+ * Spark's default number of partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ */
+ def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Seq[V])] = {
+ groupByKeyAndWindow(windowDuration, self.slideDuration, defaultPartitioner())
+ }
+
+ /**
+ * Create a new DStream by applying `groupByKey` over a sliding window. Similar to
+ * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to
+ * generate the RDDs with Spark's default number of partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): DStream[(K, Seq[V])] = {
+ groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner())
+ }
+
+ /**
+ * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream.
+ * Similar to `DStream.groupByKey()`, but applies it over a sliding window.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param numPartitions Number of partitions of each RDD in the new DStream.
+ */
+ def groupByKeyAndWindow(
+ windowDuration: Duration,
+ slideDuration: Duration,
+ numPartitions: Int
+ ): DStream[(K, Seq[V])] = {
+ groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream.
+ * Similar to `DStream.groupByKey()`, but applies it over a sliding window.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ */
+ def groupByKeyAndWindow(
+ windowDuration: Duration,
+ slideDuration: Duration,
+ partitioner: Partitioner
+ ): DStream[(K, Seq[V])] = {
+ self.window(windowDuration, slideDuration).groupByKey(partitioner)
+ }
+
+ /**
+ * Create a new DStream by applying `reduceByKey` over a sliding window on `this` DStream.
+ * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream
+ * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate
+ * the RDDs with Spark's default number of partitions.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ windowDuration: Duration
+ ): DStream[(K, V)] = {
+ reduceByKeyAndWindow(reduceFunc, windowDuration, self.slideDuration, defaultPartitioner())
+ }
+
+ /**
+ * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to
+ * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to
+ * generate the RDDs with Spark's default number of partitions.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ windowDuration: Duration,
+ slideDuration: Duration
+ ): DStream[(K, V)] = {
+ reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner())
+ }
+
+ /**
+ * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to
+ * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to
+ * generate the RDDs with `numPartitions` partitions.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param numPartitions Number of partitions of each RDD in the new DStream.
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ windowDuration: Duration,
+ slideDuration: Duration,
+ numPartitions: Int
+ ): DStream[(K, V)] = {
+ reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Create a new DStream by applying `reduceByKey` over a sliding window. Similar to
+ * `DStream.reduceByKey()`, but applies it over a sliding window.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ windowDuration: Duration,
+ slideDuration: Duration,
+ partitioner: Partitioner
+ ): DStream[(K, V)] = {
+ val cleanedReduceFunc = ssc.sc.clean(reduceFunc)
+ self.reduceByKey(cleanedReduceFunc, partitioner)
+ .window(windowDuration, slideDuration)
+ .reduceByKey(cleanedReduceFunc, partitioner)
+ }
+
+ /**
+ * Create a new DStream by reducing over a using incremental computation.
+ * The reduced value of over a new window is calculated using the old window's reduce value :
+ * 1. reduce the new values that entered the window (e.g., adding new counts)
+ * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+ * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function.
+ * However, it is applicable to only "invertible reduce functions".
+ * Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
+ * @param reduceFunc associative reduce function
+ * @param invReduceFunc inverse function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ invReduceFunc: (V, V) => V,
+ windowDuration: Duration,
+ slideDuration: Duration
+ ): DStream[(K, V)] = {
+
+ reduceByKeyAndWindow(
+ reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner())
+ }
+
+ /**
+ * Create a new DStream by reducing over a using incremental computation.
+ * The reduced value of over a new window is calculated using the old window's reduce value :
+ * 1. reduce the new values that entered the window (e.g., adding new counts)
+ * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+ * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function.
+ * However, it is applicable to only "invertible reduce functions".
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ * @param reduceFunc associative reduce function
+ * @param invReduceFunc inverse function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param numPartitions Number of partitions of each RDD in the new DStream.
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ invReduceFunc: (V, V) => V,
+ windowDuration: Duration,
+ slideDuration: Duration,
+ numPartitions: Int
+ ): DStream[(K, V)] = {
+
+ reduceByKeyAndWindow(
+ reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Create a new DStream by reducing over a using incremental computation.
+ * The reduced value of over a new window is calculated using the old window's reduce value :
+ * 1. reduce the new values that entered the window (e.g., adding new counts)
+ * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+ * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function.
+ * However, it is applicable to only "invertible reduce functions".
+ * @param reduceFunc associative reduce function
+ * @param invReduceFunc inverse function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: (V, V) => V,
+ invReduceFunc: (V, V) => V,
+ windowDuration: Duration,
+ slideDuration: Duration,
+ partitioner: Partitioner
+ ): DStream[(K, V)] = {
+
+ val cleanedReduceFunc = ssc.sc.clean(reduceFunc)
+ val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc)
+ new ReducedWindowedDStream[K, V](
+ self, cleanedReduceFunc, cleanedInvReduceFunc, windowDuration, slideDuration, partitioner)
+ }
+
+ /**
+ * Create a new DStream by counting the number of values for each key over a window.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param numPartitions Number of partitions of each RDD in the new DStream.
+ */
+ def countByKeyAndWindow(
+ windowDuration: Duration,
+ slideDuration: Duration,
+ numPartitions: Int = self.ssc.sc.defaultParallelism
+ ): DStream[(K, Long)] = {
+
+ self.map(x => (x._1, 1L)).reduceByKeyAndWindow(
+ (x: Long, y: Long) => x + y,
+ (x: Long, y: Long) => x - y,
+ windowDuration,
+ slideDuration,
+ numPartitions
+ )
+ }
+
+ /**
+ * Create a new "state" DStream where the state for each key is updated by applying
+ * the given function on the previous state of the key and the new values of each key.
+ * Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
+ * @param updateFunc State update function. If `this` function returns None, then
+ * corresponding state key-value pair will be eliminated.
+ * @tparam S State type
+ */
+ def updateStateByKey[S: ClassManifest](
+ updateFunc: (Seq[V], Option[S]) => Option[S]
+ ): DStream[(K, S)] = {
+ updateStateByKey(updateFunc, defaultPartitioner())
+ }
+
+ /**
+ * Create a new "state" DStream where the state for each key is updated by applying
+ * the given function on the previous state of the key and the new values of each key.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ * @param updateFunc State update function. If `this` function returns None, then
+ * corresponding state key-value pair will be eliminated.
+ * @param numPartitions Number of partitions of each RDD in the new DStream.
+ * @tparam S State type
+ */
+ def updateStateByKey[S: ClassManifest](
+ updateFunc: (Seq[V], Option[S]) => Option[S],
+ numPartitions: Int
+ ): DStream[(K, S)] = {
+ updateStateByKey(updateFunc, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Create a new "state" DStream where the state for each key is updated by applying
+ * the given function on the previous state of the key and the new values of the key.
+ * [[spark.Partitioner]] is used to control the partitioning of each RDD.
+ * @param updateFunc State update function. If `this` function returns None, then
+ * corresponding state key-value pair will be eliminated.
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ * @tparam S State type
+ */
+ def updateStateByKey[S: ClassManifest](
+ updateFunc: (Seq[V], Option[S]) => Option[S],
+ partitioner: Partitioner
+ ): DStream[(K, S)] = {
+ val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
+ iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
+ }
+ updateStateByKey(newUpdateFunc, partitioner, true)
+ }
+
+ /**
+ * Create a new "state" DStream where the state for each key is updated by applying
+ * the given function on the previous state of the key and the new values of each key.
+ * [[spark.Paxrtitioner]] is used to control the partitioning of each RDD.
+ * @param updateFunc State update function. If `this` function returns None, then
+ * corresponding state key-value pair will be eliminated. Note, that
+ * this function may generate a different a tuple with a different key
+ * than the input key. It is up to the developer to decide whether to
+ * remember the partitioner despite the key being changed.
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs.
+ * @tparam S State type
+ */
+ def updateStateByKey[S: ClassManifest](
+ updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
+ partitioner: Partitioner,
+ rememberPartitioner: Boolean
+ ): DStream[(K, S)] = {
+ new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner)
+ }
+
+
+ def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = {
+ new MapValuedDStream[K, V, U](self, mapValuesFunc)
+ }
+
+ def flatMapValues[U: ClassManifest](
+ flatMapValuesFunc: V => TraversableOnce[U]
+ ): DStream[(K, U)] = {
+ new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc)
+ }
+
+ /**
+ * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this`
+ * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
+ * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number
+ * of partitions.
+ */
+ def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = {
+ cogroup(other, defaultPartitioner())
+ }
+
+ /**
+ * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this`
+ * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
+ * key in both RDDs. Partitioner is used to partition each generated RDD.
+ */
+ def cogroup[W: ClassManifest](
+ other: DStream[(K, W)],
+ partitioner: Partitioner
+ ): DStream[(K, (Seq[V], Seq[W]))] = {
+
+ val cgd = new CoGroupedDStream[K](
+ Seq(self.asInstanceOf[DStream[(_, _)]], other.asInstanceOf[DStream[(_, _)]]),
+ partitioner
+ )
+ val pdfs = new PairDStreamFunctions[K, Seq[Seq[_]]](cgd)(
+ classManifest[K],
+ Manifests.seqSeqManifest
+ )
+ pdfs.mapValues {
+ case Seq(vs, ws) =>
+ (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
+ }
+ }
+
+ /**
+ * Join `this` DStream with `other` DStream. HashPartitioner is used
+ * to partition each generated RDD into default number of partitions.
+ */
+ def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = {
+ join[W](other, defaultPartitioner())
+ }
+
+ /**
+ * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will
+ * be generated by joining RDDs from `this` and other DStream. Uses the given
+ * Partitioner to partition each generated RDD.
+ */
+ def join[W: ClassManifest](
+ other: DStream[(K, W)],
+ partitioner: Partitioner
+ ): DStream[(K, (V, W))] = {
+ this.cogroup(other, partitioner)
+ .flatMapValues{
+ case (vs, ws) =>
+ for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
+ }
+ }
+
+ /**
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated
+ * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix"
+ */
+ def saveAsHadoopFiles[F <: OutputFormat[K, V]](
+ prefix: String,
+ suffix: String
+ )(implicit fm: ClassManifest[F]) {
+ saveAsHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ }
+
+ /**
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated
+ * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix"
+ */
+ def saveAsHadoopFiles(
+ prefix: String,
+ suffix: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: OutputFormat[_, _]],
+ conf: JobConf = new JobConf
+ ) {
+ val saveFunc = (rdd: RDD[(K, V)], time: Time) => {
+ val file = rddToFileName(prefix, suffix, time)
+ rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, conf)
+ }
+ self.foreach(saveFunc)
+ }
+
+ /**
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
+ * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
+ */
+ def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](
+ prefix: String,
+ suffix: String
+ )(implicit fm: ClassManifest[F]) {
+ saveAsNewAPIHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ }
+
+ /**
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
+ * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
+ */
+ def saveAsNewAPIHadoopFiles(
+ prefix: String,
+ suffix: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
+ conf: Configuration = new Configuration
+ ) {
+ val saveFunc = (rdd: RDD[(K, V)], time: Time) => {
+ val file = rddToFileName(prefix, suffix, time)
+ rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf)
+ }
+ self.foreach(saveFunc)
+ }
+
+ private def getKeyClass() = implicitly[ClassManifest[K]].erasure
+
+ private def getValueClass() = implicitly[ClassManifest[V]].erasure
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala
new file mode 100644
index 0000000000..c04ed37de8
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala
@@ -0,0 +1,77 @@
+package spark.streaming
+
+import util.{ManualClock, RecurringTimer, Clock}
+import spark.SparkEnv
+import spark.Logging
+
+private[streaming]
+class Scheduler(ssc: StreamingContext) extends Logging {
+
+ initLogging()
+
+ val graph = ssc.graph
+
+ val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt
+ val jobManager = new JobManager(ssc, concurrentJobs)
+
+ val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
+ new CheckpointWriter(ssc.checkpointDir)
+ } else {
+ null
+ }
+
+ val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock")
+ val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock]
+ val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
+ longTime => generateRDDs(new Time(longTime)))
+
+ def start() {
+ // If context was started from checkpoint, then restart timer such that
+ // this timer's triggers occur at the same time as the original timer.
+ // Otherwise just start the timer from scratch, and initialize graph based
+ // on this first trigger time of the timer.
+ if (ssc.isCheckpointPresent) {
+ // If manual clock is being used for testing, then
+ // either set the manual clock to the last checkpointed time,
+ // or if the property is defined set it to that time
+ if (clock.isInstanceOf[ManualClock]) {
+ val lastTime = ssc.getInitialCheckpoint.checkpointTime.milliseconds
+ val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong
+ clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime)
+ }
+ timer.restart(graph.zeroTime.milliseconds)
+ logInfo("Scheduler's timer restarted")
+ } else {
+ val firstTime = new Time(timer.start())
+ graph.start(firstTime - ssc.graph.batchDuration)
+ logInfo("Scheduler's timer started")
+ }
+ logInfo("Scheduler started")
+ }
+
+ def stop() {
+ timer.stop()
+ graph.stop()
+ logInfo("Scheduler stopped")
+ }
+
+ private def generateRDDs(time: Time) {
+ SparkEnv.set(ssc.env)
+ logInfo("\n-----------------------------------------------------\n")
+ graph.generateRDDs(time).foreach(jobManager.runJob)
+ graph.forgetOldRDDs(time)
+ doCheckpoint(time)
+ logInfo("Generated RDDs for time " + time)
+ }
+
+ private def doCheckpoint(time: Time) {
+ if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
+ val startTime = System.currentTimeMillis()
+ ssc.graph.updateCheckpointData(time)
+ checkpointWriter.write(new Checkpoint(ssc, time))
+ val stopTime = System.currentTimeMillis()
+ logInfo("Checkpointing the graph took " + (stopTime - startTime) + " ms")
+ }
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
new file mode 100644
index 0000000000..14500bdcb1
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -0,0 +1,411 @@
+package spark.streaming
+
+import spark.streaming.dstream._
+
+import spark.{RDD, Logging, SparkEnv, SparkContext}
+import spark.storage.StorageLevel
+import spark.util.MetadataCleaner
+
+import scala.collection.mutable.Queue
+
+import java.io.InputStream
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.hadoop.io.LongWritable
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
+import org.apache.hadoop.fs.Path
+import java.util.UUID
+
+/**
+ * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
+ * information (such as, cluster URL and job name) to internally create a SparkContext, it provides
+ * methods used to create DStream from various input sources.
+ */
+class StreamingContext private (
+ sc_ : SparkContext,
+ cp_ : Checkpoint,
+ batchDur_ : Duration
+ ) extends Logging {
+
+ /**
+ * Creates a StreamingContext using an existing SparkContext.
+ * @param sparkContext Existing SparkContext
+ * @param batchDuration The time interval at which streaming data will be divided into batches
+ */
+ def this(sparkContext: SparkContext, batchDuration: Duration) = this(sparkContext, null, batchDuration)
+
+ /**
+ * Creates a StreamingContext by providing the details necessary for creating a new SparkContext.
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param frameworkName A name for your job, to display on the cluster web UI
+ * @param batchDuration The time interval at which streaming data will be divided into batches
+ */
+ def this(master: String, frameworkName: String, batchDuration: Duration) =
+ this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration)
+
+ /**
+ * Re-creates a StreamingContext from a checkpoint file.
+ * @param path Path either to the directory that was specified as the checkpoint directory, or
+ * to the checkpoint file 'graph' or 'graph.bk'.
+ */
+ def this(path: String) = this(null, CheckpointReader.read(path), null)
+
+ initLogging()
+
+ if (sc_ == null && cp_ == null) {
+ throw new Exception("Streaming Context cannot be initilalized with " +
+ "both SparkContext and checkpoint as null")
+ }
+
+ protected[streaming] val isCheckpointPresent = (cp_ != null)
+
+ val sc: SparkContext = {
+ if (isCheckpointPresent) {
+ new SparkContext(cp_.master, cp_.framework, cp_.sparkHome, cp_.jars)
+ } else {
+ sc_
+ }
+ }
+
+ protected[streaming] val env = SparkEnv.get
+
+ protected[streaming] val graph: DStreamGraph = {
+ if (isCheckpointPresent) {
+ cp_.graph.setContext(this)
+ cp_.graph.restoreCheckpointData()
+ cp_.graph
+ } else {
+ assert(batchDur_ != null, "Batch duration for streaming context cannot be null")
+ val newGraph = new DStreamGraph()
+ newGraph.setBatchDuration(batchDur_)
+ newGraph
+ }
+ }
+
+ protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0)
+ protected[streaming] var networkInputTracker: NetworkInputTracker = null
+
+ protected[streaming] var checkpointDir: String = {
+ if (isCheckpointPresent) {
+ sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(cp_.checkpointDir), true)
+ cp_.checkpointDir
+ } else {
+ null
+ }
+ }
+
+ protected[streaming] var checkpointDuration: Duration = if (isCheckpointPresent) cp_.checkpointDuration else null
+ protected[streaming] var receiverJobThread: Thread = null
+ protected[streaming] var scheduler: Scheduler = null
+
+ /**
+ * Sets each DStreams in this context to remember RDDs it generated in the last given duration.
+ * DStreams remember RDDs only for a limited duration of time and releases them for garbage
+ * collection. This method allows the developer to specify how to long to remember the RDDs (
+ * if the developer wishes to query old data outside the DStream computation).
+ * @param duration Minimum duration that each DStream should remember its RDDs
+ */
+ def remember(duration: Duration) {
+ graph.remember(duration)
+ }
+
+ /**
+ * Sets the context to periodically checkpoint the DStream operations for master
+ * fault-tolerance. By default, the graph will be checkpointed every batch interval.
+ * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored
+ * @param interval checkpoint interval
+ */
+ def checkpoint(directory: String, interval: Duration = null) {
+ if (directory != null) {
+ sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(directory))
+ checkpointDir = directory
+ checkpointDuration = interval
+ } else {
+ checkpointDir = null
+ checkpointDuration = null
+ }
+ }
+
+ protected[streaming] def getInitialCheckpoint(): Checkpoint = {
+ if (isCheckpointPresent) cp_ else null
+ }
+
+ protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
+
+ /**
+ * Create an input stream that pulls messages form a Kafka Broker.
+ * @param hostname Zookeper hostname.
+ * @param port Zookeper port.
+ * @param groupId The group id for this consumer.
+ * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+ * in its own thread.
+ * @param initialOffsets Optional initial offsets for each of the partitions to consume.
+ * By default the value is pulled from zookeper.
+ * @param storageLevel RDD storage level. Defaults to memory-only.
+ */
+ def kafkaStream[T: ClassManifest](
+ hostname: String,
+ port: Int,
+ groupId: String,
+ topics: Map[String, Int],
+ initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](),
+ storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
+ ): DStream[T] = {
+ val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel)
+ registerInputStream(inputStream)
+ inputStream
+ }
+
+ /**
+ * Create a input stream from network source hostname:port. Data is received using
+ * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited
+ * lines.
+ * @param hostname Hostname to connect to for receiving data
+ * @param port Port to connect to for receiving data
+ * @param storageLevel Storage level to use for storing the received objects
+ * (default: StorageLevel.MEMORY_AND_DISK_SER_2)
+ */
+ def networkTextStream(
+ hostname: String,
+ port: Int,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
+ ): DStream[String] = {
+ networkStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel)
+ }
+
+ /**
+ * Create a input stream from network source hostname:port. Data is received using
+ * a TCP socket and the receive bytes it interepreted as object using the given
+ * converter.
+ * @param hostname Hostname to connect to for receiving data
+ * @param port Port to connect to for receiving data
+ * @param converter Function to convert the byte stream to objects
+ * @param storageLevel Storage level to use for storing the received objects
+ * @tparam T Type of the objects received (after converting bytes to objects)
+ */
+ def networkStream[T: ClassManifest](
+ hostname: String,
+ port: Int,
+ converter: (InputStream) => Iterator[T],
+ storageLevel: StorageLevel
+ ): DStream[T] = {
+ val inputStream = new SocketInputDStream[T](this, hostname, port, converter, storageLevel)
+ registerInputStream(inputStream)
+ inputStream
+ }
+
+ /**
+ * Creates a input stream from a Flume source.
+ * @param hostname Hostname of the slave machine to which the flume data will be sent
+ * @param port Port of the slave machine to which the flume data will be sent
+ * @param storageLevel Storage level to use for storing the received objects
+ */
+ def flumeStream (
+ hostname: String,
+ port: Int,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
+ ): DStream[SparkFlumeEvent] = {
+ val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel)
+ registerInputStream(inputStream)
+ inputStream
+ }
+
+ /**
+ * Create a input stream from network source hostname:port, where data is received
+ * as serialized blocks (serialized using the Spark's serializer) that can be directly
+ * pushed into the block manager without deserializing them. This is the most efficient
+ * way to receive data.
+ * @param hostname Hostname to connect to for receiving data
+ * @param port Port to connect to for receiving data
+ * @param storageLevel Storage level to use for storing the received objects
+ * @tparam T Type of the objects in the received blocks
+ */
+ def rawNetworkStream[T: ClassManifest](
+ hostname: String,
+ port: Int,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
+ ): DStream[T] = {
+ val inputStream = new RawInputDStream[T](this, hostname, port, storageLevel)
+ registerInputStream(inputStream)
+ inputStream
+ }
+
+ /**
+ * Creates a input stream that monitors a Hadoop-compatible filesystem
+ * for new files and reads them using the given key-value types and input format.
+ * File names starting with . are ignored.
+ * @param directory HDFS directory to monitor for new file
+ * @tparam K Key type for reading HDFS file
+ * @tparam V Value type for reading HDFS file
+ * @tparam F Input format for reading HDFS file
+ */
+ def fileStream[
+ K: ClassManifest,
+ V: ClassManifest,
+ F <: NewInputFormat[K, V]: ClassManifest
+ ] (directory: String): DStream[(K, V)] = {
+ val inputStream = new FileInputDStream[K, V, F](this, directory)
+ registerInputStream(inputStream)
+ inputStream
+ }
+
+ /**
+ * Creates a input stream that monitors a Hadoop-compatible filesystem
+ * for new files and reads them using the given key-value types and input format.
+ * @param directory HDFS directory to monitor for new file
+ * @param filter Function to filter paths to process
+ * @param newFilesOnly Should process only new files and ignore existing files in the directory
+ * @tparam K Key type for reading HDFS file
+ * @tparam V Value type for reading HDFS file
+ * @tparam F Input format for reading HDFS file
+ */
+ def fileStream[
+ K: ClassManifest,
+ V: ClassManifest,
+ F <: NewInputFormat[K, V]: ClassManifest
+ ] (directory: String, filter: Path => Boolean, newFilesOnly: Boolean): DStream[(K, V)] = {
+ val inputStream = new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly)
+ registerInputStream(inputStream)
+ inputStream
+ }
+
+
+ /**
+ * Creates a input stream that monitors a Hadoop-compatible filesystem
+ * for new files and reads them as text files (using key as LongWritable, value
+ * as Text and input format as TextInputFormat). File names starting with . are ignored.
+ * @param directory HDFS directory to monitor for new file
+ */
+ def textFileStream(directory: String): DStream[String] = {
+ fileStream[LongWritable, Text, TextInputFormat](directory).map(_._2.toString)
+ }
+
+ /**
+ * Creates a input stream from an queue of RDDs. In each batch,
+ * it will process either one or all of the RDDs returned by the queue.
+ * @param queue Queue of RDDs
+ * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval
+ * @param defaultRDD Default RDD is returned by the DStream when the queue is empty
+ * @tparam T Type of objects in the RDD
+ */
+ def queueStream[T: ClassManifest](
+ queue: Queue[RDD[T]],
+ oneAtATime: Boolean = true,
+ defaultRDD: RDD[T] = null
+ ): DStream[T] = {
+ val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD)
+ registerInputStream(inputStream)
+ inputStream
+ }
+
+ /**
+ * Create a unified DStream from multiple DStreams of the same type and same interval
+ */
+ def union[T: ClassManifest](streams: Seq[DStream[T]]): DStream[T] = {
+ new UnionDStream[T](streams.toArray)
+ }
+
+ /**
+ * Registers an input stream that will be started (InputDStream.start() called) to get the
+ * input data.
+ */
+ def registerInputStream(inputStream: InputDStream[_]) {
+ graph.addInputStream(inputStream)
+ }
+
+ /**
+ * Registers an output stream that will be computed every interval
+ */
+ def registerOutputStream(outputStream: DStream[_]) {
+ graph.addOutputStream(outputStream)
+ }
+
+ protected def validate() {
+ assert(graph != null, "Graph is null")
+ graph.validate()
+
+ assert(
+ checkpointDir == null || checkpointDuration != null,
+ "Checkpoint directory has been set, but the graph checkpointing interval has " +
+ "not been set. Please use StreamingContext.checkpoint() to set the interval."
+ )
+ }
+
+ /**
+ * Starts the execution of the streams.
+ */
+ def start() {
+ if (checkpointDir != null && checkpointDuration == null && graph != null) {
+ checkpointDuration = graph.batchDuration
+ }
+
+ validate()
+
+ val networkInputStreams = graph.getInputStreams().filter(s => s match {
+ case n: NetworkInputDStream[_] => true
+ case _ => false
+ }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray
+
+ if (networkInputStreams.length > 0) {
+ // Start the network input tracker (must start before receivers)
+ networkInputTracker = new NetworkInputTracker(this, networkInputStreams)
+ networkInputTracker.start()
+ }
+
+ Thread.sleep(1000)
+
+ // Start the scheduler
+ scheduler = new Scheduler(this)
+ scheduler.start()
+ }
+
+ /**
+ * Sstops the execution of the streams.
+ */
+ def stop() {
+ try {
+ if (scheduler != null) scheduler.stop()
+ if (networkInputTracker != null) networkInputTracker.stop()
+ if (receiverJobThread != null) receiverJobThread.interrupt()
+ sc.stop()
+ logInfo("StreamingContext stopped successfully")
+ } catch {
+ case e: Exception => logWarning("Error while stopping", e)
+ }
+ }
+}
+
+
+object StreamingContext {
+
+ implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = {
+ new PairDStreamFunctions[K, V](stream)
+ }
+
+ protected[streaming] def createNewSparkContext(master: String, frameworkName: String): SparkContext = {
+
+ // Set the default cleaner delay to an hour if not already set.
+ // This should be sufficient for even 1 second interval.
+ if (MetadataCleaner.getDelaySeconds < 0) {
+ MetadataCleaner.setDelaySeconds(60)
+ }
+ new SparkContext(master, frameworkName)
+ }
+
+ protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = {
+ if (prefix == null) {
+ time.milliseconds.toString
+ } else if (suffix == null || suffix.length ==0) {
+ prefix + "-" + time.milliseconds
+ } else {
+ prefix + "-" + time.milliseconds + "." + suffix
+ }
+ }
+
+ protected[streaming] def getSparkCheckpointDir(sscCheckpointDir: String): String = {
+ new Path(sscCheckpointDir, UUID.randomUUID.toString).toString
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala
new file mode 100644
index 0000000000..5daeb761dd
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/Time.scala
@@ -0,0 +1,42 @@
+package spark.streaming
+
+/**
+ * This is a simple class that represents an absolute instant of time.
+ * Internally, it represents time as the difference, measured in milliseconds, between the current
+ * time and midnight, January 1, 1970 UTC. This is the same format as what is returned by
+ * System.currentTimeMillis.
+ */
+case class Time(private val millis: Long) {
+
+ def milliseconds: Long = millis
+
+ def < (that: Time): Boolean = (this.millis < that.millis)
+
+ def <= (that: Time): Boolean = (this.millis <= that.millis)
+
+ def > (that: Time): Boolean = (this.millis > that.millis)
+
+ def >= (that: Time): Boolean = (this.millis >= that.millis)
+
+ def + (that: Duration): Time = new Time(millis + that.milliseconds)
+
+ def - (that: Time): Duration = new Duration(millis - that.millis)
+
+ def - (that: Duration): Time = new Time(millis - that.milliseconds)
+
+ def floor(that: Duration): Time = {
+ val t = that.milliseconds
+ val m = math.floor(this.millis / t).toLong
+ new Time(m * t)
+ }
+
+ def isMultipleOf(that: Duration): Boolean =
+ (this.millis % that.milliseconds == 0)
+
+ def min(that: Time): Time = if (this < that) this else that
+
+ def max(that: Time): Time = if (this > that) this else that
+
+ override def toString: String = (millis.toString + " ms")
+
+} \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala
new file mode 100644
index 0000000000..2e7466b16c
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala
@@ -0,0 +1,91 @@
+package spark.streaming.api.java
+
+import spark.streaming.{Duration, Time, DStream}
+import spark.api.java.function.{Function => JFunction}
+import spark.api.java.JavaRDD
+import spark.storage.StorageLevel
+
+/**
+ * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
+ * sequence of RDDs (of the same type) representing a continuous stream of data (see [[spark.RDD]]
+ * for more details on RDDs). DStreams can either be created from live data (such as, data from
+ * HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations
+ * such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each
+ * DStream periodically generates a RDD, either from live data or by transforming the RDD generated
+ * by a parent DStream.
+ *
+ * This class contains the basic operations available on all DStreams, such as `map`, `filter` and
+ * `window`. In addition, [[spark.streaming.api.java.JavaPairDStream]] contains operations available
+ * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. These operations
+ * are automatically available on any DStream of the right type (e.g., DStream[(Int, Int)] through
+ * implicit conversions when `spark.streaming.StreamingContext._` is imported.
+ *
+ * DStreams internally is characterized by a few basic properties:
+ * - A list of other DStreams that the DStream depends on
+ * - A time interval at which the DStream generates an RDD
+ * - A function that is used to generate an RDD after each time interval
+ */
+class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T])
+ extends JavaDStreamLike[T, JavaDStream[T]] {
+
+ /** Return a new DStream containing only the elements that satisfy a predicate. */
+ def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] =
+ dstream.filter((x => f(x).booleanValue()))
+
+ /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
+ def cache(): JavaDStream[T] = dstream.cache()
+
+ /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
+ def persist(): JavaDStream[T] = dstream.cache()
+
+ /** Persist the RDDs of this DStream with the given storage level */
+ def persist(storageLevel: StorageLevel): JavaDStream[T] = dstream.persist(storageLevel)
+
+ /** Generate an RDD for the given duration */
+ def compute(validTime: Time): JavaRDD[T] = {
+ dstream.compute(validTime) match {
+ case Some(rdd) => new JavaRDD(rdd)
+ case None => null
+ }
+ }
+
+ /**
+ * Return a new DStream which is computed based on windowed batches of this DStream.
+ * The new DStream generates RDDs with the same interval as this DStream.
+ * @param windowDuration width of the window; must be a multiple of this DStream's interval.
+ * @return
+ */
+ def window(windowDuration: Duration): JavaDStream[T] =
+ dstream.window(windowDuration)
+
+ /**
+ * Return a new DStream which is computed based on windowed batches of this DStream.
+ * @param windowDuration duration (i.e., width) of the window;
+ * must be a multiple of this DStream's interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's interval
+ */
+ def window(windowDuration: Duration, slideDuration: Duration): JavaDStream[T] =
+ dstream.window(windowDuration, slideDuration)
+
+ /**
+ * Return a new DStream which computed based on tumbling window on this DStream.
+ * This is equivalent to window(batchDuration, batchDuration).
+ * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval
+ */
+ def tumble(batchDuration: Duration): JavaDStream[T] =
+ dstream.tumble(batchDuration)
+
+ /**
+ * Return a new DStream by unifying data of another DStream with this DStream.
+ * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream.
+ */
+ def union(that: JavaDStream[T]): JavaDStream[T] =
+ dstream.union(that.dstream)
+}
+
+object JavaDStream {
+ implicit def fromDStream[T: ClassManifest](dstream: DStream[T]): JavaDStream[T] =
+ new JavaDStream[T](dstream)
+} \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala
new file mode 100644
index 0000000000..b93cb7865a
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala
@@ -0,0 +1,183 @@
+package spark.streaming.api.java
+
+import java.util.{List => JList}
+import java.lang.{Long => JLong}
+
+import scala.collection.JavaConversions._
+
+import spark.streaming._
+import spark.api.java.JavaRDD
+import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
+import java.util
+import spark.RDD
+import JavaDStream._
+
+trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable {
+ implicit val classManifest: ClassManifest[T]
+
+ def dstream: DStream[T]
+
+ implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[JLong] = {
+ in.map(new JLong(_))
+ }
+
+ /**
+ * Print the first ten elements of each RDD generated in this DStream. This is an output
+ * operator, so this DStream will be registered as an output stream and there materialized.
+ */
+ def print() = dstream.print()
+
+ /**
+ * Return a new DStream in which each RDD has a single element generated by counting each RDD
+ * of this DStream.
+ */
+ def count(): JavaDStream[JLong] = dstream.count()
+
+ /**
+ * Return a new DStream in which each RDD has a single element generated by counting the number
+ * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the
+ * window() operation. This is equivalent to window(windowDuration, slideDuration).count()
+ */
+ def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[JLong] = {
+ dstream.countByWindow(windowDuration, slideDuration)
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying glom() to each RDD of
+ * this DStream. Applying glom() to an RDD coalesces all elements within each partition into
+ * an array.
+ */
+ def glom(): JavaDStream[JList[T]] =
+ new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
+
+ /** Return the StreamingContext associated with this DStream */
+ def context(): StreamingContext = dstream.context()
+
+ /** Return a new DStream by applying a function to all elements of this DStream. */
+ def map[R](f: JFunction[T, R]): JavaDStream[R] = {
+ new JavaDStream(dstream.map(f)(f.returnType()))(f.returnType())
+ }
+
+ /** Return a new DStream by applying a function to all elements of this DStream. */
+ def map[K, V](f: PairFunction[T, K, V]): JavaPairDStream[K, V] = {
+ def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+ new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType())
+ }
+
+ /**
+ * Return a new DStream by applying a function to all elements of this DStream,
+ * and then flattening the results
+ */
+ def flatMap[U](f: FlatMapFunction[T, U]): JavaDStream[U] = {
+ import scala.collection.JavaConverters._
+ def fn = (x: T) => f.apply(x).asScala
+ new JavaDStream(dstream.flatMap(fn)(f.elementType()))(f.elementType())
+ }
+
+ /**
+ * Return a new DStream by applying a function to all elements of this DStream,
+ * and then flattening the results
+ */
+ def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairDStream[K, V] = {
+ import scala.collection.JavaConverters._
+ def fn = (x: T) => f.apply(x).asScala
+ def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+ new JavaPairDStream(dstream.flatMap(fn)(cm))(f.keyType(), f.valueType())
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
+ * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
+ * of the RDD.
+ */
+ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaDStream[U] = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ new JavaDStream(dstream.mapPartitions(fn)(f.elementType()))(f.elementType())
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
+ * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
+ * of the RDD.
+ */
+ def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V])
+ : JavaPairDStream[K, V] = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ new JavaPairDStream(dstream.mapPartitions(fn))(f.keyType(), f.valueType())
+ }
+
+ /**
+ * Return a new DStream in which each RDD has a single element generated by reducing each RDD
+ * of this DStream.
+ */
+ def reduce(f: JFunction2[T, T, T]): JavaDStream[T] = dstream.reduce(f)
+
+ /**
+ * Return a new DStream in which each RDD has a single element generated by reducing all
+ * elements in a window over this DStream. windowDuration and slideDuration are as defined in the
+ * window() operation. This is equivalent to window(windowDuration, slideDuration).reduce(reduceFunc)
+ */
+ def reduceByWindow(
+ reduceFunc: JFunction2[T, T, T],
+ invReduceFunc: JFunction2[T, T, T],
+ windowDuration: Duration,
+ slideDuration: Duration
+ ): JavaDStream[T] = {
+ dstream.reduceByWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration)
+ }
+
+ /**
+ * Return all the RDDs between 'fromDuration' to 'toDuration' (both included)
+ */
+ def slice(fromDuration: Duration, toDuration: Duration): JList[JavaRDD[T]] = {
+ new util.ArrayList(dstream.slice(fromDuration, toDuration).map(new JavaRDD(_)).toSeq)
+ }
+
+ /**
+ * Apply a function to each RDD in this DStream. This is an output operator, so
+ * this DStream will be registered as an output stream and therefore materialized.
+ */
+ def foreach(foreachFunc: JFunction[JavaRDD[T], Void]) {
+ dstream.foreach(rdd => foreachFunc.call(new JavaRDD(rdd)))
+ }
+
+ /**
+ * Apply a function to each RDD in this DStream. This is an output operator, so
+ * this DStream will be registered as an output stream and therefore materialized.
+ */
+ def foreach(foreachFunc: JFunction2[JavaRDD[T], Time, Void]) {
+ dstream.foreach((rdd, time) => foreachFunc.call(new JavaRDD(rdd), time))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of this DStream.
+ */
+ def transform[U](transformFunc: JFunction[JavaRDD[T], JavaRDD[U]]): JavaDStream[U] = {
+ implicit val cm: ClassManifest[U] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ def scalaTransform (in: RDD[T]): RDD[U] =
+ transformFunc.call(new JavaRDD[T](in)).rdd
+ dstream.transform(scalaTransform(_))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of this DStream.
+ */
+ def transform[U](transformFunc: JFunction2[JavaRDD[T], Time, JavaRDD[U]]): JavaDStream[U] = {
+ implicit val cm: ClassManifest[U] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ def scalaTransform (in: RDD[T], time: Time): RDD[U] =
+ transformFunc.call(new JavaRDD[T](in), time).rdd
+ dstream.transform(scalaTransform(_, _))
+ }
+
+ /**
+ * Enable periodic checkpointing of RDDs of this DStream
+ * @param interval Time interval after which generated RDD will be checkpointed
+ */
+ def checkpoint(interval: Duration) = {
+ dstream.checkpoint(interval)
+ }
+} \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala
new file mode 100644
index 0000000000..ef10c091ca
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala
@@ -0,0 +1,638 @@
+package spark.streaming.api.java
+
+import java.util.{List => JList}
+import java.lang.{Long => JLong}
+
+import scala.collection.JavaConversions._
+
+import spark.streaming._
+import spark.streaming.StreamingContext._
+import spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
+import spark.Partitioner
+import org.apache.hadoop.mapred.{JobConf, OutputFormat}
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
+import org.apache.hadoop.conf.Configuration
+import spark.api.java.JavaPairRDD
+import spark.storage.StorageLevel
+import com.google.common.base.Optional
+
+class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
+ implicit val kManifiest: ClassManifest[K],
+ implicit val vManifest: ClassManifest[V])
+ extends JavaDStreamLike[(K, V), JavaPairDStream[K, V]] {
+
+ // =======================================================================
+ // Methods common to all DStream's
+ // =======================================================================
+
+ /** Returns a new DStream containing only the elements that satisfy a predicate. */
+ def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairDStream[K, V] =
+ dstream.filter((x => f(x).booleanValue()))
+
+ /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
+ def cache(): JavaPairDStream[K, V] = dstream.cache()
+
+ /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */
+ def persist(): JavaPairDStream[K, V] = dstream.cache()
+
+ /** Persists the RDDs of this DStream with the given storage level */
+ def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel)
+
+ /** Method that generates a RDD for the given Duration */
+ def compute(validTime: Time): JavaPairRDD[K, V] = {
+ dstream.compute(validTime) match {
+ case Some(rdd) => new JavaPairRDD(rdd)
+ case None => null
+ }
+ }
+
+ /**
+ * Return a new DStream which is computed based on windowed batches of this DStream.
+ * The new DStream generates RDDs with the same interval as this DStream.
+ * @param windowDuration width of the window; must be a multiple of this DStream's interval.
+ * @return
+ */
+ def window(windowDuration: Duration): JavaPairDStream[K, V] =
+ dstream.window(windowDuration)
+
+ /**
+ * Return a new DStream which is computed based on windowed batches of this DStream.
+ * @param windowDuration duration (i.e., width) of the window;
+ * must be a multiple of this DStream's interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's interval
+ */
+ def window(windowDuration: Duration, slideDuration: Duration): JavaPairDStream[K, V] =
+ dstream.window(windowDuration, slideDuration)
+
+ /**
+ * Returns a new DStream which computed based on tumbling window on this DStream.
+ * This is equivalent to window(batchDuration, batchDuration).
+ * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval
+ */
+ def tumble(batchDuration: Duration): JavaPairDStream[K, V] =
+ dstream.tumble(batchDuration)
+
+ /**
+ * Returns a new DStream by unifying data of another DStream with this DStream.
+ * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream.
+ */
+ def union(that: JavaPairDStream[K, V]): JavaPairDStream[K, V] =
+ dstream.union(that.dstream)
+
+ // =======================================================================
+ // Methods only for PairDStream's
+ // =======================================================================
+
+ /**
+ * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
+ * generate the RDDs with Spark's default number of partitions.
+ */
+ def groupByKey(): JavaPairDStream[K, JList[V]] =
+ dstream.groupByKey().mapValues(seqAsJavaList _)
+
+ /**
+ * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to
+ * generate the RDDs with `numPartitions` partitions.
+ */
+ def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] =
+ dstream.groupByKey(numPartitions).mapValues(seqAsJavaList _)
+
+ /**
+ * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream.
+ * Therefore, the values for each key in `this` DStream's RDDs are grouped into a
+ * single sequence to generate the RDDs of the new DStream. [[spark.Partitioner]]
+ * is used to control the partitioning of each RDD.
+ */
+ def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JList[V]] =
+ dstream.groupByKey(partitioner).mapValues(seqAsJavaList _)
+
+ /**
+ * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * merged using the associative reduce function. Hash partitioning is used to generate the RDDs
+ * with Spark's default number of partitions.
+ */
+ def reduceByKey(func: JFunction2[V, V, V]): JavaPairDStream[K, V] =
+ dstream.reduceByKey(func)
+
+ /**
+ * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs
+ * with `numPartitions` partitions.
+ */
+ def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairDStream[K, V] =
+ dstream.reduceByKey(func, numPartitions)
+
+ /**
+ * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are
+ * merged using the supplied reduce function. [[spark.Partitioner]] is used to control the
+ * partitioning of each RDD.
+ */
+ def reduceByKey(func: JFunction2[V, V, V], partitioner: Partitioner): JavaPairDStream[K, V] = {
+ dstream.reduceByKey(func, partitioner)
+ }
+
+ /**
+ * Combine elements of each key in DStream's RDDs using custom function. This is similar to the
+ * combineByKey for RDDs. Please refer to combineByKey in [[spark.PairRDDFunctions]] for more
+ * information.
+ */
+ def combineByKey[C](createCombiner: JFunction[V, C],
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C],
+ partitioner: Partitioner
+ ): JavaPairDStream[K, C] = {
+ implicit val cm: ClassManifest[C] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner)
+ }
+
+ /**
+ * Create a new DStream by counting the number of values of each key in each RDD. Hash
+ * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions.
+ */
+ def countByKey(numPartitions: Int): JavaPairDStream[K, JLong] = {
+ JavaPairDStream.scalaToJavaLong(dstream.countByKey(numPartitions));
+ }
+
+
+ /**
+ * Create a new DStream by counting the number of values of each key in each RDD. Hash
+ * partitioning is used to generate the RDDs with the default number of partitions.
+ */
+ def countByKey(): JavaPairDStream[K, JLong] = {
+ JavaPairDStream.scalaToJavaLong(dstream.countByKey());
+ }
+
+ /**
+ * Creates a new DStream by applying `groupByKey` over a sliding window. This is similar to
+ * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs
+ * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with
+ * Spark's default number of partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ */
+ def groupByKeyAndWindow(windowDuration: Duration): JavaPairDStream[K, JList[V]] = {
+ dstream.groupByKeyAndWindow(windowDuration).mapValues(seqAsJavaList _)
+ }
+
+ /**
+ * Create a new DStream by applying `groupByKey` over a sliding window. Similar to
+ * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to
+ * generate the RDDs with Spark's default number of partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration)
+ : JavaPairDStream[K, JList[V]] = {
+ dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(seqAsJavaList _)
+ }
+
+ /**
+ * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream.
+ * Similar to `DStream.groupByKey()`, but applies it over a sliding window.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param numPartitions Number of partitions of each RDD in the new DStream.
+ */
+ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int)
+ :JavaPairDStream[K, JList[V]] = {
+ dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions)
+ .mapValues(seqAsJavaList _)
+ }
+
+ /**
+ * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream.
+ * Similar to `DStream.groupByKey()`, but applies it over a sliding window.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ */
+ def groupByKeyAndWindow(
+ windowDuration: Duration,
+ slideDuration: Duration,
+ partitioner: Partitioner
+ ):JavaPairDStream[K, JList[V]] = {
+ dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner)
+ .mapValues(seqAsJavaList _)
+ }
+
+ /**
+ * Create a new DStream by applying `reduceByKey` over a sliding window on `this` DStream.
+ * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream
+ * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate
+ * the RDDs with Spark's default number of partitions.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ */
+ def reduceByKeyAndWindow(reduceFunc: Function2[V, V, V], windowDuration: Duration)
+ :JavaPairDStream[K, V] = {
+ dstream.reduceByKeyAndWindow(reduceFunc, windowDuration)
+ }
+
+ /**
+ * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to
+ * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to
+ * generate the RDDs with Spark's default number of partitions.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: Function2[V, V, V],
+ windowDuration: Duration,
+ slideDuration: Duration
+ ):JavaPairDStream[K, V] = {
+ dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration)
+ }
+
+ /**
+ * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to
+ * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to
+ * generate the RDDs with `numPartitions` partitions.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param numPartitions Number of partitions of each RDD in the new DStream.
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: Function2[V, V, V],
+ windowDuration: Duration,
+ slideDuration: Duration,
+ numPartitions: Int
+ ): JavaPairDStream[K, V] = {
+ dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, numPartitions)
+ }
+
+ /**
+ * Create a new DStream by applying `reduceByKey` over a sliding window. Similar to
+ * `DStream.reduceByKey()`, but applies it over a sliding window.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: Function2[V, V, V],
+ windowDuration: Duration,
+ slideDuration: Duration,
+ partitioner: Partitioner
+ ): JavaPairDStream[K, V] = {
+ dstream.reduceByKeyAndWindow(reduceFunc, windowDuration, slideDuration, partitioner)
+ }
+
+ /**
+ * Create a new DStream by reducing over a using incremental computation.
+ * The reduced value of over a new window is calculated using the old window's reduce value :
+ * 1. reduce the new values that entered the window (e.g., adding new counts)
+ * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+ * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function.
+ * However, it is applicable to only "invertible reduce functions".
+ * Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
+ * @param reduceFunc associative reduce function
+ * @param invReduceFunc inverse function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: Function2[V, V, V],
+ invReduceFunc: Function2[V, V, V],
+ windowDuration: Duration,
+ slideDuration: Duration
+ ): JavaPairDStream[K, V] = {
+ dstream.reduceByKeyAndWindow(reduceFunc, invReduceFunc, windowDuration, slideDuration)
+ }
+
+ /**
+ * Create a new DStream by reducing over a using incremental computation.
+ * The reduced value of over a new window is calculated using the old window's reduce value :
+ * 1. reduce the new values that entered the window (e.g., adding new counts)
+ * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+ * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function.
+ * However, it is applicable to only "invertible reduce functions".
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ * @param reduceFunc associative reduce function
+ * @param invReduceFunc inverse function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param numPartitions Number of partitions of each RDD in the new DStream.
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: Function2[V, V, V],
+ invReduceFunc: Function2[V, V, V],
+ windowDuration: Duration,
+ slideDuration: Duration,
+ numPartitions: Int
+ ): JavaPairDStream[K, V] = {
+ dstream.reduceByKeyAndWindow(
+ reduceFunc,
+ invReduceFunc,
+ windowDuration,
+ slideDuration,
+ numPartitions)
+ }
+
+ /**
+ * Create a new DStream by reducing over a using incremental computation.
+ * The reduced value of over a new window is calculated using the old window's reduce value :
+ * 1. reduce the new values that entered the window (e.g., adding new counts)
+ * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+ * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function.
+ * However, it is applicable to only "invertible reduce functions".
+ * @param reduceFunc associative reduce function
+ * @param invReduceFunc inverse function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ */
+ def reduceByKeyAndWindow(
+ reduceFunc: Function2[V, V, V],
+ invReduceFunc: Function2[V, V, V],
+ windowDuration: Duration,
+ slideDuration: Duration,
+ partitioner: Partitioner
+ ): JavaPairDStream[K, V] = {
+ dstream.reduceByKeyAndWindow(
+ reduceFunc,
+ invReduceFunc,
+ windowDuration,
+ slideDuration,
+ partitioner)
+ }
+
+ /**
+ * Create a new DStream by counting the number of values for each key over a window.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration)
+ : JavaPairDStream[K, JLong] = {
+ JavaPairDStream.scalaToJavaLong(dstream.countByKeyAndWindow(windowDuration, slideDuration))
+ }
+
+ /**
+ * Create a new DStream by counting the number of values for each key over a window.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ * @param numPartitions Number of partitions of each RDD in the new DStream.
+ */
+ def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int)
+ : JavaPairDStream[K, Long] = {
+ dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions)
+ }
+
+ private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]):
+ (Seq[V], Option[S]) => Option[S] = {
+ val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => {
+ val list: JList[V] = values
+ val scalaState: Optional[S] = state match {
+ case Some(s) => Optional.of(s)
+ case _ => Optional.absent()
+ }
+ val result: Optional[S] = in.apply(list, scalaState)
+ result.isPresent match {
+ case true => Some(result.get())
+ case _ => None
+ }
+ }
+ scalaFunc
+ }
+
+ /**
+ * Create a new "state" DStream where the state for each key is updated by applying
+ * the given function on the previous state of the key and the new values of each key.
+ * Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
+ * @param updateFunc State update function. If `this` function returns None, then
+ * corresponding state key-value pair will be eliminated.
+ * @tparam S State type
+ */
+ def updateStateByKey[S](updateFunc: JFunction2[JList[V], Optional[S], Optional[S]])
+ : JavaPairDStream[K, S] = {
+ implicit val cm: ClassManifest[S] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[S]]
+ dstream.updateStateByKey(convertUpdateStateFunction(updateFunc))
+ }
+
+ /**
+ * Create a new "state" DStream where the state for each key is updated by applying
+ * the given function on the previous state of the key and the new values of each key.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ * @param updateFunc State update function. If `this` function returns None, then
+ * corresponding state key-value pair will be eliminated.
+ * @param numPartitions Number of partitions of each RDD in the new DStream.
+ * @tparam S State type
+ */
+ def updateStateByKey[S: ClassManifest](
+ updateFunc: JFunction2[JList[V], Optional[S], Optional[S]],
+ numPartitions: Int)
+ : JavaPairDStream[K, S] = {
+ dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), numPartitions)
+ }
+
+ /**
+ * Create a new "state" DStream where the state for each key is updated by applying
+ * the given function on the previous state of the key and the new values of the key.
+ * [[spark.Partitioner]] is used to control the partitioning of each RDD.
+ * @param updateFunc State update function. If `this` function returns None, then
+ * corresponding state key-value pair will be eliminated.
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
+ * @tparam S State type
+ */
+ def updateStateByKey[S: ClassManifest](
+ updateFunc: JFunction2[JList[V], Optional[S], Optional[S]],
+ partitioner: Partitioner
+ ): JavaPairDStream[K, S] = {
+ dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner)
+ }
+
+ def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = {
+ implicit val cm: ClassManifest[U] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ dstream.mapValues(f)
+ }
+
+ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = {
+ import scala.collection.JavaConverters._
+ def fn = (x: V) => f.apply(x).asScala
+ implicit val cm: ClassManifest[U] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ dstream.flatMapValues(fn)
+ }
+
+ /**
+ * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this`
+ * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
+ * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number
+ * of partitions.
+ */
+ def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JList[V], JList[W])] = {
+ implicit val cm: ClassManifest[W] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]]
+ dstream.cogroup(other.dstream).mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2))))
+ }
+
+ /**
+ * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this`
+ * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
+ * key in both RDDs. Partitioner is used to partition each generated RDD.
+ */
+ def cogroup[W](other: JavaPairDStream[K, W], partitioner: Partitioner)
+ : JavaPairDStream[K, (JList[V], JList[W])] = {
+ implicit val cm: ClassManifest[W] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]]
+ dstream.cogroup(other.dstream, partitioner)
+ .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2))))
+ }
+
+ /**
+ * Join `this` DStream with `other` DStream. HashPartitioner is used
+ * to partition each generated RDD into default number of partitions.
+ */
+ def join[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, W)] = {
+ implicit val cm: ClassManifest[W] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]]
+ dstream.join(other.dstream)
+ }
+
+ /**
+ * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will
+ * be generated by joining RDDs from `this` and other DStream. Uses the given
+ * Partitioner to partition each generated RDD.
+ */
+ def join[W](other: JavaPairDStream[K, W], partitioner: Partitioner)
+ : JavaPairDStream[K, (V, W)] = {
+ implicit val cm: ClassManifest[W] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]]
+ dstream.join(other.dstream, partitioner)
+ }
+
+ /**
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
+ * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
+ */
+ def saveAsHadoopFiles[F <: OutputFormat[K, V]](prefix: String, suffix: String) {
+ dstream.saveAsHadoopFiles(prefix, suffix)
+ }
+
+ /**
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
+ * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
+ */
+ def saveAsHadoopFiles(
+ prefix: String,
+ suffix: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: OutputFormat[_, _]]) {
+ dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass)
+ }
+
+ /**
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
+ * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
+ */
+ def saveAsHadoopFiles(
+ prefix: String,
+ suffix: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: OutputFormat[_, _]],
+ conf: JobConf) {
+ dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf)
+ }
+
+ /**
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
+ * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
+ */
+ def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](prefix: String, suffix: String) {
+ dstream.saveAsNewAPIHadoopFiles(prefix, suffix)
+ }
+
+ /**
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
+ * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
+ */
+ def saveAsNewAPIHadoopFiles(
+ prefix: String,
+ suffix: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) {
+ dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass)
+ }
+
+ /**
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
+ * generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
+ */
+ def saveAsNewAPIHadoopFiles(
+ prefix: String,
+ suffix: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
+ conf: Configuration = new Configuration) {
+ dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf)
+ }
+
+ override val classManifest: ClassManifest[(K, V)] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+}
+
+object JavaPairDStream {
+ implicit def fromPairDStream[K: ClassManifest, V: ClassManifest](dstream: DStream[(K, V)])
+ :JavaPairDStream[K, V] =
+ new JavaPairDStream[K, V](dstream)
+
+ def fromJavaDStream[K, V](dstream: JavaDStream[(K, V)]): JavaPairDStream[K, V] = {
+ implicit val cmk: ClassManifest[K] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val cmv: ClassManifest[V] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ new JavaPairDStream[K, V](dstream.dstream)
+ }
+
+ def scalaToJavaLong[K: ClassManifest](dstream: JavaPairDStream[K, Long])
+ : JavaPairDStream[K, JLong] = {
+ StreamingContext.toPairDStreamFunctions(dstream.dstream).mapValues(new JLong(_))
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
new file mode 100644
index 0000000000..f82e6a37cc
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala
@@ -0,0 +1,346 @@
+package spark.streaming.api.java
+
+import scala.collection.JavaConversions._
+import java.lang.{Long => JLong, Integer => JInt}
+
+import spark.streaming._
+import dstream._
+import spark.storage.StorageLevel
+import spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+import java.io.InputStream
+import java.util.{Map => JMap}
+import spark.api.java.{JavaSparkContext, JavaRDD}
+
+/**
+ * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
+ * information (such as, cluster URL and job name) to internally create a SparkContext, it provides
+ * methods used to create DStream from various input sources.
+ */
+class JavaStreamingContext(val ssc: StreamingContext) {
+
+ // TODOs:
+ // - Test to/from Hadoop functions
+ // - Support creating and registering InputStreams
+
+
+ /**
+ * Creates a StreamingContext.
+ * @param master Name of the Spark Master
+ * @param frameworkName Name to be used when registering with the scheduler
+ * @param batchDuration The time interval at which streaming data will be divided into batches
+ */
+ def this(master: String, frameworkName: String, batchDuration: Duration) =
+ this(new StreamingContext(master, frameworkName, batchDuration))
+
+ /**
+ * Re-creates a StreamingContext from a checkpoint file.
+ * @param path Path either to the directory that was specified as the checkpoint directory, or
+ * to the checkpoint file 'graph' or 'graph.bk'.
+ */
+ def this(path: String) = this (new StreamingContext(path))
+
+ /** The underlying SparkContext */
+ val sc: JavaSparkContext = new JavaSparkContext(ssc.sc)
+
+ /**
+ * Create an input stream that pulls messages form a Kafka Broker.
+ * @param hostname Zookeper hostname.
+ * @param port Zookeper port.
+ * @param groupId The group id for this consumer.
+ * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+ * in its own thread.
+ */
+ def kafkaStream[T](
+ hostname: String,
+ port: Int,
+ groupId: String,
+ topics: JMap[String, JInt])
+ : JavaDStream[T] = {
+ implicit val cmt: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ ssc.kafkaStream[T](hostname, port, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*))
+ }
+
+ /**
+ * Create an input stream that pulls messages form a Kafka Broker.
+ * @param hostname Zookeper hostname.
+ * @param port Zookeper port.
+ * @param groupId The group id for this consumer.
+ * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+ * in its own thread.
+ * @param initialOffsets Optional initial offsets for each of the partitions to consume.
+ * By default the value is pulled from zookeper.
+ */
+ def kafkaStream[T](
+ hostname: String,
+ port: Int,
+ groupId: String,
+ topics: JMap[String, JInt],
+ initialOffsets: JMap[KafkaPartitionKey, JLong])
+ : JavaDStream[T] = {
+ implicit val cmt: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ ssc.kafkaStream[T](
+ hostname,
+ port,
+ groupId,
+ Map(topics.mapValues(_.intValue()).toSeq: _*),
+ Map(initialOffsets.mapValues(_.longValue()).toSeq: _*))
+ }
+
+ /**
+ * Create an input stream that pulls messages form a Kafka Broker.
+ * @param hostname Zookeper hostname.
+ * @param port Zookeper port.
+ * @param groupId The group id for this consumer.
+ * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+ * in its own thread.
+ * @param initialOffsets Optional initial offsets for each of the partitions to consume.
+ * By default the value is pulled from zookeper.
+ * @param storageLevel RDD storage level. Defaults to memory-only
+ */
+ def kafkaStream[T](
+ hostname: String,
+ port: Int,
+ groupId: String,
+ topics: JMap[String, JInt],
+ initialOffsets: JMap[KafkaPartitionKey, JLong],
+ storageLevel: StorageLevel)
+ : JavaDStream[T] = {
+ implicit val cmt: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ ssc.kafkaStream[T](
+ hostname,
+ port,
+ groupId,
+ Map(topics.mapValues(_.intValue()).toSeq: _*),
+ Map(initialOffsets.mapValues(_.longValue()).toSeq: _*),
+ storageLevel)
+ }
+
+ /**
+ * Create a input stream from network source hostname:port. Data is received using
+ * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited
+ * lines.
+ * @param hostname Hostname to connect to for receiving data
+ * @param port Port to connect to for receiving data
+ * @param storageLevel Storage level to use for storing the received objects
+ * (default: StorageLevel.MEMORY_AND_DISK_SER_2)
+ */
+ def networkTextStream(hostname: String, port: Int, storageLevel: StorageLevel)
+ : JavaDStream[String] = {
+ ssc.networkTextStream(hostname, port, storageLevel)
+ }
+
+ /**
+ * Create a input stream from network source hostname:port. Data is received using
+ * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited
+ * lines.
+ * @param hostname Hostname to connect to for receiving data
+ * @param port Port to connect to for receiving data
+ */
+ def networkTextStream(hostname: String, port: Int): JavaDStream[String] = {
+ ssc.networkTextStream(hostname, port)
+ }
+
+ /**
+ * Create a input stream from network source hostname:port. Data is received using
+ * a TCP socket and the receive bytes it interepreted as object using the given
+ * converter.
+ * @param hostname Hostname to connect to for receiving data
+ * @param port Port to connect to for receiving data
+ * @param converter Function to convert the byte stream to objects
+ * @param storageLevel Storage level to use for storing the received objects
+ * @tparam T Type of the objects received (after converting bytes to objects)
+ */
+ def networkStream[T](
+ hostname: String,
+ port: Int,
+ converter: JFunction[InputStream, java.lang.Iterable[T]],
+ storageLevel: StorageLevel)
+ : JavaDStream[T] = {
+ def fn = (x: InputStream) => converter.apply(x).toIterator
+ implicit val cmt: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ ssc.networkStream(hostname, port, fn, storageLevel)
+ }
+
+ /**
+ * Creates a input stream that monitors a Hadoop-compatible filesystem
+ * for new files and reads them as text files (using key as LongWritable, value
+ * as Text and input format as TextInputFormat). File names starting with . are ignored.
+ * @param directory HDFS directory to monitor for new file
+ */
+ def textFileStream(directory: String): JavaDStream[String] = {
+ ssc.textFileStream(directory)
+ }
+
+ /**
+ * Create a input stream from network source hostname:port, where data is received
+ * as serialized blocks (serialized using the Spark's serializer) that can be directly
+ * pushed into the block manager without deserializing them. This is the most efficient
+ * way to receive data.
+ * @param hostname Hostname to connect to for receiving data
+ * @param port Port to connect to for receiving data
+ * @param storageLevel Storage level to use for storing the received objects
+ * @tparam T Type of the objects in the received blocks
+ */
+ def rawNetworkStream[T](
+ hostname: String,
+ port: Int,
+ storageLevel: StorageLevel): JavaDStream[T] = {
+ implicit val cmt: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ JavaDStream.fromDStream(ssc.rawNetworkStream(hostname, port, storageLevel))
+ }
+
+ /**
+ * Create a input stream from network source hostname:port, where data is received
+ * as serialized blocks (serialized using the Spark's serializer) that can be directly
+ * pushed into the block manager without deserializing them. This is the most efficient
+ * way to receive data.
+ * @param hostname Hostname to connect to for receiving data
+ * @param port Port to connect to for receiving data
+ * @tparam T Type of the objects in the received blocks
+ */
+ def rawNetworkStream[T](hostname: String, port: Int): JavaDStream[T] = {
+ implicit val cmt: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ JavaDStream.fromDStream(ssc.rawNetworkStream(hostname, port))
+ }
+
+ /**
+ * Creates a input stream that monitors a Hadoop-compatible filesystem
+ * for new files and reads them using the given key-value types and input format.
+ * File names starting with . are ignored.
+ * @param directory HDFS directory to monitor for new file
+ * @tparam K Key type for reading HDFS file
+ * @tparam V Value type for reading HDFS file
+ * @tparam F Input format for reading HDFS file
+ */
+ def fileStream[K, V, F <: NewInputFormat[K, V]](directory: String): JavaPairDStream[K, V] = {
+ implicit val cmk: ClassManifest[K] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val cmv: ClassManifest[V] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ implicit val cmf: ClassManifest[F] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[F]]
+ ssc.fileStream[K, V, F](directory);
+ }
+
+ /**
+ * Creates a input stream from a Flume source.
+ * @param hostname Hostname of the slave machine to which the flume data will be sent
+ * @param port Port of the slave machine to which the flume data will be sent
+ * @param storageLevel Storage level to use for storing the received objects
+ */
+ def flumeStream(hostname: String, port: Int, storageLevel: StorageLevel):
+ JavaDStream[SparkFlumeEvent] = {
+ ssc.flumeStream(hostname, port, storageLevel)
+ }
+
+
+ /**
+ * Creates a input stream from a Flume source.
+ * @param hostname Hostname of the slave machine to which the flume data will be sent
+ * @param port Port of the slave machine to which the flume data will be sent
+ */
+ def flumeStream(hostname: String, port: Int):
+ JavaDStream[SparkFlumeEvent] = {
+ ssc.flumeStream(hostname, port)
+ }
+
+ /**
+ * Registers an output stream that will be computed every interval
+ */
+ def registerOutputStream(outputStream: JavaDStreamLike[_, _]) {
+ ssc.registerOutputStream(outputStream.dstream)
+ }
+
+ /**
+ * Creates a input stream from an queue of RDDs. In each batch,
+ * it will process either one or all of the RDDs returned by the queue.
+ *
+ * NOTE: changes to the queue after the stream is created will not be recognized.
+ * @param queue Queue of RDDs
+ * @tparam T Type of objects in the RDD
+ */
+ def queueStream[T](queue: java.util.Queue[JavaRDD[T]]): JavaDStream[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ val sQueue = new scala.collection.mutable.Queue[spark.RDD[T]]
+ sQueue.enqueue(queue.map(_.rdd).toSeq: _*)
+ ssc.queueStream(sQueue)
+ }
+
+ /**
+ * Creates a input stream from an queue of RDDs. In each batch,
+ * it will process either one or all of the RDDs returned by the queue.
+ *
+ * NOTE: changes to the queue after the stream is created will not be recognized.
+ * @param queue Queue of RDDs
+ * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval
+ * @tparam T Type of objects in the RDD
+ */
+ def queueStream[T](queue: java.util.Queue[JavaRDD[T]], oneAtATime: Boolean): JavaDStream[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ val sQueue = new scala.collection.mutable.Queue[spark.RDD[T]]
+ sQueue.enqueue(queue.map(_.rdd).toSeq: _*)
+ ssc.queueStream(sQueue, oneAtATime)
+ }
+
+ /**
+ * Creates a input stream from an queue of RDDs. In each batch,
+ * it will process either one or all of the RDDs returned by the queue.
+ *
+ * NOTE: changes to the queue after the stream is created will not be recognized.
+ * @param queue Queue of RDDs
+ * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval
+ * @param defaultRDD Default RDD is returned by the DStream when the queue is empty
+ * @tparam T Type of objects in the RDD
+ */
+ def queueStream[T](
+ queue: java.util.Queue[JavaRDD[T]],
+ oneAtATime: Boolean,
+ defaultRDD: JavaRDD[T]): JavaDStream[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ val sQueue = new scala.collection.mutable.Queue[spark.RDD[T]]
+ sQueue.enqueue(queue.map(_.rdd).toSeq: _*)
+ ssc.queueStream(sQueue, oneAtATime, defaultRDD.rdd)
+ }
+
+ /**
+ * Sets the context to periodically checkpoint the DStream operations for master
+ * fault-tolerance. By default, the graph will be checkpointed every batch interval.
+ * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored
+ * @param interval checkpoint interval
+ */
+ def checkpoint(directory: String, interval: Duration = null) {
+ ssc.checkpoint(directory, interval)
+ }
+
+ /**
+ * Sets each DStreams in this context to remember RDDs it generated in the last given duration.
+ * DStreams remember RDDs only for a limited duration of duration and releases them for garbage
+ * collection. This method allows the developer to specify how to long to remember the RDDs (
+ * if the developer wishes to query old data outside the DStream computation).
+ * @param duration Minimum duration that each DStream should remember its RDDs
+ */
+ def remember(duration: Duration) {
+ ssc.remember(duration)
+ }
+
+ /**
+ * Starts the execution of the streams.
+ */
+ def start() = ssc.start()
+
+ /**
+ * Sstops the execution of the streams.
+ */
+ def stop() = ssc.stop()
+
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala
new file mode 100644
index 0000000000..ddb1bf6b28
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala
@@ -0,0 +1,40 @@
+package spark.streaming.dstream
+
+import spark.{RDD, Partitioner}
+import spark.rdd.CoGroupedRDD
+import spark.streaming.{Time, DStream, Duration}
+
+private[streaming]
+class CoGroupedDStream[K : ClassManifest](
+ parents: Seq[DStream[(_, _)]],
+ partitioner: Partitioner
+ ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) {
+
+ if (parents.length == 0) {
+ throw new IllegalArgumentException("Empty array of parents")
+ }
+
+ if (parents.map(_.ssc).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different StreamingContexts")
+ }
+
+ if (parents.map(_.slideDuration).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different slide times")
+ }
+
+ override def dependencies = parents.toList
+
+ override def slideDuration: Duration = parents.head.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = {
+ val part = partitioner
+ val rdds = parents.flatMap(_.getOrCompute(validTime))
+ if (rdds.size > 0) {
+ val q = new CoGroupedRDD[K](rdds, part)
+ Some(q)
+ } else {
+ None
+ }
+ }
+
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala
new file mode 100644
index 0000000000..41c3af4694
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/ConstantInputDStream.scala
@@ -0,0 +1,19 @@
+package spark.streaming.dstream
+
+import spark.RDD
+import spark.streaming.{Time, StreamingContext}
+
+/**
+ * An input stream that always returns the same RDD on each timestep. Useful for testing.
+ */
+class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T])
+ extends InputDStream[T](ssc_) {
+
+ override def start() {}
+
+ override def stop() {}
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ Some(rdd)
+ }
+} \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala
new file mode 100644
index 0000000000..1e6ad84b44
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala
@@ -0,0 +1,102 @@
+package spark.streaming.dstream
+
+import spark.RDD
+import spark.rdd.UnionRDD
+import spark.streaming.{StreamingContext, Time}
+
+import org.apache.hadoop.fs.{FileSystem, Path, PathFilter}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+
+import scala.collection.mutable.HashSet
+
+private[streaming]
+class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest](
+ @transient ssc_ : StreamingContext,
+ directory: String,
+ filter: Path => Boolean = FileInputDStream.defaultFilter,
+ newFilesOnly: Boolean = true)
+ extends InputDStream[(K, V)](ssc_) {
+
+ @transient private var path_ : Path = null
+ @transient private var fs_ : FileSystem = null
+
+ var lastModTime = 0L
+ val lastModTimeFiles = new HashSet[String]()
+
+ def path(): Path = {
+ if (path_ == null) path_ = new Path(directory)
+ path_
+ }
+
+ def fs(): FileSystem = {
+ if (fs_ == null) fs_ = path.getFileSystem(new Configuration())
+ fs_
+ }
+
+ override def start() {
+ if (newFilesOnly) {
+ lastModTime = System.currentTimeMillis()
+ } else {
+ lastModTime = 0
+ }
+ }
+
+ override def stop() { }
+
+ /**
+ * Finds the files that were modified since the last time this method was called and makes
+ * a union RDD out of them. Note that this maintains the list of files that were processed
+ * in the latest modification time in the previous call to this method. This is because the
+ * modification time returned by the FileStatus API seems to return times only at the
+ * granularity of seconds. Hence, new files may have the same modification time as the
+ * latest modification time in the previous call to this method and the list of files
+ * maintained is used to filter the one that have been processed.
+ */
+ override def compute(validTime: Time): Option[RDD[(K, V)]] = {
+ // Create the filter for selecting new files
+ val newFilter = new PathFilter() {
+ var latestModTime = 0L
+ val latestModTimeFiles = new HashSet[String]()
+
+ def accept(path: Path): Boolean = {
+ if (!filter(path)) {
+ return false
+ } else {
+ val modTime = fs.getFileStatus(path).getModificationTime()
+ if (modTime < lastModTime){
+ return false
+ } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) {
+ return false
+ }
+ if (modTime > latestModTime) {
+ latestModTime = modTime
+ latestModTimeFiles.clear()
+ }
+ latestModTimeFiles += path.toString
+ return true
+ }
+ }
+ }
+
+ val newFiles = fs.listStatus(path, newFilter)
+ logInfo("New files: " + newFiles.map(_.getPath).mkString(", "))
+ if (newFiles.length > 0) {
+ // Update the modification time and the files processed for that modification time
+ if (lastModTime != newFilter.latestModTime) {
+ lastModTime = newFilter.latestModTime
+ lastModTimeFiles.clear()
+ }
+ lastModTimeFiles ++= newFilter.latestModTimeFiles
+ }
+ val newRDD = new UnionRDD(ssc.sc, newFiles.map(
+ file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString)))
+ Some(newRDD)
+ }
+}
+
+private[streaming]
+object FileInputDStream {
+ def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".")
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala
new file mode 100644
index 0000000000..e993164f99
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/FilteredDStream.scala
@@ -0,0 +1,21 @@
+package spark.streaming.dstream
+
+import spark.streaming.{Duration, DStream, Time}
+import spark.RDD
+
+private[streaming]
+class FilteredDStream[T: ClassManifest](
+ parent: DStream[T],
+ filterFunc: T => Boolean
+ ) extends DStream[T](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ parent.getOrCompute(validTime).map(_.filter(filterFunc))
+ }
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala
new file mode 100644
index 0000000000..cabd34f5f2
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMapValuedDStream.scala
@@ -0,0 +1,20 @@
+package spark.streaming.dstream
+
+import spark.streaming.{Duration, DStream, Time}
+import spark.RDD
+import spark.SparkContext._
+
+private[streaming]
+class FlatMapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
+ parent: DStream[(K, V)],
+ flatMapValueFunc: V => TraversableOnce[U]
+ ) extends DStream[(K, U)](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[(K, U)]] = {
+ parent.getOrCompute(validTime).map(_.flatMapValues[U](flatMapValueFunc))
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala
new file mode 100644
index 0000000000..a69af60589
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/FlatMappedDStream.scala
@@ -0,0 +1,20 @@
+package spark.streaming.dstream
+
+import spark.streaming.{Duration, DStream, Time}
+import spark.RDD
+
+private[streaming]
+class FlatMappedDStream[T: ClassManifest, U: ClassManifest](
+ parent: DStream[T],
+ flatMapFunc: T => Traversable[U]
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(_.flatMap(flatMapFunc))
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala
new file mode 100644
index 0000000000..efc7058480
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala
@@ -0,0 +1,137 @@
+package spark.streaming.dstream
+
+import spark.streaming.StreamingContext
+
+import spark.Utils
+import spark.storage.StorageLevel
+
+import org.apache.flume.source.avro.AvroSourceProtocol
+import org.apache.flume.source.avro.AvroFlumeEvent
+import org.apache.flume.source.avro.Status
+import org.apache.avro.ipc.specific.SpecificResponder
+import org.apache.avro.ipc.NettyServer
+
+import scala.collection.JavaConversions._
+
+import java.net.InetSocketAddress
+import java.io.{ObjectInput, ObjectOutput, Externalizable}
+import java.nio.ByteBuffer
+
+private[streaming]
+class FlumeInputDStream[T: ClassManifest](
+ @transient ssc_ : StreamingContext,
+ host: String,
+ port: Int,
+ storageLevel: StorageLevel
+) extends NetworkInputDStream[SparkFlumeEvent](ssc_) {
+
+ override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = {
+ new FlumeReceiver(host, port, storageLevel)
+ }
+}
+
+/**
+ * A wrapper class for AvroFlumeEvent's with a custom serialization format.
+ *
+ * This is necessary because AvroFlumeEvent uses inner data structures
+ * which are not serializable.
+ */
+class SparkFlumeEvent() extends Externalizable {
+ var event : AvroFlumeEvent = new AvroFlumeEvent()
+
+ /* De-serialize from bytes. */
+ def readExternal(in: ObjectInput) {
+ val bodyLength = in.readInt()
+ val bodyBuff = new Array[Byte](bodyLength)
+ in.read(bodyBuff)
+
+ val numHeaders = in.readInt()
+ val headers = new java.util.HashMap[CharSequence, CharSequence]
+
+ for (i <- 0 until numHeaders) {
+ val keyLength = in.readInt()
+ val keyBuff = new Array[Byte](keyLength)
+ in.read(keyBuff)
+ val key : String = Utils.deserialize(keyBuff)
+
+ val valLength = in.readInt()
+ val valBuff = new Array[Byte](valLength)
+ in.read(valBuff)
+ val value : String = Utils.deserialize(valBuff)
+
+ headers.put(key, value)
+ }
+
+ event.setBody(ByteBuffer.wrap(bodyBuff))
+ event.setHeaders(headers)
+ }
+
+ /* Serialize to bytes. */
+ def writeExternal(out: ObjectOutput) {
+ val body = event.getBody.array()
+ out.writeInt(body.length)
+ out.write(body)
+
+ val numHeaders = event.getHeaders.size()
+ out.writeInt(numHeaders)
+ for ((k, v) <- event.getHeaders) {
+ val keyBuff = Utils.serialize(k.toString)
+ out.writeInt(keyBuff.length)
+ out.write(keyBuff)
+ val valBuff = Utils.serialize(v.toString)
+ out.writeInt(valBuff.length)
+ out.write(valBuff)
+ }
+ }
+}
+
+private[streaming] object SparkFlumeEvent {
+ def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = {
+ val event = new SparkFlumeEvent
+ event.event = in
+ event
+ }
+}
+
+/** A simple server that implements Flume's Avro protocol. */
+private[streaming]
+class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol {
+ override def append(event : AvroFlumeEvent) : Status = {
+ receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event)
+ Status.OK
+ }
+
+ override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = {
+ events.foreach (event =>
+ receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event))
+ Status.OK
+ }
+}
+
+/** A NetworkReceiver which listens for events using the
+ * Flume Avro interface.*/
+private[streaming]
+class FlumeReceiver(
+ host: String,
+ port: Int,
+ storageLevel: StorageLevel
+ ) extends NetworkReceiver[SparkFlumeEvent] {
+
+ lazy val blockGenerator = new BlockGenerator(storageLevel)
+
+ protected override def onStart() {
+ val responder = new SpecificResponder(
+ classOf[AvroSourceProtocol], new FlumeEventServer(this));
+ val server = new NettyServer(responder, new InetSocketAddress(host, port));
+ blockGenerator.start()
+ server.start()
+ logInfo("Flume receiver started")
+ }
+
+ protected override def onStop() {
+ blockGenerator.stop()
+ logInfo("Flume receiver stopped")
+ }
+
+ override def getLocationPreference = Some(host)
+} \ No newline at end of file
diff --git a/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala
new file mode 100644
index 0000000000..ee69ea5177
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/ForEachDStream.scala
@@ -0,0 +1,28 @@
+package spark.streaming.dstream
+
+import spark.RDD
+import spark.streaming.{Duration, DStream, Job, Time}
+
+private[streaming]
+class ForEachDStream[T: ClassManifest] (
+ parent: DStream[T],
+ foreachFunc: (RDD[T], Time) => Unit
+ ) extends DStream[Unit](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[Unit]] = None
+
+ override def generateJob(time: Time): Option[Job] = {
+ parent.getOrCompute(time) match {
+ case Some(rdd) =>
+ val jobFunc = () => {
+ foreachFunc(rdd, time)
+ }
+ Some(new Job(time, jobFunc))
+ case None => None
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala
new file mode 100644
index 0000000000..b589cbd4d5
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/GlommedDStream.scala
@@ -0,0 +1,17 @@
+package spark.streaming.dstream
+
+import spark.streaming.{Duration, DStream, Time}
+import spark.RDD
+
+private[streaming]
+class GlommedDStream[T: ClassManifest](parent: DStream[T])
+ extends DStream[Array[T]](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[Array[T]]] = {
+ parent.getOrCompute(validTime).map(_.glom())
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala
new file mode 100644
index 0000000000..980ca5177e
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala
@@ -0,0 +1,19 @@
+package spark.streaming.dstream
+
+import spark.streaming.{Duration, StreamingContext, DStream}
+
+abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext)
+ extends DStream[T](ssc_) {
+
+ override def dependencies = List()
+
+ override def slideDuration: Duration = {
+ if (ssc == null) throw new Exception("ssc is null")
+ if (ssc.graph.batchDuration == null) throw new Exception("batchDuration is null")
+ ssc.graph.batchDuration
+ }
+
+ def start()
+
+ def stop()
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
new file mode 100644
index 0000000000..2b4740bdf7
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala
@@ -0,0 +1,200 @@
+package spark.streaming.dstream
+
+import spark.Logging
+import spark.storage.StorageLevel
+import spark.streaming.{Time, DStreamCheckpointData, StreamingContext}
+
+import java.util.Properties
+import java.util.concurrent.Executors
+
+import kafka.consumer._
+import kafka.message.{Message, MessageSet, MessageAndMetadata}
+import kafka.serializer.StringDecoder
+import kafka.utils.{Utils, ZKGroupTopicDirs}
+import kafka.utils.ZkUtils._
+
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+
+
+// Key for a specific Kafka Partition: (broker, topic, group, part)
+case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int)
+// NOT USED - Originally intended for fault-tolerance
+// Metadata for a Kafka Stream that it sent to the Master
+private[streaming]
+case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long])
+// NOT USED - Originally intended for fault-tolerance
+// Checkpoint data specific to a KafkaInputDstream
+private[streaming]
+case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any],
+ savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds)
+
+/**
+ * Input stream that pulls messages from a Kafka Broker.
+ *
+ * @param host Zookeper hostname.
+ * @param port Zookeper port.
+ * @param groupId The group id for this consumer.
+ * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
+ * in its own thread.
+ * @param initialOffsets Optional initial offsets for each of the partitions to consume.
+ * By default the value is pulled from zookeper.
+ * @param storageLevel RDD storage level.
+ */
+private[streaming]
+class KafkaInputDStream[T: ClassManifest](
+ @transient ssc_ : StreamingContext,
+ host: String,
+ port: Int,
+ groupId: String,
+ topics: Map[String, Int],
+ initialOffsets: Map[KafkaPartitionKey, Long],
+ storageLevel: StorageLevel
+ ) extends NetworkInputDStream[T](ssc_ ) with Logging {
+
+ // Metadata that keeps track of which messages have already been consumed.
+ var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]()
+
+ /* NOT USED - Originally intended for fault-tolerance
+
+ // In case of a failure, the offets for a particular timestamp will be restored.
+ @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null
+
+
+ override protected[streaming] def addMetadata(metadata: Any) {
+ metadata match {
+ case x : KafkaInputDStreamMetadata =>
+ savedOffsets(x.timestamp) = x.data
+ // TOOD: Remove logging
+ logInfo("New saved Offsets: " + savedOffsets)
+ case _ => logInfo("Received unknown metadata: " + metadata.toString)
+ }
+ }
+
+ override protected[streaming] def updateCheckpointData(currentTime: Time) {
+ super.updateCheckpointData(currentTime)
+ if(savedOffsets.size > 0) {
+ // Find the offets that were stored before the checkpoint was initiated
+ val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last
+ val latestOffsets = savedOffsets(key)
+ logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString)
+ checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets)
+ // TODO: This may throw out offsets that are created after the checkpoint,
+ // but it's unlikely we'll need them.
+ savedOffsets.clear()
+ }
+ }
+
+ override protected[streaming] def restoreCheckpointData() {
+ super.restoreCheckpointData()
+ logInfo("Restoring KafkaDStream checkpoint data.")
+ checkpointData match {
+ case x : KafkaDStreamCheckpointData =>
+ restoredOffsets = x.savedOffsets
+ logInfo("Restored KafkaDStream offsets: " + savedOffsets)
+ }
+ } */
+
+ def createReceiver(): NetworkReceiver[T] = {
+ new KafkaReceiver(host, port, groupId, topics, initialOffsets, storageLevel)
+ .asInstanceOf[NetworkReceiver[T]]
+ }
+}
+
+private[streaming]
+class KafkaReceiver(host: String, port: Int, groupId: String,
+ topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long],
+ storageLevel: StorageLevel) extends NetworkReceiver[Any] {
+
+ // Timeout for establishing a connection to Zookeper in ms.
+ val ZK_TIMEOUT = 10000
+
+ // Handles pushing data into the BlockManager
+ lazy protected val blockGenerator = new BlockGenerator(storageLevel)
+ // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset
+ lazy val offsets = HashMap[KafkaPartitionKey, Long]()
+ // Connection to Kafka
+ var consumerConnector : ZookeeperConsumerConnector = null
+
+ def onStop() {
+ blockGenerator.stop()
+ }
+
+ def onStart() {
+
+ blockGenerator.start()
+
+ // In case we are using multiple Threads to handle Kafka Messages
+ val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
+
+ val zooKeeperEndPoint = host + ":" + port
+ logInfo("Starting Kafka Consumer Stream with group: " + groupId)
+ logInfo("Initial offsets: " + initialOffsets.toString)
+
+ // Zookeper connection properties
+ val props = new Properties()
+ props.put("zk.connect", zooKeeperEndPoint)
+ props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString)
+ props.put("groupid", groupId)
+
+ // Create the connection to the cluster
+ logInfo("Connecting to Zookeper: " + zooKeeperEndPoint)
+ val consumerConfig = new ConsumerConfig(props)
+ consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector]
+ logInfo("Connected to " + zooKeeperEndPoint)
+
+ // Reset the Kafka offsets in case we are recovering from a failure
+ resetOffsets(initialOffsets)
+
+ // Create Threads for each Topic/Message Stream we are listening
+ val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder())
+
+ // Start the messages handler for each partition
+ topicMessageStreams.values.foreach { streams =>
+ streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) }
+ }
+
+ }
+
+ // Overwrites the offets in Zookeper.
+ private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) {
+ offsets.foreach { case(key, offset) =>
+ val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic)
+ val partitionName = key.brokerId + "-" + key.partId
+ updatePersistentPath(consumerConnector.zkClient,
+ topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString)
+ }
+ }
+
+ // Handles Kafka Messages
+ private class MessageHandler(stream: KafkaStream[String]) extends Runnable {
+ def run() {
+ logInfo("Starting MessageHandler.")
+ stream.takeWhile { msgAndMetadata =>
+ blockGenerator += msgAndMetadata.message
+
+ // Updating the offet. The key is (broker, topic, group, partition).
+ val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic,
+ groupId, msgAndMetadata.topicInfo.partition.partId)
+ val offset = msgAndMetadata.topicInfo.getConsumeOffset
+ offsets.put(key, offset)
+ // logInfo("Handled message: " + (key, offset).toString)
+
+ // Keep on handling messages
+ true
+ }
+ }
+ }
+
+ // NOT USED - Originally intended for fault-tolerance
+ // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel)
+ // extends BufferingBlockCreator[Any](receiver, storageLevel) {
+
+ // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = {
+ // // Creates a new Block with Kafka-specific Metadata
+ // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap))
+ // }
+
+ // }
+
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala
new file mode 100644
index 0000000000..848afecfad
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/MapPartitionedDStream.scala
@@ -0,0 +1,21 @@
+package spark.streaming.dstream
+
+import spark.streaming.{Duration, DStream, Time}
+import spark.RDD
+
+private[streaming]
+class MapPartitionedDStream[T: ClassManifest, U: ClassManifest](
+ parent: DStream[T],
+ mapPartFunc: Iterator[T] => Iterator[U],
+ preservePartitioning: Boolean
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(_.mapPartitions[U](mapPartFunc, preservePartitioning))
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala
new file mode 100644
index 0000000000..6055aa6a05
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/MapValuedDStream.scala
@@ -0,0 +1,21 @@
+package spark.streaming.dstream
+
+import spark.streaming.{Duration, DStream, Time}
+import spark.RDD
+import spark.SparkContext._
+
+private[streaming]
+class MapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
+ parent: DStream[(K, V)],
+ mapValueFunc: V => U
+ ) extends DStream[(K, U)](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[(K, U)]] = {
+ parent.getOrCompute(validTime).map(_.mapValues[U](mapValueFunc))
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala
new file mode 100644
index 0000000000..20818a0cab
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/MappedDStream.scala
@@ -0,0 +1,20 @@
+package spark.streaming.dstream
+
+import spark.streaming.{Duration, DStream, Time}
+import spark.RDD
+
+private[streaming]
+class MappedDStream[T: ClassManifest, U: ClassManifest] (
+ parent: DStream[T],
+ mapFunc: T => U
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(_.map[U](mapFunc))
+ }
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
new file mode 100644
index 0000000000..aa6be95f30
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala
@@ -0,0 +1,254 @@
+package spark.streaming.dstream
+
+import spark.streaming.{Time, StreamingContext, AddBlocks, RegisterReceiver, DeregisterReceiver}
+
+import spark.{Logging, SparkEnv, RDD}
+import spark.rdd.BlockRDD
+import spark.storage.StorageLevel
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.nio.ByteBuffer
+
+import akka.actor.{Props, Actor}
+import akka.pattern.ask
+import akka.dispatch.Await
+import akka.util.duration._
+import spark.streaming.util.{RecurringTimer, SystemClock}
+import java.util.concurrent.ArrayBlockingQueue
+
+/**
+ * Abstract class for defining any InputDStream that has to start a receiver on worker
+ * nodes to receive external data. Specific implementations of NetworkInputDStream must
+ * define the createReceiver() function that creates the receiver object of type
+ * [[spark.streaming.dstream.NetworkReceiver]] that will be sent to the workers to receive
+ * data.
+ * @param ssc_ Streaming context that will execute this input stream
+ * @tparam T Class type of the object of this stream
+ */
+abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext)
+ extends InputDStream[T](ssc_) {
+
+ // This is an unique identifier that is used to match the network receiver with the
+ // corresponding network input stream.
+ val id = ssc.getNewNetworkStreamId()
+
+ /**
+ * Creates the receiver object that will be sent to the worker nodes
+ * to receive data. This method needs to defined by any specific implementation
+ * of a NetworkInputDStream.
+ */
+ def createReceiver(): NetworkReceiver[T]
+
+ // Nothing to start or stop as both taken care of by the NetworkInputTracker.
+ def start() {}
+
+ def stop() {}
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime)
+ Some(new BlockRDD[T](ssc.sc, blockIds))
+ }
+}
+
+
+private[streaming] sealed trait NetworkReceiverMessage
+private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage
+private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage
+private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage
+
+/**
+ * Abstract class of a receiver that can be run on worker nodes to receive external data. See
+ * [[spark.streaming.dstream.NetworkInputDStream]] for an explanation.
+ */
+abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Logging {
+
+ initLogging()
+
+ lazy protected val env = SparkEnv.get
+
+ lazy protected val actor = env.actorSystem.actorOf(
+ Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId)
+
+ lazy protected val receivingThread = Thread.currentThread()
+
+ protected var streamId: Int = -1
+
+ /**
+ * This method will be called to start receiving data. All your receiver
+ * starting code should be implemented by defining this function.
+ */
+ protected def onStart()
+
+ /** This method will be called to stop receiving data. */
+ protected def onStop()
+
+ /** Conveys a placement preference (hostname) for this receiver. */
+ def getLocationPreference() : Option[String] = None
+
+ /**
+ * Starts the receiver. First is accesses all the lazy members to
+ * materialize them. Then it calls the user-defined onStart() method to start
+ * other threads, etc required to receiver the data.
+ */
+ def start() {
+ try {
+ // Access the lazy vals to materialize them
+ env
+ actor
+ receivingThread
+
+ // Call user-defined onStart()
+ onStart()
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Receiving thread interrupted")
+ //println("Receiving thread interrupted")
+ case e: Exception =>
+ stopOnError(e)
+ }
+ }
+
+ /**
+ * Stops the receiver. First it interrupts the main receiving thread,
+ * that is, the thread that called receiver.start(). Then it calls the user-defined
+ * onStop() method to stop other threads and/or do cleanup.
+ */
+ def stop() {
+ receivingThread.interrupt()
+ onStop()
+ //TODO: terminate the actor
+ }
+
+ /**
+ * Stops the receiver and reports to exception to the tracker.
+ * This should be called whenever an exception has happened on any thread
+ * of the receiver.
+ */
+ protected def stopOnError(e: Exception) {
+ logError("Error receiving data", e)
+ stop()
+ actor ! ReportError(e.toString)
+ }
+
+
+ /**
+ * Pushes a block (as iterator of values) into the block manager.
+ */
+ def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) {
+ val buffer = new ArrayBuffer[T] ++ iterator
+ env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level)
+
+ actor ! ReportBlock(blockId, metadata)
+ }
+
+ /**
+ * Pushes a block (as bytes) into the block manager.
+ */
+ def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
+ env.blockManager.putBytes(blockId, bytes, level)
+ actor ! ReportBlock(blockId, metadata)
+ }
+
+ /** A helper actor that communicates with the NetworkInputTracker */
+ private class NetworkReceiverActor extends Actor {
+ logInfo("Attempting to register with tracker")
+ val ip = System.getProperty("spark.master.host", "localhost")
+ val port = System.getProperty("spark.master.port", "7077").toInt
+ val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port)
+ val tracker = env.actorSystem.actorFor(url)
+ val timeout = 5.seconds
+
+ override def preStart() {
+ val future = tracker.ask(RegisterReceiver(streamId, self))(timeout)
+ Await.result(future, timeout)
+ }
+
+ override def receive() = {
+ case ReportBlock(blockId, metadata) =>
+ tracker ! AddBlocks(streamId, Array(blockId), metadata)
+ case ReportError(msg) =>
+ tracker ! DeregisterReceiver(streamId, msg)
+ case StopReceiver(msg) =>
+ stop()
+ tracker ! DeregisterReceiver(streamId, msg)
+ }
+ }
+
+ protected[streaming] def setStreamId(id: Int) {
+ streamId = id
+ }
+
+ /**
+ * Batches objects created by a [[spark.streaming.NetworkReceiver]] and puts them into
+ * appropriately named blocks at regular intervals. This class starts two threads,
+ * one to periodically start a new batch and prepare the previous batch of as a block,
+ * the other to push the blocks into the block manager.
+ */
+ class BlockGenerator(storageLevel: StorageLevel)
+ extends Serializable with Logging {
+
+ case class Block(id: String, iterator: Iterator[T], metadata: Any = null)
+
+ val clock = new SystemClock()
+ val blockInterval = 200L
+ val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
+ val blockStorageLevel = storageLevel
+ val blocksForPushing = new ArrayBlockingQueue[Block](1000)
+ val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
+
+ var currentBuffer = new ArrayBuffer[T]
+
+ def start() {
+ blockIntervalTimer.start()
+ blockPushingThread.start()
+ logInfo("Data handler started")
+ }
+
+ def stop() {
+ blockIntervalTimer.stop()
+ blockPushingThread.interrupt()
+ logInfo("Data handler stopped")
+ }
+
+ def += (obj: T) {
+ currentBuffer += obj
+ }
+
+ private def createBlock(blockId: String, iterator: Iterator[T]) : Block = {
+ new Block(blockId, iterator)
+ }
+
+ private def updateCurrentBuffer(time: Long) {
+ try {
+ val newBlockBuffer = currentBuffer
+ currentBuffer = new ArrayBuffer[T]
+ if (newBlockBuffer.size > 0) {
+ val blockId = "input-" + NetworkReceiver.this.streamId + "-" + (time - blockInterval)
+ val newBlock = createBlock(blockId, newBlockBuffer.toIterator)
+ blocksForPushing.add(newBlock)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block interval timer thread interrupted")
+ case e: Exception =>
+ NetworkReceiver.this.stop()
+ }
+ }
+
+ private def keepPushingBlocks() {
+ logInfo("Block pushing thread started")
+ try {
+ while(true) {
+ val block = blocksForPushing.take()
+ NetworkReceiver.this.pushBlock(block.id, block.iterator, block.metadata, storageLevel)
+ }
+ } catch {
+ case ie: InterruptedException =>
+ logInfo("Block pushing thread interrupted")
+ case e: Exception =>
+ NetworkReceiver.this.stop()
+ }
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala
new file mode 100644
index 0000000000..024bf3bea4
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala
@@ -0,0 +1,41 @@
+package spark.streaming.dstream
+
+import spark.RDD
+import spark.rdd.UnionRDD
+
+import scala.collection.mutable.Queue
+import scala.collection.mutable.ArrayBuffer
+import spark.streaming.{Time, StreamingContext}
+
+class QueueInputDStream[T: ClassManifest](
+ @transient ssc: StreamingContext,
+ val queue: Queue[RDD[T]],
+ oneAtATime: Boolean,
+ defaultRDD: RDD[T]
+ ) extends InputDStream[T](ssc) {
+
+ override def start() { }
+
+ override def stop() { }
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ val buffer = new ArrayBuffer[RDD[T]]()
+ if (oneAtATime && queue.size > 0) {
+ buffer += queue.dequeue()
+ } else {
+ buffer ++= queue
+ }
+ if (buffer.size > 0) {
+ if (oneAtATime) {
+ Some(buffer.first)
+ } else {
+ Some(new UnionRDD(ssc.sc, buffer.toSeq))
+ }
+ } else if (defaultRDD != null) {
+ Some(defaultRDD)
+ } else {
+ None
+ }
+ }
+
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala
new file mode 100644
index 0000000000..04e6b69b7b
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala
@@ -0,0 +1,91 @@
+package spark.streaming.dstream
+
+import spark.Logging
+import spark.storage.StorageLevel
+import spark.streaming.StreamingContext
+
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+import java.nio.channels.{ReadableByteChannel, SocketChannel}
+import java.io.EOFException
+import java.util.concurrent.ArrayBlockingQueue
+
+
+/**
+ * An input stream that reads blocks of serialized objects from a given network address.
+ * The blocks will be inserted directly into the block store. This is the fastest way to get
+ * data into Spark Streaming, though it requires the sender to batch data and serialize it
+ * in the format that the system is configured with.
+ */
+private[streaming]
+class RawInputDStream[T: ClassManifest](
+ @transient ssc_ : StreamingContext,
+ host: String,
+ port: Int,
+ storageLevel: StorageLevel
+ ) extends NetworkInputDStream[T](ssc_ ) with Logging {
+
+ def createReceiver(): NetworkReceiver[T] = {
+ new RawNetworkReceiver(host, port, storageLevel).asInstanceOf[NetworkReceiver[T]]
+ }
+}
+
+private[streaming]
+class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel)
+ extends NetworkReceiver[Any] {
+
+ var blockPushingThread: Thread = null
+
+ override def getLocationPreference = None
+
+ def onStart() {
+ // Open a socket to the target address and keep reading from it
+ logInfo("Connecting to " + host + ":" + port)
+ val channel = SocketChannel.open()
+ channel.configureBlocking(true)
+ channel.connect(new InetSocketAddress(host, port))
+ logInfo("Connected to " + host + ":" + port)
+
+ val queue = new ArrayBlockingQueue[ByteBuffer](2)
+
+ blockPushingThread = new Thread {
+ setDaemon(true)
+ override def run() {
+ var nextBlockNumber = 0
+ while (true) {
+ val buffer = queue.take()
+ val blockId = "input-" + streamId + "-" + nextBlockNumber
+ nextBlockNumber += 1
+ pushBlock(blockId, buffer, null, storageLevel)
+ }
+ }
+ }
+ blockPushingThread.start()
+
+ val lengthBuffer = ByteBuffer.allocate(4)
+ while (true) {
+ lengthBuffer.clear()
+ readFully(channel, lengthBuffer)
+ lengthBuffer.flip()
+ val length = lengthBuffer.getInt()
+ val dataBuffer = ByteBuffer.allocate(length)
+ readFully(channel, dataBuffer)
+ dataBuffer.flip()
+ logInfo("Read a block with " + length + " bytes")
+ queue.put(dataBuffer)
+ }
+ }
+
+ def onStop() {
+ if (blockPushingThread != null) blockPushingThread.interrupt()
+ }
+
+ /** Read a buffer fully from a given Channel */
+ private def readFully(channel: ReadableByteChannel, dest: ByteBuffer) {
+ while (dest.position < dest.limit) {
+ if (channel.read(dest) == -1) {
+ throw new EOFException("End of channel")
+ }
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala
new file mode 100644
index 0000000000..733d5c4a25
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala
@@ -0,0 +1,149 @@
+package spark.streaming.dstream
+
+import spark.streaming.StreamingContext._
+
+import spark.RDD
+import spark.rdd.CoGroupedRDD
+import spark.Partitioner
+import spark.SparkContext._
+import spark.storage.StorageLevel
+
+import scala.collection.mutable.ArrayBuffer
+import spark.streaming.{Duration, Interval, Time, DStream}
+
+private[streaming]
+class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
+ parent: DStream[(K, V)],
+ reduceFunc: (V, V) => V,
+ invReduceFunc: (V, V) => V,
+ _windowDuration: Duration,
+ _slideDuration: Duration,
+ partitioner: Partitioner
+ ) extends DStream[(K,V)](parent.ssc) {
+
+ assert(_windowDuration.isMultipleOf(parent.slideDuration),
+ "The window duration of ReducedWindowedDStream (" + _slideDuration + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")"
+ )
+
+ assert(_slideDuration.isMultipleOf(parent.slideDuration),
+ "The slide duration of ReducedWindowedDStream (" + _slideDuration + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")"
+ )
+
+ // Reduce each batch of data using reduceByKey which will be further reduced by window
+ // by ReducedWindowedDStream
+ val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
+
+ // Persist RDDs to memory by default as these RDDs are going to be reused.
+ super.persist(StorageLevel.MEMORY_ONLY_SER)
+ reducedStream.persist(StorageLevel.MEMORY_ONLY_SER)
+
+ def windowDuration: Duration = _windowDuration
+
+ override def dependencies = List(reducedStream)
+
+ override def slideDuration: Duration = _slideDuration
+
+ override val mustCheckpoint = true
+
+ override def parentRememberDuration: Duration = rememberDuration + windowDuration
+
+ override def persist(storageLevel: StorageLevel): DStream[(K,V)] = {
+ super.persist(storageLevel)
+ reducedStream.persist(storageLevel)
+ this
+ }
+
+ override def checkpoint(interval: Duration): DStream[(K, V)] = {
+ super.checkpoint(interval)
+ //reducedStream.checkpoint(interval)
+ this
+ }
+
+ override def compute(validTime: Time): Option[RDD[(K, V)]] = {
+ val reduceF = reduceFunc
+ val invReduceF = invReduceFunc
+
+ val currentTime = validTime
+ val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration, currentTime)
+ val previousWindow = currentWindow - slideDuration
+
+ logDebug("Window time = " + windowDuration)
+ logDebug("Slide time = " + slideDuration)
+ logDebug("ZeroTime = " + zeroTime)
+ logDebug("Current window = " + currentWindow)
+ logDebug("Previous window = " + previousWindow)
+
+ // _____________________________
+ // | previous window _________|___________________
+ // |___________________| current window | --------------> Time
+ // |_____________________________|
+ //
+ // |________ _________| |________ _________|
+ // | |
+ // V V
+ // old RDDs new RDDs
+ //
+
+ // Get the RDDs of the reduced values in "old time steps"
+ val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration)
+ logDebug("# old RDDs = " + oldRDDs.size)
+
+ // Get the RDDs of the reduced values in "new time steps"
+ val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideDuration, currentWindow.endTime)
+ logDebug("# new RDDs = " + newRDDs.size)
+
+ // Get the RDD of the reduced value of the previous window
+ val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]()))
+
+ // Make the list of RDDs that needs to cogrouped together for reducing their reduced values
+ val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs
+
+ // Cogroup the reduced RDDs and merge the reduced values
+ val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner)
+ //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _
+
+ val numOldValues = oldRDDs.size
+ val numNewValues = newRDDs.size
+
+ val mergeValues = (seqOfValues: Seq[Seq[V]]) => {
+ if (seqOfValues.size != 1 + numOldValues + numNewValues) {
+ throw new Exception("Unexpected number of sequences of reduced values")
+ }
+ // Getting reduced values "old time steps" that will be removed from current window
+ val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head)
+ // Getting reduced values "new time steps"
+ val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head)
+ if (seqOfValues(0).isEmpty) {
+ // If previous window's reduce value does not exist, then at least new values should exist
+ if (newValues.isEmpty) {
+ throw new Exception("Neither previous window has value for key, nor new values found. " +
+ "Are you sure your key class hashes consistently?")
+ }
+ // Reduce the new values
+ newValues.reduce(reduceF) // return
+ } else {
+ // Get the previous window's reduced value
+ var tempValue = seqOfValues(0).head
+ // If old values exists, then inverse reduce then from previous value
+ if (!oldValues.isEmpty) {
+ tempValue = invReduceF(tempValue, oldValues.reduce(reduceF))
+ }
+ // If new values exists, then reduce them with previous value
+ if (!newValues.isEmpty) {
+ tempValue = reduceF(tempValue, newValues.reduce(reduceF))
+ }
+ tempValue // return
+ }
+ }
+
+ val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues)
+
+ Some(mergedValuesRDD)
+ }
+
+
+}
+
+
diff --git a/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala
new file mode 100644
index 0000000000..1f9548bfb8
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/ShuffledDStream.scala
@@ -0,0 +1,27 @@
+package spark.streaming.dstream
+
+import spark.{RDD, Partitioner}
+import spark.SparkContext._
+import spark.streaming.{Duration, DStream, Time}
+
+private[streaming]
+class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest](
+ parent: DStream[(K,V)],
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiner: (C, C) => C,
+ partitioner: Partitioner
+ ) extends DStream [(K,C)] (parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[(K,C)]] = {
+ parent.getOrCompute(validTime) match {
+ case Some(rdd) =>
+ Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner))
+ case None => None
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala
new file mode 100644
index 0000000000..d42027092b
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala
@@ -0,0 +1,103 @@
+package spark.streaming.dstream
+
+import spark.streaming.StreamingContext
+import spark.storage.StorageLevel
+
+import java.io._
+import java.net.Socket
+
+private[streaming]
+class SocketInputDStream[T: ClassManifest](
+ @transient ssc_ : StreamingContext,
+ host: String,
+ port: Int,
+ bytesToObjects: InputStream => Iterator[T],
+ storageLevel: StorageLevel
+ ) extends NetworkInputDStream[T](ssc_) {
+
+ def createReceiver(): NetworkReceiver[T] = {
+ new SocketReceiver(host, port, bytesToObjects, storageLevel)
+ }
+}
+
+private[streaming]
+class SocketReceiver[T: ClassManifest](
+ host: String,
+ port: Int,
+ bytesToObjects: InputStream => Iterator[T],
+ storageLevel: StorageLevel
+ ) extends NetworkReceiver[T] {
+
+ lazy protected val blockGenerator = new BlockGenerator(storageLevel)
+
+ override def getLocationPreference = None
+
+ protected def onStart() {
+ logInfo("Connecting to " + host + ":" + port)
+ val socket = new Socket(host, port)
+ logInfo("Connected to " + host + ":" + port)
+ blockGenerator.start()
+ val iterator = bytesToObjects(socket.getInputStream())
+ while(iterator.hasNext) {
+ val obj = iterator.next
+ blockGenerator += obj
+ }
+ }
+
+ protected def onStop() {
+ blockGenerator.stop()
+ }
+
+}
+
+private[streaming]
+object SocketReceiver {
+
+ /**
+ * This methods translates the data from an inputstream (say, from a socket)
+ * to '\n' delimited strings and returns an iterator to access the strings.
+ */
+ def bytesToLines(inputStream: InputStream): Iterator[String] = {
+ val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8"))
+
+ val iterator = new Iterator[String] {
+ var gotNext = false
+ var finished = false
+ var nextValue: String = null
+
+ private def getNext() {
+ try {
+ nextValue = dataInputStream.readLine()
+ if (nextValue == null) {
+ finished = true
+ }
+ }
+ gotNext = true
+ }
+
+ override def hasNext: Boolean = {
+ if (!finished) {
+ if (!gotNext) {
+ getNext()
+ if (finished) {
+ dataInputStream.close()
+ }
+ }
+ }
+ !finished
+ }
+
+ override def next(): String = {
+ if (finished) {
+ throw new NoSuchElementException("End of stream")
+ }
+ if (!gotNext) {
+ getNext()
+ }
+ gotNext = false
+ nextValue
+ }
+ }
+ iterator
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala
new file mode 100644
index 0000000000..b4506c74aa
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala
@@ -0,0 +1,84 @@
+package spark.streaming.dstream
+
+import spark.RDD
+import spark.Partitioner
+import spark.SparkContext._
+import spark.storage.StorageLevel
+import spark.streaming.{Duration, Time, DStream}
+
+private[streaming]
+class StateDStream[K: ClassManifest, V: ClassManifest, S: ClassManifest](
+ parent: DStream[(K, V)],
+ updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
+ partitioner: Partitioner,
+ preservePartitioning: Boolean
+ ) extends DStream[(K, S)](parent.ssc) {
+
+ super.persist(StorageLevel.MEMORY_ONLY_SER)
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override val mustCheckpoint = true
+
+ override def compute(validTime: Time): Option[RDD[(K, S)]] = {
+
+ // Try to get the previous state RDD
+ getOrCompute(validTime - slideDuration) match {
+
+ case Some(prevStateRDD) => { // If previous state RDD exists
+
+ // Try to get the parent RDD
+ parent.getOrCompute(validTime) match {
+ case Some(parentRDD) => { // If parent RDD exists, then compute as usual
+
+ // Define the function for the mapPartition operation on cogrouped RDD;
+ // first map the cogrouped tuple to tuples of required type,
+ // and then apply the update function
+ val updateFuncLocal = updateFunc
+ val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => {
+ val i = iterator.map(t => {
+ (t._1, t._2._1, t._2._2.headOption)
+ })
+ updateFuncLocal(i)
+ }
+ val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
+ val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
+ //logDebug("Generating state RDD for time " + validTime)
+ return Some(stateRDD)
+ }
+ case None => { // If parent RDD does not exist, then return old state RDD
+ return Some(prevStateRDD)
+ }
+ }
+ }
+
+ case None => { // If previous session RDD does not exist (first input data)
+
+ // Try to get the parent RDD
+ parent.getOrCompute(validTime) match {
+ case Some(parentRDD) => { // If parent RDD exists, then compute as usual
+
+ // Define the function for the mapPartition operation on grouped RDD;
+ // first map the grouped tuple to tuples of required type,
+ // and then apply the update function
+ val updateFuncLocal = updateFunc
+ val finalFunc = (iterator: Iterator[(K, Seq[V])]) => {
+ updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None)))
+ }
+
+ val groupedRDD = parentRDD.groupByKey(partitioner)
+ val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning)
+ //logDebug("Generating state RDD for time " + validTime + " (first)")
+ return Some(sessionRDD)
+ }
+ case None => { // If parent RDD does not exist, then nothing to do!
+ //logDebug("Not generating state RDD (no previous state, no parent)")
+ return None
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala
new file mode 100644
index 0000000000..99660d9dee
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/TransformedDStream.scala
@@ -0,0 +1,19 @@
+package spark.streaming.dstream
+
+import spark.RDD
+import spark.streaming.{Duration, DStream, Time}
+
+private[streaming]
+class TransformedDStream[T: ClassManifest, U: ClassManifest] (
+ parent: DStream[T],
+ transformFunc: (RDD[T], Time) => RDD[U]
+ ) extends DStream[U](parent.ssc) {
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[U]] = {
+ parent.getOrCompute(validTime).map(transformFunc(_, validTime))
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala
new file mode 100644
index 0000000000..00bad5da34
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/UnionDStream.scala
@@ -0,0 +1,40 @@
+package spark.streaming.dstream
+
+import spark.streaming.{Duration, DStream, Time}
+import spark.RDD
+import collection.mutable.ArrayBuffer
+import spark.rdd.UnionRDD
+
+private[streaming]
+class UnionDStream[T: ClassManifest](parents: Array[DStream[T]])
+ extends DStream[T](parents.head.ssc) {
+
+ if (parents.length == 0) {
+ throw new IllegalArgumentException("Empty array of parents")
+ }
+
+ if (parents.map(_.ssc).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different StreamingContexts")
+ }
+
+ if (parents.map(_.slideDuration).distinct.size > 1) {
+ throw new IllegalArgumentException("Array of parents have different slide times")
+ }
+
+ override def dependencies = parents.toList
+
+ override def slideDuration: Duration = parents.head.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ val rdds = new ArrayBuffer[RDD[T]]()
+ parents.map(_.getOrCompute(validTime)).foreach(_ match {
+ case Some(rdd) => rdds += rdd
+ case None => throw new Exception("Could not generate RDD from a parent for unifying at time " + validTime)
+ })
+ if (rdds.size > 0) {
+ Some(new UnionRDD(ssc.sc, rdds))
+ } else {
+ None
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala
new file mode 100644
index 0000000000..cbf0c88108
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/dstream/WindowedDStream.scala
@@ -0,0 +1,40 @@
+package spark.streaming.dstream
+
+import spark.RDD
+import spark.rdd.UnionRDD
+import spark.storage.StorageLevel
+import spark.streaming.{Duration, Interval, Time, DStream}
+
+private[streaming]
+class WindowedDStream[T: ClassManifest](
+ parent: DStream[T],
+ _windowDuration: Duration,
+ _slideDuration: Duration)
+ extends DStream[T](parent.ssc) {
+
+ if (!_windowDuration.isMultipleOf(parent.slideDuration))
+ throw new Exception("The window duration of WindowedDStream (" + _slideDuration + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")")
+
+ if (!_slideDuration.isMultipleOf(parent.slideDuration))
+ throw new Exception("The slide duration of WindowedDStream (" + _slideDuration + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")")
+
+ parent.persist(StorageLevel.MEMORY_ONLY_SER)
+
+ def windowDuration: Duration = _windowDuration
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = _slideDuration
+
+ override def parentRememberDuration: Duration = rememberDuration + windowDuration
+
+ override def compute(validTime: Time): Option[RDD[T]] = {
+ val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime)
+ Some(new UnionRDD(ssc.sc, parent.slice(currentWindow)))
+ }
+}
+
+
+
diff --git a/streaming/src/main/scala/spark/streaming/util/Clock.scala b/streaming/src/main/scala/spark/streaming/util/Clock.scala
new file mode 100644
index 0000000000..974651f9f6
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/Clock.scala
@@ -0,0 +1,84 @@
+package spark.streaming.util
+
+private[streaming]
+trait Clock {
+ def currentTime(): Long
+ def waitTillTime(targetTime: Long): Long
+}
+
+private[streaming]
+class SystemClock() extends Clock {
+
+ val minPollTime = 25L
+
+ def currentTime(): Long = {
+ System.currentTimeMillis()
+ }
+
+ def waitTillTime(targetTime: Long): Long = {
+ var currentTime = 0L
+ currentTime = System.currentTimeMillis()
+
+ var waitTime = targetTime - currentTime
+ if (waitTime <= 0) {
+ return currentTime
+ }
+
+ val pollTime = {
+ if (waitTime / 10.0 > minPollTime) {
+ (waitTime / 10.0).toLong
+ } else {
+ minPollTime
+ }
+ }
+
+
+ while (true) {
+ currentTime = System.currentTimeMillis()
+ waitTime = targetTime - currentTime
+
+ if (waitTime <= 0) {
+
+ return currentTime
+ }
+ val sleepTime =
+ if (waitTime < pollTime) {
+ waitTime
+ } else {
+ pollTime
+ }
+ Thread.sleep(sleepTime)
+ }
+ return -1
+ }
+}
+
+private[streaming]
+class ManualClock() extends Clock {
+
+ var time = 0L
+
+ def currentTime() = time
+
+ def setTime(timeToSet: Long) = {
+ this.synchronized {
+ time = timeToSet
+ this.notifyAll()
+ }
+ }
+
+ def addToTime(timeToAdd: Long) = {
+ this.synchronized {
+ time += timeToAdd
+ this.notifyAll()
+ }
+ }
+ def waitTillTime(targetTime: Long): Long = {
+ this.synchronized {
+ while (time < targetTime) {
+ this.wait(100)
+ }
+ }
+ return currentTime()
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala
new file mode 100644
index 0000000000..03749d4a94
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/RawTextHelper.scala
@@ -0,0 +1,98 @@
+package spark.streaming.util
+
+import spark.SparkContext
+import spark.SparkContext._
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+import scala.collection.JavaConversions.mapAsScalaMap
+
+object RawTextHelper {
+
+ /**
+ * Splits lines and counts the words in them using specialized object-to-long hashmap
+ * (to avoid boxing-unboxing overhead of Long in java/scala HashMap)
+ */
+ def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = {
+ val map = new OLMap[String]
+ var i = 0
+ var j = 0
+ while (iter.hasNext) {
+ val s = iter.next()
+ i = 0
+ while (i < s.length) {
+ j = i
+ while (j < s.length && s.charAt(j) != ' ') {
+ j += 1
+ }
+ if (j > i) {
+ val w = s.substring(i, j)
+ val c = map.getLong(w)
+ map.put(w, c + 1)
+ }
+ i = j
+ while (i < s.length && s.charAt(i) == ' ') {
+ i += 1
+ }
+ }
+ }
+ map.toIterator.map{case (k, v) => (k, v)}
+ }
+
+ /**
+ * Gets the top k words in terms of word counts. Assumes that each word exists only once
+ * in the `data` iterator (that is, the counts have been reduced).
+ */
+ def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = {
+ val taken = new Array[(String, Long)](k)
+
+ var i = 0
+ var len = 0
+ var done = false
+ var value: (String, Long) = null
+ var swap: (String, Long) = null
+ var count = 0
+
+ while(data.hasNext) {
+ value = data.next
+ if (value != null) {
+ count += 1
+ if (len == 0) {
+ taken(0) = value
+ len = 1
+ } else if (len < k || value._2 > taken(len - 1)._2) {
+ if (len < k) {
+ len += 1
+ }
+ taken(len - 1) = value
+ i = len - 1
+ while(i > 0 && taken(i - 1)._2 < taken(i)._2) {
+ swap = taken(i)
+ taken(i) = taken(i-1)
+ taken(i - 1) = swap
+ i -= 1
+ }
+ }
+ }
+ }
+ return taken.toIterator
+ }
+
+ /**
+ * Warms up the SparkContext in master and slave by running tasks to force JIT kick in
+ * before real workload starts.
+ */
+ def warmUp(sc: SparkContext) {
+ for(i <- 0 to 1) {
+ sc.parallelize(1 to 200000, 1000)
+ .map(_ % 1331).map(_.toString)
+ .mapPartitions(splitAndCountPartitions).reduceByKey(_ + _, 10)
+ .count()
+ }
+ }
+
+ def add(v1: Long, v2: Long) = (v1 + v2)
+
+ def subtract(v1: Long, v2: Long) = (v1 - v2)
+
+ def max(v1: Long, v2: Long) = math.max(v1, v2)
+}
+
diff --git a/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
new file mode 100644
index 0000000000..d8b987ec86
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/RawTextSender.scala
@@ -0,0 +1,60 @@
+package spark.streaming.util
+
+import java.nio.ByteBuffer
+import spark.util.{RateLimitedOutputStream, IntParam}
+import java.net.ServerSocket
+import spark.{Logging, KryoSerializer}
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import io.Source
+import java.io.IOException
+
+/**
+ * A helper program that sends blocks of Kryo-serialized text strings out on a socket at a
+ * specified rate. Used to feed data into RawInputDStream.
+ */
+object RawTextSender extends Logging {
+ def main(args: Array[String]) {
+ if (args.length != 4) {
+ System.err.println("Usage: RawTextSender <port> <file> <blockSize> <bytesPerSec>")
+ System.exit(1)
+ }
+ // Parse the arguments using a pattern match
+ val Array(IntParam(port), file, IntParam(blockSize), IntParam(bytesPerSec)) = args
+
+ // Repeat the input data multiple times to fill in a buffer
+ val lines = Source.fromFile(file).getLines().toArray
+ val bufferStream = new FastByteArrayOutputStream(blockSize + 1000)
+ val ser = new KryoSerializer().newInstance()
+ val serStream = ser.serializeStream(bufferStream)
+ var i = 0
+ while (bufferStream.position < blockSize) {
+ serStream.writeObject(lines(i))
+ i = (i + 1) % lines.length
+ }
+ bufferStream.trim()
+ val array = bufferStream.array
+
+ val countBuf = ByteBuffer.wrap(new Array[Byte](4))
+ countBuf.putInt(array.length)
+ countBuf.flip()
+
+ val serverSocket = new ServerSocket(port)
+ logInfo("Listening on port " + port)
+
+ while (true) {
+ val socket = serverSocket.accept()
+ logInfo("Got a new connection")
+ val out = new RateLimitedOutputStream(socket.getOutputStream, bytesPerSec)
+ try {
+ while (true) {
+ out.write(countBuf.array)
+ out.write(array)
+ }
+ } catch {
+ case e: IOException =>
+ logError("Client disconnected")
+ socket.close()
+ }
+ }
+ }
+}
diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
new file mode 100644
index 0000000000..db715cc295
--- /dev/null
+++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
@@ -0,0 +1,75 @@
+package spark.streaming.util
+
+private[streaming]
+class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) {
+
+ val minPollTime = 25L
+
+ val pollTime = {
+ if (period / 10.0 > minPollTime) {
+ (period / 10.0).toLong
+ } else {
+ minPollTime
+ }
+ }
+
+ val thread = new Thread() {
+ override def run() { loop }
+ }
+
+ var nextTime = 0L
+
+ def start(startTime: Long): Long = {
+ nextTime = startTime
+ thread.start()
+ nextTime
+ }
+
+ def start(): Long = {
+ val startTime = (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period
+ start(startTime)
+ }
+
+ def restart(originalStartTime: Long): Long = {
+ val gap = clock.currentTime - originalStartTime
+ val newStartTime = (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime
+ start(newStartTime)
+ }
+
+ def stop() {
+ thread.interrupt()
+ }
+
+ def loop() {
+ try {
+ while (true) {
+ clock.waitTillTime(nextTime)
+ callback(nextTime)
+ nextTime += period
+ }
+
+ } catch {
+ case e: InterruptedException =>
+ }
+ }
+}
+
+private[streaming]
+object RecurringTimer {
+
+ def main(args: Array[String]) {
+ var lastRecurTime = 0L
+ val period = 1000
+
+ def onRecur(time: Long) {
+ val currentTime = System.currentTimeMillis()
+ println("" + currentTime + ": " + (currentTime - lastRecurTime))
+ lastRecurTime = currentTime
+ }
+ val timer = new RecurringTimer(new SystemClock(), period, onRecur)
+ timer.start()
+ Thread.sleep(30 * 1000)
+ timer.stop()
+ }
+}
+
diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
new file mode 100644
index 0000000000..c84e7331c7
--- /dev/null
+++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java
@@ -0,0 +1,1029 @@
+package spark.streaming;
+
+import com.google.common.base.Optional;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.Files;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import scala.Tuple2;
+import spark.HashPartitioner;
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+import spark.api.java.function.*;
+import spark.storage.StorageLevel;
+import spark.streaming.api.java.JavaDStream;
+import spark.streaming.api.java.JavaPairDStream;
+import spark.streaming.api.java.JavaStreamingContext;
+import spark.streaming.JavaTestUtils;
+import spark.streaming.JavaCheckpointTestUtils;
+import spark.streaming.dstream.KafkaPartitionKey;
+
+import java.io.*;
+import java.util.*;
+
+// The test suite itself is Serializable so that anonymous Function implementations can be
+// serialized, as an alternative to converting these anonymous classes to static inner classes;
+// see http://stackoverflow.com/questions/758570/.
+public class JavaAPISuite implements Serializable {
+ private transient JavaStreamingContext ssc;
+
+ @Before
+ public void setUp() {
+ ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
+ ssc.checkpoint("checkpoint", new Duration(1000));
+ }
+
+ @After
+ public void tearDown() {
+ ssc.stop();
+ ssc = null;
+
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port");
+ }
+
+ @Test
+ public void testCount() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1,2,3,4),
+ Arrays.asList(3,4,5),
+ Arrays.asList(3));
+
+ List<List<Long>> expected = Arrays.asList(
+ Arrays.asList(4L),
+ Arrays.asList(3L),
+ Arrays.asList(1L));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream count = stream.count();
+ JavaTestUtils.attachTestOutputStream(count);
+ List<List<Long>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+ assertOrderInvariantEquals(expected, result);
+ }
+
+ @Test
+ public void testMap() {
+ List<List<String>> inputData = Arrays.asList(
+ Arrays.asList("hello", "world"),
+ Arrays.asList("goodnight", "moon"));
+
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(5,5),
+ Arrays.asList(9,4));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream letterCount = stream.map(new Function<String, Integer>() {
+ @Override
+ public Integer call(String s) throws Exception {
+ return s.length();
+ }
+ });
+ JavaTestUtils.attachTestOutputStream(letterCount);
+ List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ assertOrderInvariantEquals(expected, result);
+ }
+
+ @Test
+ public void testWindow() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1,2,3),
+ Arrays.asList(4,5,6),
+ Arrays.asList(7,8,9));
+
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(1,2,3),
+ Arrays.asList(4,5,6,1,2,3),
+ Arrays.asList(7,8,9,4,5,6),
+ Arrays.asList(7,8,9));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream windowed = stream.window(new Duration(2000));
+ JavaTestUtils.attachTestOutputStream(windowed);
+ List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 4, 4);
+
+ assertOrderInvariantEquals(expected, result);
+ }
+
+ @Test
+ public void testWindowWithSlideDuration() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1,2,3),
+ Arrays.asList(4,5,6),
+ Arrays.asList(7,8,9),
+ Arrays.asList(10,11,12),
+ Arrays.asList(13,14,15),
+ Arrays.asList(16,17,18));
+
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(1,2,3,4,5,6),
+ Arrays.asList(1,2,3,4,5,6,7,8,9,10,11,12),
+ Arrays.asList(7,8,9,10,11,12,13,14,15,16,17,18),
+ Arrays.asList(13,14,15,16,17,18));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream windowed = stream.window(new Duration(4000), new Duration(2000));
+ JavaTestUtils.attachTestOutputStream(windowed);
+ List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 8, 4);
+
+ assertOrderInvariantEquals(expected, result);
+ }
+
+ @Test
+ public void testTumble() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1,2,3),
+ Arrays.asList(4,5,6),
+ Arrays.asList(7,8,9),
+ Arrays.asList(10,11,12),
+ Arrays.asList(13,14,15),
+ Arrays.asList(16,17,18));
+
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(1,2,3,4,5,6),
+ Arrays.asList(7,8,9,10,11,12),
+ Arrays.asList(13,14,15,16,17,18));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream windowed = stream.tumble(new Duration(2000));
+ JavaTestUtils.attachTestOutputStream(windowed);
+ List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 6, 3);
+
+ assertOrderInvariantEquals(expected, result);
+ }
+
+ @Test
+ public void testFilter() {
+ List<List<String>> inputData = Arrays.asList(
+ Arrays.asList("giants", "dodgers"),
+ Arrays.asList("yankees", "red socks"));
+
+ List<List<String>> expected = Arrays.asList(
+ Arrays.asList("giants"),
+ Arrays.asList("yankees"));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream filtered = stream.filter(new Function<String, Boolean>() {
+ @Override
+ public Boolean call(String s) throws Exception {
+ return s.contains("a");
+ }
+ });
+ JavaTestUtils.attachTestOutputStream(filtered);
+ List<List<String>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ assertOrderInvariantEquals(expected, result);
+ }
+
+ @Test
+ public void testGlom() {
+ List<List<String>> inputData = Arrays.asList(
+ Arrays.asList("giants", "dodgers"),
+ Arrays.asList("yankees", "red socks"));
+
+ List<List<List<String>>> expected = Arrays.asList(
+ Arrays.asList(Arrays.asList("giants", "dodgers")),
+ Arrays.asList(Arrays.asList("yankees", "red socks")));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream glommed = stream.glom();
+ JavaTestUtils.attachTestOutputStream(glommed);
+ List<List<List<String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testMapPartitions() {
+ List<List<String>> inputData = Arrays.asList(
+ Arrays.asList("giants", "dodgers"),
+ Arrays.asList("yankees", "red socks"));
+
+ List<List<String>> expected = Arrays.asList(
+ Arrays.asList("GIANTSDODGERS"),
+ Arrays.asList("YANKEESRED SOCKS"));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream mapped = stream.mapPartitions(new FlatMapFunction<Iterator<String>, String>() {
+ @Override
+ public Iterable<String> call(Iterator<String> in) {
+ String out = "";
+ while (in.hasNext()) {
+ out = out + in.next().toUpperCase();
+ }
+ return Lists.newArrayList(out);
+ }
+ });
+ JavaTestUtils.attachTestOutputStream(mapped);
+ List<List<List<String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ private class IntegerSum extends Function2<Integer, Integer, Integer> {
+ @Override
+ public Integer call(Integer i1, Integer i2) throws Exception {
+ return i1 + i2;
+ }
+ }
+
+ private class IntegerDifference extends Function2<Integer, Integer, Integer> {
+ @Override
+ public Integer call(Integer i1, Integer i2) throws Exception {
+ return i1 - i2;
+ }
+ }
+
+ @Test
+ public void testReduce() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1,2,3),
+ Arrays.asList(4,5,6),
+ Arrays.asList(7,8,9));
+
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(6),
+ Arrays.asList(15),
+ Arrays.asList(24));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream reduced = stream.reduce(new IntegerSum());
+ JavaTestUtils.attachTestOutputStream(reduced);
+ List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testReduceByWindow() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1,2,3),
+ Arrays.asList(4,5,6),
+ Arrays.asList(7,8,9));
+
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(6),
+ Arrays.asList(21),
+ Arrays.asList(39),
+ Arrays.asList(24));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(),
+ new IntegerDifference(), new Duration(2000), new Duration(1000));
+ JavaTestUtils.attachTestOutputStream(reducedWindowed);
+ List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 4, 4);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testQueueStream() {
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(1,2,3),
+ Arrays.asList(4,5,6),
+ Arrays.asList(7,8,9));
+
+ JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc());
+ JavaRDD<Integer> rdd1 = ssc.sc().parallelize(Arrays.asList(1,2,3));
+ JavaRDD<Integer> rdd2 = ssc.sc().parallelize(Arrays.asList(4,5,6));
+ JavaRDD<Integer> rdd3 = ssc.sc().parallelize(Arrays.asList(7,8,9));
+
+ LinkedList<JavaRDD<Integer>> rdds = Lists.newLinkedList();
+ rdds.add(rdd1);
+ rdds.add(rdd2);
+ rdds.add(rdd3);
+
+ JavaDStream<Integer> stream = ssc.queueStream(rdds);
+ JavaTestUtils.attachTestOutputStream(stream);
+ List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testTransform() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1,2,3),
+ Arrays.asList(4,5,6),
+ Arrays.asList(7,8,9));
+
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(3,4,5),
+ Arrays.asList(6,7,8),
+ Arrays.asList(9,10,11));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream transformed = stream.transform(new Function<JavaRDD<Integer>, JavaRDD<Integer>>() {
+ @Override
+ public JavaRDD<Integer> call(JavaRDD<Integer> in) throws Exception {
+ return in.map(new Function<Integer, Integer>() {
+ @Override
+ public Integer call(Integer i) throws Exception {
+ return i + 2;
+ }
+ });
+ }});
+ JavaTestUtils.attachTestOutputStream(transformed);
+ List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+
+ assertOrderInvariantEquals(expected, result);
+ }
+
+ @Test
+ public void testFlatMap() {
+ List<List<String>> inputData = Arrays.asList(
+ Arrays.asList("go", "giants"),
+ Arrays.asList("boo", "dodgers"),
+ Arrays.asList("athletics"));
+
+ List<List<String>> expected = Arrays.asList(
+ Arrays.asList("g","o","g","i","a","n","t","s"),
+ Arrays.asList("b", "o", "o", "d","o","d","g","e","r","s"),
+ Arrays.asList("a","t","h","l","e","t","i","c","s"));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream flatMapped = stream.flatMap(new FlatMapFunction<String, String>() {
+ @Override
+ public Iterable<String> call(String x) {
+ return Lists.newArrayList(x.split("(?!^)"));
+ }
+ });
+ JavaTestUtils.attachTestOutputStream(flatMapped);
+ List<List<String>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+
+ assertOrderInvariantEquals(expected, result);
+ }
+
+ @Test
+ public void testPairFlatMap() {
+ List<List<String>> inputData = Arrays.asList(
+ Arrays.asList("giants"),
+ Arrays.asList("dodgers"),
+ Arrays.asList("athletics"));
+
+ List<List<Tuple2<Integer, String>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<Integer, String>(6, "g"),
+ new Tuple2<Integer, String>(6, "i"),
+ new Tuple2<Integer, String>(6, "a"),
+ new Tuple2<Integer, String>(6, "n"),
+ new Tuple2<Integer, String>(6, "t"),
+ new Tuple2<Integer, String>(6, "s")),
+ Arrays.asList(
+ new Tuple2<Integer, String>(7, "d"),
+ new Tuple2<Integer, String>(7, "o"),
+ new Tuple2<Integer, String>(7, "d"),
+ new Tuple2<Integer, String>(7, "g"),
+ new Tuple2<Integer, String>(7, "e"),
+ new Tuple2<Integer, String>(7, "r"),
+ new Tuple2<Integer, String>(7, "s")),
+ Arrays.asList(
+ new Tuple2<Integer, String>(9, "a"),
+ new Tuple2<Integer, String>(9, "t"),
+ new Tuple2<Integer, String>(9, "h"),
+ new Tuple2<Integer, String>(9, "l"),
+ new Tuple2<Integer, String>(9, "e"),
+ new Tuple2<Integer, String>(9, "t"),
+ new Tuple2<Integer, String>(9, "i"),
+ new Tuple2<Integer, String>(9, "c"),
+ new Tuple2<Integer, String>(9, "s")));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction<String, Integer, String>() {
+ @Override
+ public Iterable<Tuple2<Integer, String>> call(String in) throws Exception {
+ List<Tuple2<Integer, String>> out = Lists.newArrayList();
+ for (String letter: in.split("(?!^)")) {
+ out.add(new Tuple2<Integer, String>(in.length(), letter));
+ }
+ return out;
+ }
+ });
+ JavaTestUtils.attachTestOutputStream(flatMapped);
+ List<List<Tuple2<Integer, String>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testUnion() {
+ List<List<Integer>> inputData1 = Arrays.asList(
+ Arrays.asList(1,1),
+ Arrays.asList(2,2),
+ Arrays.asList(3,3));
+
+ List<List<Integer>> inputData2 = Arrays.asList(
+ Arrays.asList(4,4),
+ Arrays.asList(5,5),
+ Arrays.asList(6,6));
+
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(1,1,4,4),
+ Arrays.asList(2,2,5,5),
+ Arrays.asList(3,3,6,6));
+
+ JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 2);
+ JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 2);
+
+ JavaDStream unioned = stream1.union(stream2);
+ JavaTestUtils.attachTestOutputStream(unioned);
+ List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+
+ assertOrderInvariantEquals(expected, result);
+ }
+
+ /*
+ * Performs an order-invariant comparison of lists representing two RDD streams. This allows
+ * us to account for ordering variation within individual RDD's which occurs during windowing.
+ */
+ public static <T extends Comparable> void assertOrderInvariantEquals(
+ List<List<T>> expected, List<List<T>> actual) {
+ for (List<T> list: expected) {
+ Collections.sort(list);
+ }
+ for (List<T> list: actual) {
+ Collections.sort(list);
+ }
+ Assert.assertEquals(expected, actual);
+ }
+
+
+ // PairDStream Functions
+ @Test
+ public void testPairFilter() {
+ List<List<String>> inputData = Arrays.asList(
+ Arrays.asList("giants", "dodgers"),
+ Arrays.asList("yankees", "red socks"));
+
+ List<List<Tuple2<String, Integer>>> expected = Arrays.asList(
+ Arrays.asList(new Tuple2<String, Integer>("giants", 6)),
+ Arrays.asList(new Tuple2<String, Integer>("yankees", 7)));
+
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = stream.map(
+ new PairFunction<String, String, Integer>() {
+ @Override
+ public Tuple2 call(String in) throws Exception {
+ return new Tuple2<String, Integer>(in, in.length());
+ }
+ });
+
+ JavaPairDStream<String, Integer> filtered = pairStream.filter(
+ new Function<Tuple2<String, Integer>, Boolean>() {
+ @Override
+ public Boolean call(Tuple2<String, Integer> in) throws Exception {
+ return in._1().contains("a");
+ }
+ });
+ JavaTestUtils.attachTestOutputStream(filtered);
+ List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ List<List<Tuple2<String, String>>> stringStringKVStream = Arrays.asList(
+ Arrays.asList(new Tuple2<String, String>("california", "dodgers"),
+ new Tuple2<String, String>("california", "giants"),
+ new Tuple2<String, String>("new york", "yankees"),
+ new Tuple2<String, String>("new york", "mets")),
+ Arrays.asList(new Tuple2<String, String>("california", "sharks"),
+ new Tuple2<String, String>("california", "ducks"),
+ new Tuple2<String, String>("new york", "rangers"),
+ new Tuple2<String, String>("new york", "islanders")));
+
+ List<List<Tuple2<String, Integer>>> stringIntKVStream = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, Integer>("california", 1),
+ new Tuple2<String, Integer>("california", 3),
+ new Tuple2<String, Integer>("new york", 4),
+ new Tuple2<String, Integer>("new york", 1)),
+ Arrays.asList(
+ new Tuple2<String, Integer>("california", 5),
+ new Tuple2<String, Integer>("california", 5),
+ new Tuple2<String, Integer>("new york", 3),
+ new Tuple2<String, Integer>("new york", 1)));
+
+ @Test
+ public void testPairGroupByKey() {
+ List<List<Tuple2<String, String>>> inputData = stringStringKVStream;
+
+ List<List<Tuple2<String, List<String>>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, List<String>>("california", Arrays.asList("dodgers", "giants")),
+ new Tuple2<String, List<String>>("new york", Arrays.asList("yankees", "mets"))),
+ Arrays.asList(
+ new Tuple2<String, List<String>>("california", Arrays.asList("sharks", "ducks")),
+ new Tuple2<String, List<String>>("new york", Arrays.asList("rangers", "islanders"))));
+
+ JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<String, List<String>> grouped = pairStream.groupByKey();
+ JavaTestUtils.attachTestOutputStream(grouped);
+ List<List<Tuple2<String, List<String>>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testPairReduceByKey() {
+ List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
+
+ List<List<Tuple2<String, Integer>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, Integer>("california", 4),
+ new Tuple2<String, Integer>("new york", 5)),
+ Arrays.asList(
+ new Tuple2<String, Integer>("california", 10),
+ new Tuple2<String, Integer>("new york", 4)));
+
+ JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(
+ ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<String, Integer> reduced = pairStream.reduceByKey(new IntegerSum());
+
+ JavaTestUtils.attachTestOutputStream(reduced);
+ List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testCombineByKey() {
+ List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
+
+ List<List<Tuple2<String, Integer>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, Integer>("california", 4),
+ new Tuple2<String, Integer>("new york", 5)),
+ Arrays.asList(
+ new Tuple2<String, Integer>("california", 10),
+ new Tuple2<String, Integer>("new york", 4)));
+
+ JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(
+ ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<String, Integer> combined = pairStream.<Integer>combineByKey(
+ new Function<Integer, Integer>() {
+ @Override
+ public Integer call(Integer i) throws Exception {
+ return i;
+ }
+ }, new IntegerSum(), new IntegerSum(), new HashPartitioner(2));
+
+ JavaTestUtils.attachTestOutputStream(combined);
+ List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testCountByKey() {
+ List<List<Tuple2<String, String>>> inputData = stringStringKVStream;
+
+ List<List<Tuple2<String, Long>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, Long>("california", 2L),
+ new Tuple2<String, Long>("new york", 2L)),
+ Arrays.asList(
+ new Tuple2<String, Long>("california", 2L),
+ new Tuple2<String, Long>("new york", 2L)));
+
+ JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream(
+ ssc, inputData, 1);
+ JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<String, Long> counted = pairStream.countByKey();
+ JavaTestUtils.attachTestOutputStream(counted);
+ List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testGroupByKeyAndWindow() {
+ List<List<Tuple2<String, String>>> inputData = stringStringKVStream;
+
+ List<List<Tuple2<String, List<String>>>> expected = Arrays.asList(
+ Arrays.asList(new Tuple2<String, List<String>>("california", Arrays.asList("dodgers", "giants")),
+ new Tuple2<String, List<String>>("new york", Arrays.asList("yankees", "mets"))),
+ Arrays.asList(new Tuple2<String, List<String>>("california",
+ Arrays.asList("sharks", "ducks", "dodgers", "giants")),
+ new Tuple2<String, List<String>>("new york", Arrays.asList("rangers", "islanders", "yankees", "mets"))),
+ Arrays.asList(new Tuple2<String, List<String>>("california", Arrays.asList("sharks", "ducks")),
+ new Tuple2<String, List<String>>("new york", Arrays.asList("rangers", "islanders"))));
+
+ JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<String, List<String>> groupWindowed =
+ pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000));
+ JavaTestUtils.attachTestOutputStream(groupWindowed);
+ List<List<Tuple2<String, List<String>>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testReduceByKeyAndWindow() {
+ List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
+
+ List<List<Tuple2<String, Integer>>> expected = Arrays.asList(
+ Arrays.asList(new Tuple2<String, Integer>("california", 4),
+ new Tuple2<String, Integer>("new york", 5)),
+ Arrays.asList(new Tuple2<String, Integer>("california", 14),
+ new Tuple2<String, Integer>("new york", 9)),
+ Arrays.asList(new Tuple2<String, Integer>("california", 10),
+ new Tuple2<String, Integer>("new york", 4)));
+
+ JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<String, Integer> reduceWindowed =
+ pairStream.reduceByKeyAndWindow(new IntegerSum(), new Duration(2000), new Duration(1000));
+ JavaTestUtils.attachTestOutputStream(reduceWindowed);
+ List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testUpdateStateByKey() {
+ List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
+
+ List<List<Tuple2<String, Integer>>> expected = Arrays.asList(
+ Arrays.asList(new Tuple2<String, Integer>("california", 4),
+ new Tuple2<String, Integer>("new york", 5)),
+ Arrays.asList(new Tuple2<String, Integer>("california", 14),
+ new Tuple2<String, Integer>("new york", 9)),
+ Arrays.asList(new Tuple2<String, Integer>("california", 14),
+ new Tuple2<String, Integer>("new york", 9)));
+
+ JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey(
+ new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>(){
+ @Override
+ public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
+ int out = 0;
+ if (state.isPresent()) {
+ out = out + state.get();
+ }
+ for (Integer v: values) {
+ out = out + v;
+ }
+ return Optional.of(out);
+ }
+ });
+ JavaTestUtils.attachTestOutputStream(updated);
+ List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testReduceByKeyAndWindowWithInverse() {
+ List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
+
+ List<List<Tuple2<String, Integer>>> expected = Arrays.asList(
+ Arrays.asList(new Tuple2<String, Integer>("california", 4),
+ new Tuple2<String, Integer>("new york", 5)),
+ Arrays.asList(new Tuple2<String, Integer>("california", 14),
+ new Tuple2<String, Integer>("new york", 9)),
+ Arrays.asList(new Tuple2<String, Integer>("california", 10),
+ new Tuple2<String, Integer>("new york", 4)));
+
+ JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<String, Integer> reduceWindowed =
+ pairStream.reduceByKeyAndWindow(new IntegerSum(), new IntegerDifference(), new Duration(2000), new Duration(1000));
+ JavaTestUtils.attachTestOutputStream(reduceWindowed);
+ List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testCountByKeyAndWindow() {
+ List<List<Tuple2<String, String>>> inputData = stringStringKVStream;
+
+ List<List<Tuple2<String, Long>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, Long>("california", 2L),
+ new Tuple2<String, Long>("new york", 2L)),
+ Arrays.asList(
+ new Tuple2<String, Long>("california", 4L),
+ new Tuple2<String, Long>("new york", 4L)),
+ Arrays.asList(
+ new Tuple2<String, Long>("california", 2L),
+ new Tuple2<String, Long>("new york", 2L)));
+
+ JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream(
+ ssc, inputData, 1);
+ JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<String, Long> counted =
+ pairStream.countByKeyAndWindow(new Duration(2000), new Duration(1000));
+ JavaTestUtils.attachTestOutputStream(counted);
+ List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testMapValues() {
+ List<List<Tuple2<String, String>>> inputData = stringStringKVStream;
+
+ List<List<Tuple2<String, String>>> expected = Arrays.asList(
+ Arrays.asList(new Tuple2<String, String>("california", "DODGERS"),
+ new Tuple2<String, String>("california", "GIANTS"),
+ new Tuple2<String, String>("new york", "YANKEES"),
+ new Tuple2<String, String>("new york", "METS")),
+ Arrays.asList(new Tuple2<String, String>("california", "SHARKS"),
+ new Tuple2<String, String>("california", "DUCKS"),
+ new Tuple2<String, String>("new york", "RANGERS"),
+ new Tuple2<String, String>("new york", "ISLANDERS")));
+
+ JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream(
+ ssc, inputData, 1);
+ JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<String, String> mapped = pairStream.mapValues(new Function<String, String>() {
+ @Override
+ public String call(String s) throws Exception {
+ return s.toUpperCase();
+ }
+ });
+
+ JavaTestUtils.attachTestOutputStream(mapped);
+ List<List<Tuple2<String, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testFlatMapValues() {
+ List<List<Tuple2<String, String>>> inputData = stringStringKVStream;
+
+ List<List<Tuple2<String, String>>> expected = Arrays.asList(
+ Arrays.asList(new Tuple2<String, String>("california", "dodgers1"),
+ new Tuple2<String, String>("california", "dodgers2"),
+ new Tuple2<String, String>("california", "giants1"),
+ new Tuple2<String, String>("california", "giants2"),
+ new Tuple2<String, String>("new york", "yankees1"),
+ new Tuple2<String, String>("new york", "yankees2"),
+ new Tuple2<String, String>("new york", "mets1"),
+ new Tuple2<String, String>("new york", "mets2")),
+ Arrays.asList(new Tuple2<String, String>("california", "sharks1"),
+ new Tuple2<String, String>("california", "sharks2"),
+ new Tuple2<String, String>("california", "ducks1"),
+ new Tuple2<String, String>("california", "ducks2"),
+ new Tuple2<String, String>("new york", "rangers1"),
+ new Tuple2<String, String>("new york", "rangers2"),
+ new Tuple2<String, String>("new york", "islanders1"),
+ new Tuple2<String, String>("new york", "islanders2")));
+
+ JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream(
+ ssc, inputData, 1);
+ JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+
+ JavaPairDStream<String, String> flatMapped = pairStream.flatMapValues(
+ new Function<String, Iterable<String>>() {
+ @Override
+ public Iterable<String> call(String in) {
+ List<String> out = new ArrayList<String>();
+ out.add(in + "1");
+ out.add(in + "2");
+ return out;
+ }
+ });
+
+ JavaTestUtils.attachTestOutputStream(flatMapped);
+ List<List<Tuple2<String, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testCoGroup() {
+ List<List<Tuple2<String, String>>> stringStringKVStream1 = Arrays.asList(
+ Arrays.asList(new Tuple2<String, String>("california", "dodgers"),
+ new Tuple2<String, String>("new york", "yankees")),
+ Arrays.asList(new Tuple2<String, String>("california", "sharks"),
+ new Tuple2<String, String>("new york", "rangers")));
+
+ List<List<Tuple2<String, String>>> stringStringKVStream2 = Arrays.asList(
+ Arrays.asList(new Tuple2<String, String>("california", "giants"),
+ new Tuple2<String, String>("new york", "mets")),
+ Arrays.asList(new Tuple2<String, String>("california", "ducks"),
+ new Tuple2<String, String>("new york", "islanders")));
+
+
+ List<List<Tuple2<String, Tuple2<List<String>, List<String>>>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, Tuple2<List<String>, List<String>>>("california",
+ new Tuple2<List<String>, List<String>>(Arrays.asList("dodgers"), Arrays.asList("giants"))),
+ new Tuple2<String, Tuple2<List<String>, List<String>>>("new york",
+ new Tuple2<List<String>, List<String>>(Arrays.asList("yankees"), Arrays.asList("mets")))),
+ Arrays.asList(
+ new Tuple2<String, Tuple2<List<String>, List<String>>>("california",
+ new Tuple2<List<String>, List<String>>(Arrays.asList("sharks"), Arrays.asList("ducks"))),
+ new Tuple2<String, Tuple2<List<String>, List<String>>>("new york",
+ new Tuple2<List<String>, List<String>>(Arrays.asList("rangers"), Arrays.asList("islanders")))));
+
+
+ JavaDStream<Tuple2<String, String>> stream1 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream1, 1);
+ JavaPairDStream<String, String> pairStream1 = JavaPairDStream.fromJavaDStream(stream1);
+
+ JavaDStream<Tuple2<String, String>> stream2 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream2, 1);
+ JavaPairDStream<String, String> pairStream2 = JavaPairDStream.fromJavaDStream(stream2);
+
+ JavaPairDStream<String, Tuple2<List<String>, List<String>>> grouped = pairStream1.cogroup(pairStream2);
+ JavaTestUtils.attachTestOutputStream(grouped);
+ List<List<Tuple2<String, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testJoin() {
+ List<List<Tuple2<String, String>>> stringStringKVStream1 = Arrays.asList(
+ Arrays.asList(new Tuple2<String, String>("california", "dodgers"),
+ new Tuple2<String, String>("new york", "yankees")),
+ Arrays.asList(new Tuple2<String, String>("california", "sharks"),
+ new Tuple2<String, String>("new york", "rangers")));
+
+ List<List<Tuple2<String, String>>> stringStringKVStream2 = Arrays.asList(
+ Arrays.asList(new Tuple2<String, String>("california", "giants"),
+ new Tuple2<String, String>("new york", "mets")),
+ Arrays.asList(new Tuple2<String, String>("california", "ducks"),
+ new Tuple2<String, String>("new york", "islanders")));
+
+
+ List<List<Tuple2<String, Tuple2<String, String>>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, Tuple2<String, String>>("california",
+ new Tuple2<String, String>("dodgers", "giants")),
+ new Tuple2<String, Tuple2<String, String>>("new york",
+ new Tuple2<String, String>("yankees", "mets"))),
+ Arrays.asList(
+ new Tuple2<String, Tuple2<String, String>>("california",
+ new Tuple2<String, String>("sharks", "ducks")),
+ new Tuple2<String, Tuple2<String, String>>("new york",
+ new Tuple2<String, String>("rangers", "islanders"))));
+
+
+ JavaDStream<Tuple2<String, String>> stream1 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream1, 1);
+ JavaPairDStream<String, String> pairStream1 = JavaPairDStream.fromJavaDStream(stream1);
+
+ JavaDStream<Tuple2<String, String>> stream2 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream2, 1);
+ JavaPairDStream<String, String> pairStream2 = JavaPairDStream.fromJavaDStream(stream2);
+
+ JavaPairDStream<String, Tuple2<String, String>> joined = pairStream1.join(pairStream2);
+ JavaTestUtils.attachTestOutputStream(joined);
+ List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testCheckpointMasterRecovery() throws InterruptedException {
+ List<List<String>> inputData = Arrays.asList(
+ Arrays.asList("this", "is"),
+ Arrays.asList("a", "test"),
+ Arrays.asList("counting", "letters"));
+
+ List<List<Integer>> expectedInitial = Arrays.asList(
+ Arrays.asList(4,2));
+ List<List<Integer>> expectedFinal = Arrays.asList(
+ Arrays.asList(1,4),
+ Arrays.asList(8,7));
+
+
+ File tempDir = Files.createTempDir();
+ ssc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000));
+
+ JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream letterCount = stream.map(new Function<String, Integer>() {
+ @Override
+ public Integer call(String s) throws Exception {
+ return s.length();
+ }
+ });
+ JavaCheckpointTestUtils.attachTestOutputStream(letterCount);
+ List<List<Integer>> initialResult = JavaTestUtils.runStreams(ssc, 1, 1);
+
+ assertOrderInvariantEquals(expectedInitial, initialResult);
+ Thread.sleep(1000);
+
+ ssc.stop();
+ ssc = new JavaStreamingContext(tempDir.getAbsolutePath());
+ ssc.start();
+ List<List<Integer>> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 2);
+ assertOrderInvariantEquals(expectedFinal, finalResult);
+ }
+
+ /** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD
+ @Test
+ public void testCheckpointofIndividualStream() throws InterruptedException {
+ List<List<String>> inputData = Arrays.asList(
+ Arrays.asList("this", "is"),
+ Arrays.asList("a", "test"),
+ Arrays.asList("counting", "letters"));
+
+ List<List<Integer>> expected = Arrays.asList(
+ Arrays.asList(4,2),
+ Arrays.asList(1,4),
+ Arrays.asList(8,7));
+
+ JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream letterCount = stream.map(new Function<String, Integer>() {
+ @Override
+ public Integer call(String s) throws Exception {
+ return s.length();
+ }
+ });
+ JavaCheckpointTestUtils.attachTestOutputStream(letterCount);
+
+ letterCount.checkpoint(new Duration(1000));
+
+ List<List<Integer>> result1 = JavaCheckpointTestUtils.runStreams(ssc, 3, 3);
+ assertOrderInvariantEquals(expected, result1);
+ }
+ */
+
+ // Input stream tests. These mostly just test that we can instantiate a given InputStream with
+ // Java arguments and assign it to a JavaDStream without producing type errors. Testing of the
+ // InputStream functionality is deferred to the existing Scala tests.
+ @Test
+ public void testKafkaStream() {
+ HashMap<String, Integer> topics = Maps.newHashMap();
+ HashMap<KafkaPartitionKey, Long> offsets = Maps.newHashMap();
+ JavaDStream test1 = ssc.kafkaStream("localhost", 12345, "group", topics);
+ JavaDStream test2 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets);
+ JavaDStream test3 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets,
+ StorageLevel.MEMORY_AND_DISK());
+ }
+
+ @Test
+ public void testNetworkTextStream() {
+ JavaDStream test = ssc.networkTextStream("localhost", 12345);
+ }
+
+ @Test
+ public void testNetworkString() {
+ class Converter extends Function<InputStream, Iterable<String>> {
+ public Iterable<String> call(InputStream in) {
+ BufferedReader reader = new BufferedReader(new InputStreamReader(in));
+ List<String> out = new ArrayList<String>();
+ try {
+ while (true) {
+ String line = reader.readLine();
+ if (line == null) { break; }
+ out.add(line);
+ }
+ } catch (IOException e) { }
+ return out;
+ }
+ }
+
+ JavaDStream test = ssc.networkStream(
+ "localhost",
+ 12345,
+ new Converter(),
+ StorageLevel.MEMORY_ONLY());
+ }
+
+ @Test
+ public void testTextFileStream() {
+ JavaDStream test = ssc.textFileStream("/tmp/foo");
+ }
+
+ @Test
+ public void testRawNetworkStream() {
+ JavaDStream test = ssc.rawNetworkStream("localhost", 12345);
+ }
+
+ @Test
+ public void testFlumeStream() {
+ JavaDStream test = ssc.flumeStream("localhost", 12345);
+ }
+
+ @Test
+ public void testFileStream() {
+ JavaPairDStream<String, String> foo =
+ ssc.<String, String, SequenceFileInputFormat>fileStream("/tmp/foo");
+ }
+}
diff --git a/streaming/src/test/java/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala
new file mode 100644
index 0000000000..56349837e5
--- /dev/null
+++ b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala
@@ -0,0 +1,65 @@
+package spark.streaming
+
+import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import java.util.{List => JList}
+import spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext}
+import spark.streaming._
+import java.util.ArrayList
+import collection.JavaConversions._
+
+/** Exposes streaming test functionality in a Java-friendly way. */
+trait JavaTestBase extends TestSuiteBase {
+
+ /**
+ * Create a [[spark.streaming.TestInputStream]] and attach it to the supplied context.
+ * The stream will be derived from the supplied lists of Java objects.
+ **/
+ def attachTestInputStream[T](
+ ssc: JavaStreamingContext,
+ data: JList[JList[T]],
+ numPartitions: Int) = {
+ val seqData = data.map(Seq(_:_*))
+
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ val dstream = new TestInputStream[T](ssc.ssc, seqData, numPartitions)
+ ssc.ssc.registerInputStream(dstream)
+ new JavaDStream[T](dstream)
+ }
+
+ /**
+ * Attach a provided stream to it's associated StreamingContext as a
+ * [[spark.streaming.TestOutputStream]].
+ **/
+ def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T,This]](
+ dstream: JavaDStreamLike[T, This]) = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ val ostream = new TestOutputStream(dstream.dstream,
+ new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]])
+ dstream.dstream.ssc.registerOutputStream(ostream)
+ }
+
+ /**
+ * Process all registered streams for a numBatches batches, failing if
+ * numExpectedOutput RDD's are not generated. Generated RDD's are collected
+ * and returned, represented as a list for each batch interval.
+ */
+ def runStreams[V](
+ ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = {
+ implicit val cm: ClassManifest[V] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput)
+ val out = new ArrayList[JList[V]]()
+ res.map(entry => out.append(new ArrayList[V](entry)))
+ out
+ }
+}
+
+object JavaTestUtils extends JavaTestBase {
+
+}
+
+object JavaCheckpointTestUtils extends JavaTestBase {
+ override def actuallyWait = true
+} \ No newline at end of file
diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties
new file mode 100644
index 0000000000..edfa1243fa
--- /dev/null
+++ b/streaming/src/test/resources/log4j.properties
@@ -0,0 +1,11 @@
+# Set everything to be logged to the file streaming/target/unit-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=false
+log4j.appender.file.file=streaming/target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.eclipse.jetty=WARN
+
diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
new file mode 100644
index 0000000000..bfdf32c73e
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
@@ -0,0 +1,218 @@
+package spark.streaming
+
+import spark.streaming.StreamingContext._
+import scala.runtime.RichInt
+import util.ManualClock
+
+class BasicOperationsSuite extends TestSuiteBase {
+
+ override def framework() = "BasicOperationsSuite"
+
+ after {
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
+ }
+
+ test("map") {
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ testOperation(
+ input,
+ (r: DStream[Int]) => r.map(_.toString),
+ input.map(_.map(_.toString))
+ )
+ }
+
+ test("flatmap") {
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ testOperation(
+ input,
+ (r: DStream[Int]) => r.flatMap(x => Seq(x, x * 2)),
+ input.map(_.flatMap(x => Array(x, x * 2)))
+ )
+ }
+
+ test("filter") {
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ testOperation(
+ input,
+ (r: DStream[Int]) => r.filter(x => (x % 2 == 0)),
+ input.map(_.filter(x => (x % 2 == 0)))
+ )
+ }
+
+ test("glom") {
+ assert(numInputPartitions === 2, "Number of input partitions has been changed from 2")
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ val output = Seq(
+ Seq( Seq(1, 2), Seq(3, 4) ),
+ Seq( Seq(5, 6), Seq(7, 8) ),
+ Seq( Seq(9, 10), Seq(11, 12) )
+ )
+ val operation = (r: DStream[Int]) => r.glom().map(_.toSeq)
+ testOperation(input, operation, output)
+ }
+
+ test("mapPartitions") {
+ assert(numInputPartitions === 2, "Number of input partitions has been changed from 2")
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ val output = Seq(Seq(3, 7), Seq(11, 15), Seq(19, 23))
+ val operation = (r: DStream[Int]) => r.mapPartitions(x => Iterator(x.reduce(_ + _)))
+ testOperation(input, operation, output, true)
+ }
+
+ test("groupByKey") {
+ testOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).groupByKey(),
+ Seq( Seq(("a", Seq(1, 1)), ("b", Seq(1))), Seq(("", Seq(1, 1))), Seq() ),
+ true
+ )
+ }
+
+ test("reduceByKey") {
+ testOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
+ Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ),
+ true
+ )
+ }
+
+ test("reduce") {
+ testOperation(
+ Seq(1 to 4, 5 to 8, 9 to 12),
+ (s: DStream[Int]) => s.reduce(_ + _),
+ Seq(Seq(10), Seq(26), Seq(42))
+ )
+ }
+
+ test("mapValues") {
+ testOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).mapValues(_ + 10),
+ Seq( Seq(("a", 12), ("b", 11)), Seq(("", 12)), Seq() ),
+ true
+ )
+ }
+
+ test("flatMapValues") {
+ testOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _).flatMapValues(x => Seq(x, x + 10)),
+ Seq( Seq(("a", 2), ("a", 12), ("b", 1), ("b", 11)), Seq(("", 2), ("", 12)), Seq() ),
+ true
+ )
+ }
+
+ test("cogroup") {
+ val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() )
+ val outputData = Seq(
+ Seq( ("a", (Seq(1, 1), Seq("x", "x"))), ("b", (Seq(1), Seq("x"))) ),
+ Seq( ("a", (Seq(1), Seq())), ("b", (Seq(), Seq("x"))), ("", (Seq(1), Seq("x"))) ),
+ Seq( ("", (Seq(1), Seq())) ),
+ Seq( )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x")))
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
+ test("join") {
+ val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") )
+ val outputData = Seq(
+ Seq( ("a", (1, "x")), ("b", (1, "x")) ),
+ Seq( ("", (1, "x")) ),
+ Seq( ),
+ Seq( )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ s1.map(x => (x,1)).join(s2.map(x => (x,"x")))
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
+ test("updateStateByKey") {
+ val inputData =
+ Seq(
+ Seq("a"),
+ Seq("a", "b"),
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq()
+ )
+
+ val outputData =
+ Seq(
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 3), ("b", 2), ("c", 1)),
+ Seq(("a", 4), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1))
+ )
+
+ val updateStateOperation = (s: DStream[String]) => {
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ Some(values.foldLeft(0)(_ + _) + state.getOrElse(0))
+ }
+ s.map(x => (x, 1)).updateStateByKey[Int](updateFunc)
+ }
+
+ testOperation(inputData, updateStateOperation, outputData, true)
+ }
+
+ test("forgetting of RDDs - map and window operations") {
+ assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second")
+
+ val input = (0 until 10).map(x => Seq(x, x + 1)).toSeq
+ val rememberDuration = Seconds(3)
+
+ assert(input.size === 10, "Number of inputs have changed")
+
+ def operation(s: DStream[Int]): DStream[(Int, Int)] = {
+ s.map(x => (x % 10, 1))
+ .window(Seconds(2), Seconds(1))
+ .window(Seconds(4), Seconds(2))
+ }
+
+ val ssc = setupStreams(input, operation _)
+ ssc.remember(rememberDuration)
+ runStreams[(Int, Int)](ssc, input.size, input.size / 2)
+
+ val windowedStream2 = ssc.graph.getOutputStreams().head.dependencies.head
+ val windowedStream1 = windowedStream2.dependencies.head
+ val mappedStream = windowedStream1.dependencies.head
+
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ assert(clock.time === Seconds(10).milliseconds)
+
+ // IDEALLY
+ // WindowedStream2 should remember till 7 seconds: 10, 8,
+ // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5
+ // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3,
+
+ // IN THIS TEST
+ // WindowedStream2 should remember till 7 seconds: 10, 8,
+ // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4
+ // MappedStream should remember till 7 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2
+
+ // WindowedStream2
+ assert(windowedStream2.generatedRDDs.contains(Time(10000)))
+ assert(windowedStream2.generatedRDDs.contains(Time(8000)))
+ assert(!windowedStream2.generatedRDDs.contains(Time(6000)))
+
+ // WindowedStream1
+ assert(windowedStream1.generatedRDDs.contains(Time(10000)))
+ assert(windowedStream1.generatedRDDs.contains(Time(4000)))
+ assert(!windowedStream1.generatedRDDs.contains(Time(3000)))
+
+ // MappedStream
+ assert(mappedStream.generatedRDDs.contains(Time(10000)))
+ assert(mappedStream.generatedRDDs.contains(Time(2000)))
+ assert(!mappedStream.generatedRDDs.contains(Time(1000)))
+ }
+}
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
new file mode 100644
index 0000000000..d2f32c189b
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -0,0 +1,210 @@
+package spark.streaming
+
+import spark.streaming.StreamingContext._
+import java.io.File
+import runtime.RichInt
+import org.scalatest.BeforeAndAfter
+import org.apache.commons.io.FileUtils
+import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import util.{Clock, ManualClock}
+
+class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
+
+ before {
+ FileUtils.deleteDirectory(new File(checkpointDir))
+ }
+
+ after {
+ if (ssc != null) ssc.stop()
+ FileUtils.deleteDirectory(new File(checkpointDir))
+
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
+ }
+
+ var ssc: StreamingContext = null
+
+ override def framework = "CheckpointSuite"
+
+ override def batchDuration = Milliseconds(500)
+
+ override def checkpointInterval = batchDuration
+
+ override def actuallyWait = true
+
+ test("basic stream+rdd recovery") {
+
+ assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second")
+ assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration")
+
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
+ val stateStreamCheckpointInterval = Seconds(1)
+
+ // this ensure checkpointing occurs at least once
+ val firstNumBatches = (stateStreamCheckpointInterval / batchDuration) * 2
+ val secondNumBatches = firstNumBatches
+
+ // Setup the streams
+ val input = (1 to 10).map(_ => Seq("a")).toSeq
+ val operation = (st: DStream[String]) => {
+ val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
+ Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
+ }
+ st.map(x => (x, 1))
+ .updateStateByKey[RichInt](updateFunc)
+ .checkpoint(stateStreamCheckpointInterval)
+ .map(t => (t._1, t._2.self))
+ }
+ var ssc = setupStreams(input, operation)
+ var stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head
+
+ // Run till a time such that at least one RDD in the stream should have been checkpointed,
+ // then check whether some RDD has been checkpointed or not
+ ssc.start()
+ runStreamsWithRealDelay(ssc, firstNumBatches)
+ logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.rdds.mkString(",\n") + "]")
+ assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before first failure")
+ stateStream.checkpointData.rdds.foreach {
+ case (time, data) => {
+ val file = new File(data.toString)
+ assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist")
+ }
+ }
+
+ // Run till a further time such that previous checkpoint files in the stream would be deleted
+ // and check whether the earlier checkpoint files are deleted
+ val checkpointFiles = stateStream.checkpointData.rdds.map(x => new File(x._2.toString))
+ runStreamsWithRealDelay(ssc, secondNumBatches)
+ checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted"))
+ ssc.stop()
+
+ // Restart stream computation using the checkpoint file and check whether
+ // checkpointed RDDs have been restored or not
+ ssc = new StreamingContext(checkpointDir)
+ stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head
+ logInfo("Restored data of state stream = \n[" + stateStream.generatedRDDs.mkString("\n") + "]")
+ assert(!stateStream.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from first failure")
+
+
+ // Run one batch to generate a new checkpoint file and check whether some RDD
+ // is present in the checkpoint data or not
+ ssc.start()
+ runStreamsWithRealDelay(ssc, 1)
+ assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before second failure")
+ stateStream.checkpointData.rdds.foreach {
+ case (time, data) => {
+ val file = new File(data.toString)
+ assert(file.exists(),
+ "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist")
+ }
+ }
+ ssc.stop()
+
+ // Restart stream computation from the new checkpoint file to see whether that file has
+ // correct checkpoint data
+ ssc = new StreamingContext(checkpointDir)
+ stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head
+ logInfo("Restored data of state stream = \n[" + stateStream.generatedRDDs.mkString("\n") + "]")
+ assert(!stateStream.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from second failure")
+
+ // Adjust manual clock time as if it is being restarted after a delay
+ System.setProperty("spark.streaming.manualClock.jump", (batchDuration.milliseconds * 7).toString)
+ ssc.start()
+ runStreamsWithRealDelay(ssc, 4)
+ ssc.stop()
+ System.clearProperty("spark.streaming.manualClock.jump")
+ ssc = null
+ }
+
+ test("map and reduceByKey") {
+ testCheckpointedOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
+ Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ),
+ 3
+ )
+ }
+
+ test("reduceByKeyAndWindowInv") {
+ val n = 10
+ val w = 4
+ val input = (1 to n).map(_ => Seq("a")).toSeq
+ val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4)))
+ val operation = (st: DStream[String]) => {
+ st.map(x => (x, 1))
+ .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration)
+ .checkpoint(batchDuration * 2)
+ }
+ testCheckpointedOperation(input, operation, output, 7)
+ }
+
+ test("updateStateByKey") {
+ val input = (1 to 10).map(_ => Seq("a")).toSeq
+ val output = (1 to 10).map(x => Seq(("a", x))).toSeq
+ val operation = (st: DStream[String]) => {
+ val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
+ Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
+ }
+ st.map(x => (x, 1))
+ .updateStateByKey[RichInt](updateFunc)
+ .checkpoint(batchDuration * 2)
+ .map(t => (t._1, t._2.self))
+ }
+ testCheckpointedOperation(input, operation, output, 7)
+ }
+
+ /**
+ * Tests a streaming operation under checkpointing, by restart the operation
+ * from checkpoint file and verifying whether the final output is correct.
+ * The output is assumed to have come from a reliable queue which an replay
+ * data as required.
+ */
+ def testCheckpointedOperation[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ initialNumBatches: Int
+ ) {
+
+ // Current code assumes that:
+ // number of inputs = number of outputs = number of batches to be run
+ val totalNumBatches = input.size
+ val nextNumBatches = totalNumBatches - initialNumBatches
+ val initialNumExpectedOutputs = initialNumBatches
+ val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs
+
+ // Do the computation for initial number of batches, create checkpoint file and quit
+ ssc = setupStreams[U, V](input, operation)
+ val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs)
+ verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
+ Thread.sleep(1000)
+
+ // Restart and complete the computation from checkpoint file
+ logInfo(
+ "\n-------------------------------------------\n" +
+ " Restarting stream computation " +
+ "\n-------------------------------------------\n"
+ )
+ ssc = new StreamingContext(checkpointDir)
+ val outputNew = runStreams[V](ssc, nextNumBatches, nextNumExpectedOutputs)
+ verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
+ ssc = null
+ }
+
+ /**
+ * Advances the manual clock on the streaming scheduler by given number of batches.
+ * It also wait for the expected amount of time for each batch.
+ */
+ def runStreamsWithRealDelay(ssc: StreamingContext, numBatches: Long) {
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ logInfo("Manual clock before advancing = " + clock.time)
+ for (i <- 1 to numBatches.toInt) {
+ clock.addToTime(batchDuration.milliseconds)
+ Thread.sleep(batchDuration.milliseconds)
+ }
+ logInfo("Manual clock after advancing = " + clock.time)
+ Thread.sleep(batchDuration.milliseconds)
+ }
+
+} \ No newline at end of file
diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala
new file mode 100644
index 0000000000..7493ac1207
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala
@@ -0,0 +1,191 @@
+package spark.streaming
+
+import org.scalatest.BeforeAndAfter
+import org.apache.commons.io.FileUtils
+import java.io.File
+import scala.runtime.RichInt
+import scala.util.Random
+import spark.streaming.StreamingContext._
+import collection.mutable.ArrayBuffer
+import spark.Logging
+
+/**
+ * This testsuite tests master failures at random times while the stream is running using
+ * the real clock.
+ */
+class FailureSuite extends TestSuiteBase with BeforeAndAfter {
+
+ before {
+ FileUtils.deleteDirectory(new File(checkpointDir))
+ }
+
+ after {
+ FailureSuite.reset()
+ FileUtils.deleteDirectory(new File(checkpointDir))
+
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
+ }
+
+ override def framework = "CheckpointSuite"
+
+ override def batchDuration = Milliseconds(500)
+
+ override def checkpointDir = "checkpoint"
+
+ override def checkpointInterval = batchDuration
+
+ test("multiple failures with updateStateByKey") {
+ val n = 30
+ // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ...
+ val input = (1 to n).map(i => (1 to i).map(_ =>"a").toSeq).toSeq
+ // Last output: [ (a, 465) ] for n=30
+ val lastOutput = Seq( ("a", (1 to n).reduce(_ + _)) )
+
+ val operation = (st: DStream[String]) => {
+ val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
+ Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
+ }
+ st.map(x => (x, 1))
+ .updateStateByKey[RichInt](updateFunc)
+ .checkpoint(Seconds(2))
+ .map(t => (t._1, t._2.self))
+ }
+
+ testOperationWithMultipleFailures(input, operation, lastOutput, n, n)
+ }
+
+ test("multiple failures with reduceByKeyAndWindow") {
+ val n = 30
+ val w = 100
+ assert(w > n, "Window should be much larger than the number of input sets in this test")
+ // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ...
+ val input = (1 to n).map(i => (1 to i).map(_ =>"a").toSeq).toSeq
+ // Last output: [ (a, 465) ]
+ val lastOutput = Seq( ("a", (1 to n).reduce(_ + _)) )
+
+ val operation = (st: DStream[String]) => {
+ st.map(x => (x, 1))
+ .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration)
+ .checkpoint(Seconds(2))
+ }
+
+ testOperationWithMultipleFailures(input, operation, lastOutput, n, n)
+ }
+
+
+ /**
+ * Tests stream operation with multiple master failures, and verifies whether the
+ * final set of output values is as expected or not. Checking the final value is
+ * proof that no intermediate data was lost due to master failures.
+ */
+ def testOperationWithMultipleFailures[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ lastExpectedOutput: Seq[V],
+ numBatches: Int,
+ numExpectedOutput: Int
+ ) {
+ var ssc = setupStreams[U, V](input, operation)
+ val mergedOutput = new ArrayBuffer[Seq[V]]()
+
+ var totalTimeRan = 0L
+ while(totalTimeRan <= numBatches * batchDuration.milliseconds * 2) {
+ new KillingThread(ssc, numBatches * batchDuration.milliseconds.toInt / 4).start()
+ val (output, timeRan) = runStreamsWithRealClock[V](ssc, numBatches, numExpectedOutput)
+
+ mergedOutput ++= output
+ totalTimeRan += timeRan
+ logInfo("New output = " + output)
+ logInfo("Merged output = " + mergedOutput)
+ logInfo("Total time spent = " + totalTimeRan)
+ val sleepTime = Random.nextInt(numBatches * batchDuration.milliseconds.toInt / 8)
+ logInfo(
+ "\n-------------------------------------------\n" +
+ " Restarting stream computation in " + sleepTime + " ms " +
+ "\n-------------------------------------------\n"
+ )
+ Thread.sleep(sleepTime)
+ FailureSuite.failed = false
+ ssc = new StreamingContext(checkpointDir)
+ }
+ ssc.stop()
+ ssc = null
+
+ // Verify whether the last output is the expected one
+ val lastOutput = mergedOutput(mergedOutput.lastIndexWhere(!_.isEmpty))
+ assert(lastOutput.toSet === lastExpectedOutput.toSet)
+ logInfo("Finished computation after " + FailureSuite.failureCount + " failures")
+ }
+
+ /**
+ * Runs the streams set up in `ssc` on real clock until the expected max number of
+ */
+ def runStreamsWithRealClock[V: ClassManifest](
+ ssc: StreamingContext,
+ numBatches: Int,
+ maxExpectedOutput: Int
+ ): (Seq[Seq[V]], Long) = {
+
+ System.clearProperty("spark.streaming.clock")
+
+ assert(numBatches > 0, "Number of batches to run stream computation is zero")
+ assert(maxExpectedOutput > 0, "Max expected outputs after " + numBatches + " is zero")
+ logInfo("numBatches = " + numBatches + ", maxExpectedOutput = " + maxExpectedOutput)
+
+ // Get the output buffer
+ val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
+ val output = outputStream.output
+ val waitTime = (batchDuration.milliseconds * (numBatches.toDouble + 0.5)).toLong
+ val startTime = System.currentTimeMillis()
+
+ try {
+ // Start computation
+ ssc.start()
+
+ // Wait until expected number of output items have been generated
+ while (output.size < maxExpectedOutput && System.currentTimeMillis() - startTime < waitTime && !FailureSuite.failed) {
+ logInfo("output.size = " + output.size + ", maxExpectedOutput = " + maxExpectedOutput)
+ Thread.sleep(100)
+ }
+ } catch {
+ case e: Exception => logInfo("Exception while running streams: " + e)
+ } finally {
+ ssc.stop()
+ }
+ val timeTaken = System.currentTimeMillis() - startTime
+ logInfo("" + output.size + " sets of output generated in " + timeTaken + " ms")
+ (output, timeTaken)
+ }
+
+
+}
+
+object FailureSuite {
+ var failed = false
+ var failureCount = 0
+
+ def reset() {
+ failed = false
+ failureCount = 0
+ }
+}
+
+class KillingThread(ssc: StreamingContext, maxKillWaitTime: Int) extends Thread with Logging {
+ initLogging()
+
+ override def run() {
+ var minKillWaitTime = if (FailureSuite.failureCount == 0) 3000 else 1000 // to allow the first checkpoint
+ val killWaitTime = minKillWaitTime + Random.nextInt(maxKillWaitTime)
+ logInfo("Kill wait time = " + killWaitTime)
+ Thread.sleep(killWaitTime.toLong)
+ logInfo(
+ "\n---------------------------------------\n" +
+ "Killing streaming context after " + killWaitTime + " ms" +
+ "\n---------------------------------------\n"
+ )
+ if (ssc != null) ssc.stop()
+ FailureSuite.failed = true
+ FailureSuite.failureCount += 1
+ }
+}
diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
new file mode 100644
index 0000000000..d7ba7a5d17
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala
@@ -0,0 +1,355 @@
+package spark.streaming
+
+import dstream.SparkFlumeEvent
+import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket}
+import java.io.{File, BufferedWriter, OutputStreamWriter}
+import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
+import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import util.ManualClock
+import spark.storage.StorageLevel
+import spark.Logging
+import scala.util.Random
+import org.apache.commons.io.FileUtils
+import org.scalatest.BeforeAndAfter
+import org.apache.flume.source.avro.AvroSourceProtocol
+import org.apache.flume.source.avro.AvroFlumeEvent
+import org.apache.flume.source.avro.Status
+import org.apache.avro.ipc.{specific, NettyTransceiver}
+import org.apache.avro.ipc.specific.SpecificRequestor
+import java.nio.ByteBuffer
+import collection.JavaConversions._
+import java.nio.charset.Charset
+
+class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
+
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
+ val testPort = 9999
+ var testServer: TestServer = null
+ var testDir: File = null
+
+ override def checkpointDir = "checkpoint"
+
+ after {
+ FileUtils.deleteDirectory(new File(checkpointDir))
+ if (testServer != null) {
+ testServer.stop()
+ testServer = null
+ }
+ if (testDir != null && testDir.exists()) {
+ FileUtils.deleteDirectory(testDir)
+ testDir = null
+ }
+
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
+ }
+
+ test("network input stream") {
+ // Start the server
+ testServer = new TestServer(testPort)
+ testServer.start()
+
+ // Set up the streaming context and input streams
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK)
+ val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]]
+ val outputStream = new TestOutputStream(networkStream, outputBuffer)
+ def output = outputBuffer.flatMap(x => x)
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ // Feed data to the server to send to the network receiver
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val input = Seq(1, 2, 3, 4, 5)
+ val expectedOutput = input.map(_.toString)
+ Thread.sleep(1000)
+ for (i <- 0 until input.size) {
+ testServer.send(input(i).toString + "\n")
+ Thread.sleep(500)
+ clock.addToTime(batchDuration.milliseconds)
+ }
+ Thread.sleep(1000)
+ logInfo("Stopping server")
+ testServer.stop()
+ logInfo("Stopping context")
+ ssc.stop()
+
+ // Verify whether data received was as expected
+ logInfo("--------------------------------")
+ logInfo("output.size = " + outputBuffer.size)
+ logInfo("output")
+ outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("expected output.size = " + expectedOutput.size)
+ logInfo("expected output")
+ expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+
+ // Verify whether all the elements received are as expected
+ // (whether the elements were received one in each interval is not verified)
+ assert(output.size === expectedOutput.size)
+ for (i <- 0 until output.size) {
+ assert(output(i) === expectedOutput(i))
+ }
+ }
+
+ test("network input stream with checkpoint") {
+ // Start the server
+ testServer = new TestServer(testPort)
+ testServer.start()
+
+ // Set up the streaming context and input streams
+ var ssc = new StreamingContext(master, framework, batchDuration)
+ ssc.checkpoint(checkpointDir, checkpointInterval)
+ val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK)
+ var outputStream = new TestOutputStream(networkStream, new ArrayBuffer[Seq[String]])
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ // Feed data to the server to send to the network receiver
+ var clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ for (i <- Seq(1, 2, 3)) {
+ testServer.send(i.toString + "\n")
+ Thread.sleep(100)
+ clock.addToTime(batchDuration.milliseconds)
+ }
+ Thread.sleep(500)
+ assert(outputStream.output.size > 0)
+ ssc.stop()
+
+ // Restart stream computation from checkpoint and feed more data to see whether
+ // they are being received and processed
+ logInfo("*********** RESTARTING ************")
+ ssc = new StreamingContext(checkpointDir)
+ ssc.start()
+ clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ for (i <- Seq(4, 5, 6)) {
+ testServer.send(i.toString + "\n")
+ Thread.sleep(100)
+ clock.addToTime(batchDuration.milliseconds)
+ }
+ Thread.sleep(500)
+ outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[String]]
+ assert(outputStream.output.size > 0)
+ ssc.stop()
+ }
+
+ test("flume input stream") {
+ // Set up the streaming context and input streams
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ val flumeStream = ssc.flumeStream("localhost", 33333, StorageLevel.MEMORY_AND_DISK)
+ val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]]
+ with SynchronizedBuffer[Seq[SparkFlumeEvent]]
+ val outputStream = new TestOutputStream(flumeStream, outputBuffer)
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val input = Seq(1, 2, 3, 4, 5)
+
+ val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", 33333));
+ val client = SpecificRequestor.getClient(
+ classOf[AvroSourceProtocol], transceiver);
+
+ for (i <- 0 until input.size) {
+ val event = new AvroFlumeEvent
+ event.setBody(ByteBuffer.wrap(input(i).toString.getBytes()))
+ event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header"))
+ client.append(event)
+ Thread.sleep(500)
+ clock.addToTime(batchDuration.milliseconds)
+ }
+
+ val startTime = System.currentTimeMillis()
+ while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
+ logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size)
+ Thread.sleep(100)
+ }
+ Thread.sleep(1000)
+ val timeTaken = System.currentTimeMillis() - startTime
+ assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
+ logInfo("Stopping context")
+ ssc.stop()
+
+ val decoder = Charset.forName("UTF-8").newDecoder()
+
+ assert(outputBuffer.size === input.length)
+ for (i <- 0 until outputBuffer.size) {
+ assert(outputBuffer(i).size === 1)
+ val str = decoder.decode(outputBuffer(i).head.event.getBody)
+ assert(str.toString === input(i).toString)
+ assert(outputBuffer(i).head.event.getHeaders.get("test") === "header")
+ }
+ }
+
+ test("file input stream") {
+
+ // Create a temporary directory
+ testDir = {
+ var temp = File.createTempFile(".temp.", Random.nextInt().toString)
+ temp.delete()
+ temp.mkdirs()
+ logInfo("Created temp dir " + temp)
+ temp
+ }
+
+ // Set up the streaming context and input streams
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ val filestream = ssc.textFileStream(testDir.toString)
+ val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
+ def output = outputBuffer.flatMap(x => x)
+ val outputStream = new TestOutputStream(filestream, outputBuffer)
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ // Create files in the temporary directory so that Spark Streaming can read data from it
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val input = Seq(1, 2, 3, 4, 5)
+ val expectedOutput = input.map(_.toString)
+ Thread.sleep(1000)
+ for (i <- 0 until input.size) {
+ FileUtils.writeStringToFile(new File(testDir, i.toString), input(i).toString + "\n")
+ Thread.sleep(500)
+ clock.addToTime(batchDuration.milliseconds)
+ //Thread.sleep(100)
+ }
+ val startTime = System.currentTimeMillis()
+ /*while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
+ logInfo("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size)
+ Thread.sleep(100)
+ }*/
+ Thread.sleep(1000)
+ val timeTaken = System.currentTimeMillis() - startTime
+ assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
+ logInfo("Stopping context")
+ ssc.stop()
+
+ // Verify whether data received by Spark Streaming was as expected
+ logInfo("--------------------------------")
+ logInfo("output.size = " + outputBuffer.size)
+ logInfo("output")
+ outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("expected output.size = " + expectedOutput.size)
+ logInfo("expected output")
+ expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+
+ // Verify whether all the elements received are as expected
+ // (whether the elements were received one in each interval is not verified)
+ assert(output.size === expectedOutput.size)
+ for (i <- 0 until output.size) {
+ assert(output(i).size === 1)
+ assert(output(i).head.toString === expectedOutput(i))
+ }
+ }
+
+ test("file input stream with checkpoint") {
+ // Create a temporary directory
+ testDir = {
+ var temp = File.createTempFile(".temp.", Random.nextInt().toString)
+ temp.delete()
+ temp.mkdirs()
+ logInfo("Created temp dir " + temp)
+ temp
+ }
+
+ // Set up the streaming context and input streams
+ var ssc = new StreamingContext(master, framework, batchDuration)
+ ssc.checkpoint(checkpointDir, checkpointInterval)
+ val filestream = ssc.textFileStream(testDir.toString)
+ var outputStream = new TestOutputStream(filestream, new ArrayBuffer[Seq[String]])
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ // Create files and advance manual clock to process them
+ var clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ Thread.sleep(1000)
+ for (i <- Seq(1, 2, 3)) {
+ FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n")
+ Thread.sleep(100)
+ clock.addToTime(batchDuration.milliseconds)
+ }
+ Thread.sleep(500)
+ logInfo("Output = " + outputStream.output.mkString(","))
+ assert(outputStream.output.size > 0)
+ ssc.stop()
+
+ // Restart stream computation from checkpoint and create more files to see whether
+ // they are being processed
+ logInfo("*********** RESTARTING ************")
+ ssc = new StreamingContext(checkpointDir)
+ ssc.start()
+ clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ Thread.sleep(500)
+ for (i <- Seq(4, 5, 6)) {
+ FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n")
+ Thread.sleep(100)
+ clock.addToTime(batchDuration.milliseconds)
+ }
+ Thread.sleep(500)
+ outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[String]]
+ logInfo("Output = " + outputStream.output.mkString(","))
+ assert(outputStream.output.size > 0)
+ ssc.stop()
+ }
+}
+
+
+class TestServer(port: Int) extends Logging {
+
+ val queue = new ArrayBlockingQueue[String](100)
+
+ val serverSocket = new ServerSocket(port)
+
+ val servingThread = new Thread() {
+ override def run() {
+ try {
+ while(true) {
+ logInfo("Accepting connections on port " + port)
+ val clientSocket = serverSocket.accept()
+ logInfo("New connection")
+ try {
+ clientSocket.setTcpNoDelay(true)
+ val outputStream = new BufferedWriter(new OutputStreamWriter(clientSocket.getOutputStream))
+
+ while(clientSocket.isConnected) {
+ val msg = queue.poll(100, TimeUnit.MILLISECONDS)
+ if (msg != null) {
+ outputStream.write(msg)
+ outputStream.flush()
+ logInfo("Message '" + msg + "' sent")
+ }
+ }
+ } catch {
+ case e: SocketException => logError("TestServer error", e)
+ } finally {
+ logInfo("Connection closed")
+ if (!clientSocket.isClosed) clientSocket.close()
+ }
+ }
+ } catch {
+ case ie: InterruptedException =>
+
+ } finally {
+ serverSocket.close()
+ }
+ }
+ }
+
+ def start() { servingThread.start() }
+
+ def send(msg: String) { queue.add(msg) }
+
+ def stop() { servingThread.interrupt() }
+}
+
+object TestServer {
+ def main(args: Array[String]) {
+ val s = new TestServer(9999)
+ s.start()
+ while(true) {
+ Thread.sleep(1000)
+ s.send("hello")
+ }
+ }
+}
diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
new file mode 100644
index 0000000000..49129f3964
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
@@ -0,0 +1,291 @@
+package spark.streaming
+
+import spark.streaming.dstream.{InputDStream, ForEachDStream}
+import spark.streaming.util.ManualClock
+
+import spark.{RDD, Logging}
+
+import collection.mutable.ArrayBuffer
+import collection.mutable.SynchronizedBuffer
+
+import java.io.{ObjectInputStream, IOException}
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+/**
+ * This is a input stream just for the testsuites. This is equivalent to a checkpointable,
+ * replayable, reliable message queue like Kafka. It requires a sequence as input, and
+ * returns the i_th element at the i_th batch unde manual clock.
+ */
+class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int)
+ extends InputDStream[T](ssc_) {
+
+ def start() {}
+
+ def stop() {}
+
+ def compute(validTime: Time): Option[RDD[T]] = {
+ logInfo("Computing RDD for time " + validTime)
+ val index = ((validTime - zeroTime) / slideDuration - 1).toInt
+ val selectedInput = if (index < input.size) input(index) else Seq[T]()
+ val rdd = ssc.sc.makeRDD(selectedInput, numPartitions)
+ logInfo("Created RDD " + rdd.id + " with " + selectedInput)
+ Some(rdd)
+ }
+}
+
+/**
+ * This is a output stream just for the testsuites. All the output is collected into a
+ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ */
+class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]])
+ extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
+ val collected = rdd.collect()
+ output += collected
+ }) {
+
+ // This is to clear the output buffer every it is read from a checkpoint
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ ois.defaultReadObject()
+ output.clear()
+ }
+}
+
+/**
+ * This is the base trait for Spark Streaming testsuites. This provides basic functionality
+ * to run user-defined set of input on user-defined stream operations, and verify the output.
+ */
+trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
+
+ def framework = "TestSuiteBase"
+
+ def master = "local[2]"
+
+ def batchDuration = Seconds(1)
+
+ def checkpointDir = "checkpoint"
+
+ def checkpointInterval = batchDuration
+
+ def numInputPartitions = 2
+
+ def maxWaitTimeMillis = 10000
+
+ def actuallyWait = false
+
+ /**
+ * Set up required DStreams to test the DStream operation using the two sequences
+ * of input collections.
+ */
+ def setupStreams[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V]
+ ): StreamingContext = {
+
+ // Create StreamingContext
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ if (checkpointDir != null) {
+ ssc.checkpoint(checkpointDir, checkpointInterval)
+ }
+
+ // Setup the stream computation
+ val inputStream = new TestInputStream(ssc, input, numInputPartitions)
+ val operatedStream = operation(inputStream)
+ val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]])
+ ssc.registerInputStream(inputStream)
+ ssc.registerOutputStream(outputStream)
+ ssc
+ }
+
+ /**
+ * Set up required DStreams to test the binary operation using the sequence
+ * of input collections.
+ */
+ def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ input1: Seq[Seq[U]],
+ input2: Seq[Seq[V]],
+ operation: (DStream[U], DStream[V]) => DStream[W]
+ ): StreamingContext = {
+
+ // Create StreamingContext
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ if (checkpointDir != null) {
+ ssc.checkpoint(checkpointDir, checkpointInterval)
+ }
+
+ // Setup the stream computation
+ val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions)
+ val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions)
+ val operatedStream = operation(inputStream1, inputStream2)
+ val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]])
+ ssc.registerInputStream(inputStream1)
+ ssc.registerInputStream(inputStream2)
+ ssc.registerOutputStream(outputStream)
+ ssc
+ }
+
+ /**
+ * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
+ * returns the collected output. It will wait until `numExpectedOutput` number of
+ * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
+ */
+ def runStreams[V: ClassManifest](
+ ssc: StreamingContext,
+ numBatches: Int,
+ numExpectedOutput: Int
+ ): Seq[Seq[V]] = {
+
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
+ assert(numBatches > 0, "Number of batches to run stream computation is zero")
+ assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero")
+ logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
+
+ // Get the output buffer
+ val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
+ val output = outputStream.output
+
+ try {
+ // Start computation
+ ssc.start()
+
+ // Advance manual clock
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ logInfo("Manual clock before advancing = " + clock.time)
+ if (actuallyWait) {
+ for (i <- 1 to numBatches) {
+ logInfo("Actually waiting for " + batchDuration)
+ clock.addToTime(batchDuration.milliseconds)
+ Thread.sleep(batchDuration.milliseconds)
+ }
+ } else {
+ clock.addToTime(numBatches * batchDuration.milliseconds)
+ }
+ logInfo("Manual clock after advancing = " + clock.time)
+
+ // Wait until expected number of output items have been generated
+ val startTime = System.currentTimeMillis()
+ while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
+ logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput)
+ Thread.sleep(100)
+ }
+ val timeTaken = System.currentTimeMillis() - startTime
+
+ assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
+ assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
+
+ Thread.sleep(500) // Give some time for the forgetting old RDDs to complete
+ } catch {
+ case e: Exception => e.printStackTrace(); throw e;
+ } finally {
+ ssc.stop()
+ }
+
+ output
+ }
+
+ /**
+ * Verify whether the output values after running a DStream operation
+ * is same as the expected output values, by comparing the output
+ * collections either as lists (order matters) or sets (order does not matter)
+ */
+ def verifyOutput[V: ClassManifest](
+ output: Seq[Seq[V]],
+ expectedOutput: Seq[Seq[V]],
+ useSet: Boolean
+ ) {
+ logInfo("--------------------------------")
+ logInfo("output.size = " + output.size)
+ logInfo("output")
+ output.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("expected output.size = " + expectedOutput.size)
+ logInfo("expected output")
+ expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+
+ // Match the output with the expected output
+ assert(output.size === expectedOutput.size, "Number of outputs do not match")
+ for (i <- 0 until output.size) {
+ if (useSet) {
+ assert(output(i).toSet === expectedOutput(i).toSet)
+ } else {
+ assert(output(i).toList === expectedOutput(i).toList)
+ }
+ }
+ logInfo("Output verified successfully")
+ }
+
+ /**
+ * Test unary DStream operation with a list of inputs, with number of
+ * batches to run same as the number of expected output values
+ */
+ def testOperation[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ useSet: Boolean = false
+ ) {
+ testOperation[U, V](input, operation, expectedOutput, -1, useSet)
+ }
+
+ /**
+ * Test unary DStream operation with a list of inputs
+ * @param input Sequence of input collections
+ * @param operation Binary DStream operation to be applied to the 2 inputs
+ * @param expectedOutput Sequence of expected output collections
+ * @param numBatches Number of batches to run the operation for
+ * @param useSet Compare the output values with the expected output values
+ * as sets (order matters) or as lists (order does not matter)
+ */
+ def testOperation[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ numBatches: Int,
+ useSet: Boolean
+ ) {
+ val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size
+ val ssc = setupStreams[U, V](input, operation)
+ val output = runStreams[V](ssc, numBatches_, expectedOutput.size)
+ verifyOutput[V](output, expectedOutput, useSet)
+ }
+
+ /**
+ * Test binary DStream operation with two lists of inputs, with number of
+ * batches to run same as the number of expected output values
+ */
+ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ input1: Seq[Seq[U]],
+ input2: Seq[Seq[V]],
+ operation: (DStream[U], DStream[V]) => DStream[W],
+ expectedOutput: Seq[Seq[W]],
+ useSet: Boolean
+ ) {
+ testOperation[U, V, W](input1, input2, operation, expectedOutput, -1, useSet)
+ }
+
+ /**
+ * Test binary DStream operation with two lists of inputs
+ * @param input1 First sequence of input collections
+ * @param input2 Second sequence of input collections
+ * @param operation Binary DStream operation to be applied to the 2 inputs
+ * @param expectedOutput Sequence of expected output collections
+ * @param numBatches Number of batches to run the operation for
+ * @param useSet Compare the output values with the expected output values
+ * as sets (order matters) or as lists (order does not matter)
+ */
+ def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ input1: Seq[Seq[U]],
+ input2: Seq[Seq[V]],
+ operation: (DStream[U], DStream[V]) => DStream[W],
+ expectedOutput: Seq[Seq[W]],
+ numBatches: Int,
+ useSet: Boolean
+ ) {
+ val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size
+ val ssc = setupStreams[U, V, W](input1, input2, operation)
+ val output = runStreams[W](ssc, numBatches_, expectedOutput.size)
+ verifyOutput[W](output, expectedOutput, useSet)
+ }
+}
diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
new file mode 100644
index 0000000000..0c6e928835
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
@@ -0,0 +1,300 @@
+package spark.streaming
+
+import spark.streaming.StreamingContext._
+import collection.mutable.ArrayBuffer
+
+class WindowOperationsSuite extends TestSuiteBase {
+
+ override def framework = "WindowOperationsSuite"
+
+ override def maxWaitTimeMillis = 20000
+
+ override def batchDuration = Seconds(1)
+
+ after {
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.master.port")
+ }
+
+ val largerSlideInput = Seq(
+ Seq(("a", 1)),
+ Seq(("a", 2)), // 1st window from here
+ Seq(("a", 3)),
+ Seq(("a", 4)), // 2nd window from here
+ Seq(("a", 5)),
+ Seq(("a", 6)), // 3rd window from here
+ Seq(),
+ Seq() // 4th window from here
+ )
+
+ val largerSlideReduceOutput = Seq(
+ Seq(("a", 3)),
+ Seq(("a", 10)),
+ Seq(("a", 18)),
+ Seq(("a", 11))
+ )
+
+
+ val bigInput = Seq(
+ Seq(("a", 1)),
+ Seq(("a", 1), ("b", 1)),
+ Seq(("a", 1), ("b", 1), ("c", 1)),
+ Seq(("a", 1), ("b", 1)),
+ Seq(("a", 1)),
+ Seq(),
+ Seq(("a", 1)),
+ Seq(("a", 1), ("b", 1)),
+ Seq(("a", 1), ("b", 1), ("c", 1)),
+ Seq(("a", 1), ("b", 1)),
+ Seq(("a", 1)),
+ Seq()
+ )
+
+ val bigGroupByOutput = Seq(
+ Seq(("a", Seq(1))),
+ Seq(("a", Seq(1, 1)), ("b", Seq(1))),
+ Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))),
+ Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))),
+ Seq(("a", Seq(1, 1)), ("b", Seq(1))),
+ Seq(("a", Seq(1))),
+ Seq(("a", Seq(1))),
+ Seq(("a", Seq(1, 1)), ("b", Seq(1))),
+ Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))),
+ Seq(("a", Seq(1, 1)), ("b", Seq(1, 1)), ("c", Seq(1))),
+ Seq(("a", Seq(1, 1)), ("b", Seq(1))),
+ Seq(("a", Seq(1)))
+ )
+
+
+ val bigReduceOutput = Seq(
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 1)),
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 1))
+ )
+
+ /*
+ The output of the reduceByKeyAndWindow with inverse reduce function is
+ different from the naive reduceByKeyAndWindow. Even if the count of a
+ particular key is 0, the key does not get eliminated from the RDDs of
+ ReducedWindowedDStream. This causes the number of keys in these RDDs to
+ increase forever. A more generalized version that allows elimination of
+ keys should be considered.
+ */
+
+ val bigReduceInvOutput = Seq(
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 1), ("c", 0)),
+ Seq(("a", 1), ("b", 0), ("c", 0)),
+ Seq(("a", 1), ("b", 0), ("c", 0)),
+ Seq(("a", 2), ("b", 1), ("c", 0)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 2), ("c", 1)),
+ Seq(("a", 2), ("b", 1), ("c", 0)),
+ Seq(("a", 1), ("b", 0), ("c", 0))
+ )
+
+ // Testing window operation
+
+ testWindow(
+ "basic window",
+ Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)),
+ Seq( Seq(0), Seq(0, 1), Seq(1, 2), Seq(2, 3), Seq(3, 4), Seq(4, 5))
+ )
+
+ testWindow(
+ "tumbling window",
+ Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)),
+ Seq( Seq(0, 1), Seq(2, 3), Seq(4, 5)),
+ Seconds(2),
+ Seconds(2)
+ )
+
+ testWindow(
+ "larger window",
+ Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)),
+ Seq( Seq(0, 1), Seq(0, 1, 2, 3), Seq(2, 3, 4, 5), Seq(4, 5)),
+ Seconds(4),
+ Seconds(2)
+ )
+
+ testWindow(
+ "non-overlapping window",
+ Seq( Seq(0), Seq(1), Seq(2), Seq(3), Seq(4), Seq(5)),
+ Seq( Seq(1, 2), Seq(4, 5)),
+ Seconds(2),
+ Seconds(3)
+ )
+
+ // Testing naive reduceByKeyAndWindow (without invertible function)
+
+ testReduceByKeyAndWindow(
+ "basic reduction",
+ Seq( Seq(("a", 1), ("a", 3)) ),
+ Seq( Seq(("a", 4)) )
+ )
+
+ testReduceByKeyAndWindow(
+ "key already in window and new value added into window",
+ Seq( Seq(("a", 1)), Seq(("a", 1)) ),
+ Seq( Seq(("a", 1)), Seq(("a", 2)) )
+ )
+
+ testReduceByKeyAndWindow(
+ "new key added into window",
+ Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ),
+ Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) )
+ )
+
+ testReduceByKeyAndWindow(
+ "key removed from window",
+ Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ),
+ Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq() )
+ )
+
+ testReduceByKeyAndWindow(
+ "larger slide time",
+ largerSlideInput,
+ largerSlideReduceOutput,
+ Seconds(4),
+ Seconds(2)
+ )
+
+ testReduceByKeyAndWindow("big test", bigInput, bigReduceOutput)
+
+ // Testing reduceByKeyAndWindow (with invertible reduce function)
+
+ testReduceByKeyAndWindowInv(
+ "basic reduction",
+ Seq(Seq(("a", 1), ("a", 3)) ),
+ Seq(Seq(("a", 4)) )
+ )
+
+ testReduceByKeyAndWindowInv(
+ "key already in window and new value added into window",
+ Seq( Seq(("a", 1)), Seq(("a", 1)) ),
+ Seq( Seq(("a", 1)), Seq(("a", 2)) )
+ )
+
+ testReduceByKeyAndWindowInv(
+ "new key added into window",
+ Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ),
+ Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) )
+ )
+
+ testReduceByKeyAndWindowInv(
+ "key removed from window",
+ Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ),
+ Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) )
+ )
+
+ testReduceByKeyAndWindowInv(
+ "larger slide time",
+ largerSlideInput,
+ largerSlideReduceOutput,
+ Seconds(4),
+ Seconds(2)
+ )
+
+ testReduceByKeyAndWindowInv("big test", bigInput, bigReduceInvOutput)
+
+ test("groupByKeyAndWindow") {
+ val input = bigInput
+ val expectedOutput = bigGroupByOutput.map(_.map(x => (x._1, x._2.toSet)))
+ val windowDuration = Seconds(2)
+ val slideDuration = Seconds(1)
+ val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
+ val operation = (s: DStream[(String, Int)]) => {
+ s.groupByKeyAndWindow(windowDuration, slideDuration)
+ .map(x => (x._1, x._2.toSet))
+ .persist()
+ }
+ testOperation(input, operation, expectedOutput, numBatches, true)
+ }
+
+ test("countByWindow") {
+ val input = Seq(Seq(1), Seq(1), Seq(1, 2), Seq(0), Seq(), Seq() )
+ val expectedOutput = Seq( Seq(1), Seq(2), Seq(3), Seq(3), Seq(1), Seq(0))
+ val windowDuration = Seconds(2)
+ val slideDuration = Seconds(1)
+ val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
+ val operation = (s: DStream[Int]) => {
+ s.countByWindow(windowDuration, slideDuration).map(_.toInt)
+ }
+ testOperation(input, operation, expectedOutput, numBatches, true)
+ }
+
+ test("countByKeyAndWindow") {
+ val input = Seq(Seq(("a", 1)), Seq(("b", 1), ("b", 2)), Seq(("a", 10), ("b", 20)))
+ val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3)))
+ val windowDuration = Seconds(2)
+ val slideDuration = Seconds(1)
+ val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
+ val operation = (s: DStream[(String, Int)]) => {
+ s.countByKeyAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toInt))
+ }
+ testOperation(input, operation, expectedOutput, numBatches, true)
+ }
+
+
+ // Helper functions
+
+ def testWindow(
+ name: String,
+ input: Seq[Seq[Int]],
+ expectedOutput: Seq[Seq[Int]],
+ windowDuration: Duration = Seconds(2),
+ slideDuration: Duration = Seconds(1)
+ ) {
+ test("window - " + name) {
+ val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
+ val operation = (s: DStream[Int]) => s.window(windowDuration, slideDuration)
+ testOperation(input, operation, expectedOutput, numBatches, true)
+ }
+ }
+
+ def testReduceByKeyAndWindow(
+ name: String,
+ input: Seq[Seq[(String, Int)]],
+ expectedOutput: Seq[Seq[(String, Int)]],
+ windowDuration: Duration = Seconds(2),
+ slideDuration: Duration = Seconds(1)
+ ) {
+ test("reduceByKeyAndWindow - " + name) {
+ val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
+ val operation = (s: DStream[(String, Int)]) => {
+ s.reduceByKeyAndWindow(_ + _, windowDuration, slideDuration).persist()
+ }
+ testOperation(input, operation, expectedOutput, numBatches, true)
+ }
+ }
+
+ def testReduceByKeyAndWindowInv(
+ name: String,
+ input: Seq[Seq[(String, Int)]],
+ expectedOutput: Seq[Seq[(String, Int)]],
+ windowDuration: Duration = Seconds(2),
+ slideDuration: Duration = Seconds(1)
+ ) {
+ test("reduceByKeyAndWindowInv - " + name) {
+ val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
+ val operation = (s: DStream[(String, Int)]) => {
+ s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration)
+ .persist()
+ .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing
+ }
+ testOperation(input, operation, expectedOutput, numBatches, true)
+ }
+ }
+}