aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaymond Liu <raymond.liu@intel.com>2013-11-12 15:14:21 +0800
committerRaymond Liu <raymond.liu@intel.com>2013-11-13 16:55:11 +0800
commit0f2e3c6e31d56c627ff81cdc93289a7c7cb2ec16 (patch)
tree60f01110b170ff72347e1ae6209f898712578ed3
parent5429d62dfa16305eb23d67dfe38172803c80db65 (diff)
parent3d4ad84b63e440fd3f4b3edb1b120ff7c14a42d1 (diff)
downloadspark-0f2e3c6e31d56c627ff81cdc93289a7c7cb2ec16.tar.gz
spark-0f2e3c6e31d56c627ff81cdc93289a7c7cb2ec16.tar.bz2
spark-0f2e3c6e31d56c627ff81cdc93289a7c7cb2ec16.zip
Merge branch 'master' into scala-2.10
-rw-r--r--README.md2
-rwxr-xr-xbin/compute-classpath.sh22
-rwxr-xr-xbin/slaves.sh19
-rwxr-xr-xbin/spark-daemon.sh21
-rwxr-xr-xbin/spark-daemons.sh2
-rwxr-xr-xbin/stop-slaves.sh2
-rw-r--r--core/pom.xml4
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java3
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java23
-rwxr-xr-xcore/src/main/java/org/apache/spark/network/netty/PathResolver.java11
-rw-r--r--core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala17
-rw-r--r--core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala250
-rw-r--r--core/src/main/scala/org/apache/spark/InterruptibleIterator.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala169
-rw-r--r--core/src/main/scala/org/apache/spark/ShuffleFetcher.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala229
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java10
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function.java2
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function3.java36
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java2
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java5
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala1058
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala410
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala247
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala603
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala420
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/Client.scala80
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala53
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala90
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala45
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala232
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala53
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala203
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala42
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala136
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala181
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala (renamed from core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala)38
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala167
-rw-r--r--core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/package.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala123
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala123
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala (renamed from core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala)12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala79
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala106
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala200
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala676
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Pool.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala76
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala63
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala106
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala)26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala)64
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala66
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala196
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockException.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala103
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockInfo.scala81
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala628
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessage.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala142
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockStore.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala151
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala280
-rw-r--r--core/src/main/scala/org/apache/spark/storage/FileSegment.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/storage/MemoryStore.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala200
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala86
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageUtils.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala105
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala230
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala45
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/BitSet.scala103
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala154
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala272
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala128
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala53
-rw-r--r--core/src/test/scala/org/apache/spark/BroadcastSuite.scala52
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala21
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/FileServerSuite.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java23
-rw-r--r--core/src/test/scala/org/apache/spark/JobCancellationSuite.scala209
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala176
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala35
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala136
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala49
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala28
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala114
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala102
-rw-r--r--core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala84
-rw-r--r--core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala154
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala73
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala148
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala145
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala90
-rw-r--r--docker/README.md5
-rwxr-xr-xdocker/build22
-rw-r--r--docker/spark-test/README.md11
-rw-r--r--docker/spark-test/base/Dockerfile38
-rwxr-xr-xdocker/spark-test/build22
-rw-r--r--docker/spark-test/master/Dockerfile21
-rwxr-xr-xdocker/spark-test/master/default_cmd22
-rw-r--r--docker/spark-test/worker/Dockerfile22
-rwxr-xr-xdocker/spark-test/worker/default_cmd22
-rw-r--r--docs/cluster-overview.md14
-rw-r--r--docs/configuration.md10
-rw-r--r--docs/ec2-scripts.md2
-rw-r--r--docs/python-programming-guide.md11
-rw-r--r--docs/running-on-yarn.md9
-rw-r--r--docs/scala-programming-guide.md6
-rw-r--r--docs/spark-standalone.md75
-rw-r--r--docs/streaming-programming-guide.md9
-rw-r--r--docs/tuning.md2
-rwxr-xr-xec2/spark_ec2.py2
-rw-r--r--examples/pom.xml36
-rw-r--r--examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java98
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala15
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala3
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkPi.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala28
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala107
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala13
-rw-r--r--pom.xml126
-rw-r--r--project/SparkBuild.scala35
-rw-r--r--python/pyspark/accumulators.py13
-rw-r--r--python/pyspark/context.py50
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala36
-rwxr-xr-xspark-class15
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jarbin1358063 -> 0 bytes
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha11
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom9
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha11
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml12
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha11
-rw-r--r--streaming/pom.xml57
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStream.scala55
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala12
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala155
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala52
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala8
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala97
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala186
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala108
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala63
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/MQTTInputDStream.scala110
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala14
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala20
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala4
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java425
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala36
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala141
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala4
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala8
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala61
-rw-r--r--tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala4
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala55
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala186
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala25
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala59
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala4
259 files changed, 10776 insertions, 5146 deletions
diff --git a/README.md b/README.md
index 7d3f9d4845..dd7e790534 100644
--- a/README.md
+++ b/README.md
@@ -68,7 +68,7 @@ described below.
When developing a Spark application, specify the Hadoop version by adding the
"hadoop-client" artifact to your project's dependencies. For example, if you're
-using Hadoop 1.0.1 and build your application using SBT, add this entry to
+using Hadoop 1.2.1 and build your application using SBT, add this entry to
`libraryDependencies`:
"org.apache.hadoop" % "hadoop-client" % "1.2.1"
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 4fe3d0ef3a..40555089fc 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -32,12 +32,26 @@ fi
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf"
-if [ -f "$FWDIR/RELEASE" ]; then
- ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
+
+# First check if we have a dependencies jar. If so, include binary classes with the deps jar
+if [ -f "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar ]; then
+ CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
+
+ DEPS_ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar`
+ CLASSPATH="$CLASSPATH:$DEPS_ASSEMBLY_JAR"
else
- ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
+ # Else use spark-assembly jar from either RELEASE or assembly directory
+ if [ -f "$FWDIR/RELEASE" ]; then
+ ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
+ else
+ ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
+ fi
+ CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
fi
-CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1
if [[ $SPARK_TESTING == 1 ]]; then
diff --git a/bin/slaves.sh b/bin/slaves.sh
index 752565b759..c367c2fd8e 100755
--- a/bin/slaves.sh
+++ b/bin/slaves.sh
@@ -28,7 +28,7 @@
# SPARK_SSH_OPTS Options passed to ssh when running remote commands.
##
-usage="Usage: slaves.sh [--config confdir] command..."
+usage="Usage: slaves.sh [--config <conf-dir>] command..."
# if no args specified, show usage
if [ $# -le 0 ]; then
@@ -46,6 +46,23 @@ bin=`cd "$bin"; pwd`
# spark-env.sh. Save it here.
HOSTLIST=$SPARK_SLAVES
+# Check if --config is passed as an argument. It is an optional parameter.
+# Exit if the argument is not a directory.
+if [ "$1" == "--config" ]
+then
+ shift
+ conf_dir=$1
+ if [ ! -d "$conf_dir" ]
+ then
+ echo "ERROR : $conf_dir is not a directory"
+ echo $usage
+ exit 1
+ else
+ export SPARK_CONF_DIR=$conf_dir
+ fi
+ shift
+fi
+
if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then
. "${SPARK_CONF_DIR}/spark-env.sh"
fi
diff --git a/bin/spark-daemon.sh b/bin/spark-daemon.sh
index 5bfe967fbf..a0c0d44b58 100755
--- a/bin/spark-daemon.sh
+++ b/bin/spark-daemon.sh
@@ -29,7 +29,7 @@
# SPARK_NICENESS The scheduling priority for daemons. Defaults to 0.
##
-usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <spark-instance-number> <args...>"
+usage="Usage: spark-daemon.sh [--config <conf-dir>] (start|stop) <spark-command> <spark-instance-number> <args...>"
# if no args specified, show usage
if [ $# -le 1 ]; then
@@ -43,6 +43,25 @@ bin=`cd "$bin"; pwd`
. "$bin/spark-config.sh"
# get arguments
+
+# Check if --config is passed as an argument. It is an optional parameter.
+# Exit if the argument is not a directory.
+
+if [ "$1" == "--config" ]
+then
+ shift
+ conf_dir=$1
+ if [ ! -d "$conf_dir" ]
+ then
+ echo "ERROR : $conf_dir is not a directory"
+ echo $usage
+ exit 1
+ else
+ export SPARK_CONF_DIR=$conf_dir
+ fi
+ shift
+fi
+
startStop=$1
shift
command=$1
diff --git a/bin/spark-daemons.sh b/bin/spark-daemons.sh
index 354eb905a1..64286cb2da 100755
--- a/bin/spark-daemons.sh
+++ b/bin/spark-daemons.sh
@@ -19,7 +19,7 @@
# Run a Spark command on all slave hosts.
-usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command instance-number args..."
+usage="Usage: spark-daemons.sh [--config <conf-dir>] [start|stop] command instance-number args..."
# if no args specified, show usage
if [ $# -le 1 ]; then
diff --git a/bin/stop-slaves.sh b/bin/stop-slaves.sh
index 03e416a132..fcb8555d4e 100755
--- a/bin/stop-slaves.sh
+++ b/bin/stop-slaves.sh
@@ -17,8 +17,6 @@
# limitations under the License.
#
-# Starts the master on the machine this script is executed on.
-
bin=`dirname "$0"`
bin=`cd "$bin"; pwd`
diff --git a/core/pom.xml b/core/pom.xml
index 595240b5e5..468dd71249 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -49,6 +49,10 @@
<artifactId>avro-ipc</artifactId>
</dependency>
<dependency>
+ <groupId>org.apache.zookeeper</groupId>
+ <artifactId>zookeeper</artifactId>
+ </dependency>
+ <dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
</dependency>
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java
index c4aa2669e0..8a09210245 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java
@@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandlerAdapter;
+import org.apache.spark.storage.BlockId;
abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
@@ -33,7 +34,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
}
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
- public abstract void handleError(String blockId);
+ public abstract void handleError(BlockId blockId);
@Override
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
index d3d57a0255..172c6e4b1c 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
@@ -24,6 +24,8 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.FileSegment;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@@ -34,41 +36,36 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
}
@Override
- public void messageReceived(ChannelHandlerContext ctx, String blockId) {
- String path = pResolver.getAbsolutePath(blockId);
- // if getFilePath returns null, close the channel
- if (path == null) {
+ public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
+ BlockId blockId = BlockId.apply(blockIdString);
+ FileSegment fileSegment = pResolver.getBlockLocation(blockId);
+ // if getBlockLocation returns null, close the channel
+ if (fileSegment == null) {
//ctx.close();
return;
}
- File file = new File(path);
+ File file = fileSegment.file();
if (file.exists()) {
if (!file.isFile()) {
- //logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
- long length = file.length();
+ long length = fileSegment.length();
if (length > Integer.MAX_VALUE || length <= 0) {
- //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
- //logger.info("Sending block "+blockId+" filelen = "+len);
- //logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
- .getChannel(), 0, file.length()));
+ .getChannel(), fileSegment.offset(), fileSegment.length()));
} catch (Exception e) {
- //logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
- //logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();
diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
index 94c034cad0..9f7ced44cf 100755
--- a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
+++ b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
@@ -17,13 +17,10 @@
package org.apache.spark.network.netty;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.FileSegment;
public interface PathResolver {
- /**
- * Get the absolute path of the file
- *
- * @param fileId
- * @return the absolute path of file
- */
- public String getAbsolutePath(String fileId);
+ /** Get the file segment in which the given block resides. */
+ public FileSegment getBlockLocation(BlockId blockId);
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
index f87460039b..0c47afae54 100644
--- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
@@ -17,20 +17,29 @@
package org.apache.hadoop.mapred
+private[apache]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext");
- val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID])
+ val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
+ "org.apache.hadoop.mapred.JobContext")
+ val ctor = klass.getDeclaredConstructor(classOf[JobConf],
+ classOf[org.apache.hadoop.mapreduce.JobID])
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext")
+ val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
+ "org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
- def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
+ def newTaskAttemptID(
+ jtIdentifier: String,
+ jobId: Int,
+ isMap: Boolean,
+ taskId: Int,
+ attemptId: Int) = {
new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
index 93180307fa..32429f01ac 100644
--- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
+++ b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
@@ -17,9 +17,10 @@
package org.apache.hadoop.mapreduce
-import org.apache.hadoop.conf.Configuration
import java.lang.{Integer => JInteger, Boolean => JBoolean}
+import org.apache.hadoop.conf.Configuration
+private[apache]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
@@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil {
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
- def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
- val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID");
+ def newTaskAttemptID(
+ jtIdentifier: String,
+ jobId: Int,
+ isMap: Boolean,
+ taskId: Int,
+ attemptId: Int) = {
+ val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
- // first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN)
+ // First, attempt to use the old-style constructor that takes a boolean isMap
+ // (not available in YARN)
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean],
- classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new
- JInteger(attemptId)).asInstanceOf[TaskAttemptID]
+ classOf[Int], classOf[Int])
+ ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId),
+ new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
} catch {
case exc: NoSuchMethodException => {
- // failed, look for the new ctor that takes a TaskType (not available in 1.x)
- val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]]
- val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE")
+ // If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
+ val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
+ .asInstanceOf[Class[Enum[_]]]
+ val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
+ taskTypeClass, if(isMap) "MAP" else "REDUCE")
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass,
classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new
- JInteger(attemptId)).asInstanceOf[TaskAttemptID]
+ ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId),
+ new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 3ef402926e..1a2ec55876 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -17,43 +17,42 @@
package org.apache.spark
-import java.util.{HashMap => JHashMap}
+import org.apache.spark.util.AppendOnlyMap
-import scala.collection.JavaConversions._
-
-/** A set of functions used to aggregate data.
- *
- * @param createCombiner function to create the initial value of the aggregation.
- * @param mergeValue function to merge a new value into the aggregation result.
- * @param mergeCombiners function to merge outputs from multiple mergeValue function.
- */
+/**
+ * A set of functions used to aggregate data.
+ *
+ * @param createCombiner function to create the initial value of the aggregation.
+ * @param mergeValue function to merge a new value into the aggregation result.
+ * @param mergeCombiners function to merge outputs from multiple mergeValue function.
+ */
case class Aggregator[K, V, C] (
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
- val combiners = new JHashMap[K, C]
- for (kv <- iter) {
- val oldC = combiners.get(kv._1)
- if (oldC == null) {
- combiners.put(kv._1, createCombiner(kv._2))
- } else {
- combiners.put(kv._1, mergeValue(oldC, kv._2))
- }
+ val combiners = new AppendOnlyMap[K, C]
+ var kv: Product2[K, V] = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
+ }
+ while (iter.hasNext) {
+ kv = iter.next()
+ combiners.changeValue(kv._1, update)
}
combiners.iterator
}
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
- val combiners = new JHashMap[K, C]
- iter.foreach { case(k, c) =>
- val oldC = combiners.get(k)
- if (oldC == null) {
- combiners.put(k, c)
- } else {
- combiners.put(k, mergeCombiners(oldC, c))
- }
+ val combiners = new AppendOnlyMap[K, C]
+ var kc: (K, C) = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
+ }
+ while (iter.hasNext) {
+ kc = iter.next()
+ combiners.changeValue(kc._1, update)
}
combiners.iterator
}
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
index 908ff56a6b..d9ed572da6 100644
--- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
@@ -22,13 +22,17 @@ import scala.collection.mutable.HashMap
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
import org.apache.spark.serializer.Serializer
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer)
+ override def fetch[T](
+ shuffleId: Int,
+ reduceId: Int,
+ context: TaskContext,
+ serializer: Serializer)
: Iterator[T] =
{
@@ -45,12 +49,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
- (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
+ (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
- def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = {
+ def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
@@ -58,9 +62,8 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
block.asInstanceOf[Iterator[T]]
}
case None => {
- val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
- case regex(shufId, mapId, _) =>
+ case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ =>
@@ -74,7 +77,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)
- CompletionIterator[T, Iterator[T]](itr, {
+ val completionIter = CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
@@ -83,7 +86,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
- metrics.shuffleReadMetrics = Some(shuffleMetrics)
+ context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
})
+
+ new InterruptibleIterator[T](context, completionIter)
}
}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 4cf7eb96da..519ecde50a 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -18,7 +18,7 @@
package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashSet}
-import org.apache.spark.storage.{BlockManager, StorageLevel}
+import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, RDDBlockId}
import org.apache.spark.rdd.RDD
@@ -28,17 +28,17 @@ import org.apache.spark.rdd.RDD
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
/** Keys of RDD splits that are being computed/loaded. */
- private val loading = new HashSet[String]()
+ private val loading = new HashSet[RDDBlockId]()
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
- val key = "rdd_%d_%d".format(rdd.id, split.index)
+ val key = RDDBlockId(rdd.id, split.index)
logDebug("Looking for partition " + key)
blockManager.get(key) match {
case Some(values) =>
// Partition is already materialized, so just return its values
- return values.asInstanceOf[Iterator[T]]
+ return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
// Mark the split as loading (unless someone else marks it first)
@@ -56,7 +56,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
- return values.asInstanceOf[Iterator[T]]
+ return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key))
loading.add(key)
@@ -73,7 +73,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
if (context.runningLocally) { return computedValues }
val elements = new ArrayBuffer[Any]
elements ++= computedValues
- blockManager.put(key, elements, storageLevel, true)
+ blockManager.put(key, elements, storageLevel, tellMaster = true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
new file mode 100644
index 0000000000..1ad9240cfa
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.concurrent._
+import scala.concurrent.duration.Duration
+import scala.util.Try
+
+import org.apache.spark.scheduler.{JobSucceeded, JobWaiter}
+import org.apache.spark.scheduler.JobFailed
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * A future for the result of an action. This is an extension of the Scala Future interface to
+ * support cancellation.
+ */
+trait FutureAction[T] extends Future[T] {
+ // Note that we redefine methods of the Future trait here explicitly so we can specify a different
+ // documentation (with reference to the word "action").
+
+ /**
+ * Cancels the execution of this action.
+ */
+ def cancel()
+
+ /**
+ * Blocks until this action completes.
+ * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
+ * for unbounded waiting, or a finite positive duration
+ * @return this FutureAction
+ */
+ override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type
+
+ /**
+ * Awaits and returns the result (of type T) of this action.
+ * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
+ * for unbounded waiting, or a finite positive duration
+ * @throws Exception exception during action execution
+ * @return the result value if the action is completed within the specific maximum wait time
+ */
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T
+
+ /**
+ * When this action is completed, either through an exception, or a value, applies the provided
+ * function.
+ */
+ def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext)
+
+ /**
+ * Returns whether the action has already been completed with a value or an exception.
+ */
+ override def isCompleted: Boolean
+
+ /**
+ * The value of this Future.
+ *
+ * If the future is not completed the returned value will be None. If the future is completed
+ * the value will be Some(Success(t)) if it contains a valid result, or Some(Failure(error)) if
+ * it contains an exception.
+ */
+ override def value: Option[Try[T]]
+
+ /**
+ * Blocks and returns the result of this job.
+ */
+ @throws(classOf[Exception])
+ def get(): T = Await.result(this, Duration.Inf)
+}
+
+
+/**
+ * The future holding the result of an action that triggers a single job. Examples include
+ * count, collect, reduce.
+ */
+class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
+ extends FutureAction[T] {
+
+ override def cancel() {
+ jobWaiter.cancel()
+ }
+
+ override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
+ if (!atMost.isFinite()) {
+ awaitResult()
+ } else {
+ val finishTime = System.currentTimeMillis() + atMost.toMillis
+ while (!isCompleted) {
+ val time = System.currentTimeMillis()
+ if (time >= finishTime) {
+ throw new TimeoutException
+ } else {
+ jobWaiter.wait(finishTime - time)
+ }
+ }
+ }
+ this
+ }
+
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T = {
+ ready(atMost)(permit)
+ awaitResult() match {
+ case scala.util.Success(res) => res
+ case scala.util.Failure(e) => throw e
+ }
+ }
+
+ override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
+ executor.execute(new Runnable {
+ override def run() {
+ func(awaitResult())
+ }
+ })
+ }
+
+ override def isCompleted: Boolean = jobWaiter.jobFinished
+
+ override def value: Option[Try[T]] = {
+ if (jobWaiter.jobFinished) {
+ Some(awaitResult())
+ } else {
+ None
+ }
+ }
+
+ private def awaitResult(): Try[T] = {
+ jobWaiter.awaitResult() match {
+ case JobSucceeded => scala.util.Success(resultFunc)
+ case JobFailed(e: Exception, _) => scala.util.Failure(e)
+ }
+ }
+}
+
+
+/**
+ * A FutureAction for actions that could trigger multiple Spark jobs. Examples include take,
+ * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
+ * action thread if it is being blocked by a job.
+ */
+class ComplexFutureAction[T] extends FutureAction[T] {
+
+ // Pointer to the thread that is executing the action. It is set when the action is run.
+ @volatile private var thread: Thread = _
+
+ // A flag indicating whether the future has been cancelled. This is used in case the future
+ // is cancelled before the action was even run (and thus we have no thread to interrupt).
+ @volatile private var _cancelled: Boolean = false
+
+ // A promise used to signal the future.
+ private val p = promise[T]()
+
+ override def cancel(): Unit = this.synchronized {
+ _cancelled = true
+ if (thread != null) {
+ thread.interrupt()
+ }
+ }
+
+ /**
+ * Executes some action enclosed in the closure. To properly enable cancellation, the closure
+ * should use runJob implementation in this promise. See takeAsync for example.
+ */
+ def run(func: => T)(implicit executor: ExecutionContext): this.type = {
+ scala.concurrent.future {
+ thread = Thread.currentThread
+ try {
+ p.success(func)
+ } catch {
+ case e: Exception => p.failure(e)
+ } finally {
+ thread = null
+ }
+ }
+ this
+ }
+
+ /**
+ * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
+ * to enable cancellation.
+ */
+ def runJob[T, U, R](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ partitions: Seq[Int],
+ resultHandler: (Int, U) => Unit,
+ resultFunc: => R) {
+ // If the action hasn't been cancelled yet, submit the job. The check and the submitJob
+ // command need to be in an atomic block.
+ val job = this.synchronized {
+ if (!cancelled) {
+ rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
+ } else {
+ throw new SparkException("Action has been cancelled")
+ }
+ }
+
+ // Wait for the job to complete. If the action is cancelled (with an interrupt),
+ // cancel the job and stop the execution. This is not in a synchronized block because
+ // Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
+ try {
+ Await.ready(job, Duration.Inf)
+ } catch {
+ case e: InterruptedException =>
+ job.cancel()
+ throw new SparkException("Action has been cancelled")
+ }
+ }
+
+ /**
+ * Returns whether the promise has been cancelled.
+ */
+ def cancelled: Boolean = _cancelled
+
+ @throws(classOf[InterruptedException])
+ @throws(classOf[scala.concurrent.TimeoutException])
+ override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = {
+ p.future.ready(atMost)(permit)
+ this
+ }
+
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T = {
+ p.future.result(atMost)(permit)
+ }
+
+ override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = {
+ p.future.onComplete(func)(executor)
+ }
+
+ override def isCompleted: Boolean = p.isCompleted
+
+ override def value: Option[Try[T]] = p.future.value
+}
diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
new file mode 100644
index 0000000000..56e0b8d2c0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * An iterator that wraps around an existing iterator to provide task killing functionality.
+ * It works by checking the interrupted flag in TaskContext.
+ */
+class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
+ extends Iterator[T] {
+
+ def hasNext: Boolean = !context.interrupted && delegate.hasNext
+
+ def next(): T = delegate.next()
+}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 1afb1870f1..035942ad39 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -20,7 +20,6 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
@@ -34,7 +33,7 @@ import scala.concurrent.duration._
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait MapOutputTrackerMessage
@@ -42,11 +41,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
-private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
+private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
+ extends Actor with Logging {
def receive = {
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
- sender ! tracker.getSerializedLocations(shuffleId)
+ sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
@@ -62,22 +62,19 @@ private[spark] class MapOutputTracker extends Logging {
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
- private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
- private var epoch: Long = 0
- private val epochLock = new java.lang.Object
+ protected var epoch: Long = 0
+ protected val epochLock = new java.lang.Object
- // Cache a serialized version of the output statuses for each shuffle to send them out faster
- var cacheEpoch = epoch
- private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
-
- val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
+ private val metadataCleaner =
+ new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
- def askTracker(message: Any): Any = {
+ private def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
@@ -88,50 +85,12 @@ private[spark] class MapOutputTracker extends Logging {
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
- def communicate(message: Any) {
+ private def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
- def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
- throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
- }
- }
-
- def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
- var array = mapStatuses(shuffleId)
- array.synchronized {
- array(mapId) = status
- }
- }
-
- def registerMapOutputs(
- shuffleId: Int,
- statuses: Array[MapStatus],
- changeEpoch: Boolean = false) {
- mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
- if (changeEpoch) {
- incrementEpoch()
- }
- }
-
- def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var arrayOpt = mapStatuses.get(shuffleId)
- if (arrayOpt.isDefined && arrayOpt.get != null) {
- var array = arrayOpt.get
- array.synchronized {
- if (array(mapId) != null && array(mapId).location == bmAddress) {
- array(mapId) = null
- }
- }
- incrementEpoch()
- } else {
- throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
- }
- }
-
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]
@@ -170,7 +129,7 @@ private[spark] class MapOutputTracker extends Logging {
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
- fetchedStatuses = deserializeStatuses(fetchedBytes)
+ fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
@@ -196,9 +155,8 @@ private[spark] class MapOutputTracker extends Logging {
}
}
- private def cleanup(cleanupTime: Long) {
+ protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
- cachedSerializedStatuses.clearOldValues(cleanupTime)
}
def stop() {
@@ -208,15 +166,7 @@ private[spark] class MapOutputTracker extends Logging {
trackerActor = null
}
- // Called on master to increment the epoch number
- def incrementEpoch() {
- epochLock.synchronized {
- epoch += 1
- logDebug("Increasing epoch to " + epoch)
- }
- }
-
- // Called on master or workers to get current epoch number
+ // Called to get current epoch number
def getEpoch: Long = {
epochLock.synchronized {
return epoch
@@ -230,14 +180,62 @@ private[spark] class MapOutputTracker extends Logging {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
- // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
- mapStatuses.clear()
epoch = newEpoch
+ mapStatuses.clear()
+ }
+ }
+ }
+}
+
+private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
+
+ // Cache a serialized version of the output statuses for each shuffle to send them out faster
+ private var cacheEpoch = epoch
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+
+ def registerShuffle(shuffleId: Int, numMaps: Int) {
+ if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+ throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+ }
+ }
+
+ def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
+ val array = mapStatuses(shuffleId)
+ array.synchronized {
+ array(mapId) = status
+ }
+ }
+
+ def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
+ mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
+ if (changeEpoch) {
+ incrementEpoch()
+ }
+ }
+
+ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ val arrayOpt = mapStatuses.get(shuffleId)
+ if (arrayOpt.isDefined && arrayOpt.get != null) {
+ val array = arrayOpt.get
+ array.synchronized {
+ if (array(mapId) != null && array(mapId).location == bmAddress) {
+ array(mapId) = null
+ }
}
+ incrementEpoch()
+ } else {
+ throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
- def getSerializedLocations(shuffleId: Int): Array[Byte] = {
+ def incrementEpoch() {
+ epochLock.synchronized {
+ epoch += 1
+ logDebug("Increasing epoch to " + epoch)
+ }
+ }
+
+ def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
@@ -255,7 +253,7 @@ private[spark] class MapOutputTracker extends Logging {
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
- val bytes = serializeStatuses(statuses)
+ val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
@@ -263,13 +261,31 @@ private[spark] class MapOutputTracker extends Logging {
cachedSerializedStatuses(shuffleId) = bytes
}
}
- return bytes
+ bytes
+ }
+
+ protected override def cleanup(cleanupTime: Long) {
+ super.cleanup(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
}
+ override def stop() {
+ super.stop()
+ cachedSerializedStatuses.clear()
+ }
+
+ override def updateEpoch(newEpoch: Long) {
+ // This might be called on the MapOutputTrackerMaster if we're running in local mode.
+ }
+}
+
+private[spark] object MapOutputTracker {
+ private val LOG_BASE = 1.1
+
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
@@ -280,18 +296,11 @@ private[spark] class MapOutputTracker extends Logging {
out.toByteArray
}
- // Opposite of serializeStatuses.
- def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
+ // Opposite of serializeMapStatuses.
+ def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
- objIn.readObject().
- // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
- // comment this out - nulls could be due to missing location ?
- asInstanceOf[Array[MapStatus]] // .filter( _ != null )
+ objIn.readObject().asInstanceOf[Array[MapStatus]]
}
-}
-
-private[spark] object MapOutputTracker {
- private val LOG_BASE = 1.1
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
index 307c383a89..a85aa50a9b 100644
--- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
@@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher {
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
+ def fetch[T](
+ shuffleId: Int,
+ reduceId: Int,
+ context: TaskContext,
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
/** Stop the fetcher */
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 1b003cc685..cc44a4c7dd 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.Map
import scala.collection.generic.Growable
-import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.reflect.{ ClassTag, classTag}
@@ -53,21 +53,19 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor
import org.apache.mesos.MesosNativeLibrary
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.LocalSparkCluster
+import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend,
- ClusterScheduler}
-import org.apache.spark.scheduler.local.LocalScheduler
+import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
+ SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import org.apache.spark.storage.{StorageUtils, BlockManagerSource}
-import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{ClosureCleaner, Utils, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.scheduler.local.LocalScheduler
import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.storage.RDDInfo
-import org.apache.spark.storage.StorageStatus
+import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
+import org.apache.spark.ui.SparkUI
+import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
+ TimeStampedHashMap, Utils}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -121,9 +119,9 @@ class SparkContext(
// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
- private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
+ private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup)
- // Initalize the Spark UI
+ // Initialize the Spark UI
private[spark] val ui = new SparkUI(this)
ui.bind()
@@ -149,6 +147,14 @@ class SparkContext(
executorEnvs ++= environment
}
+ // Set SPARK_USER for user who is running SparkContext.
+ val sparkUser = Option {
+ Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER"))
+ }.getOrElse {
+ SparkContext.SPARK_UNKNOWN_USER
+ }
+ executorEnvs("SPARK_USER") = sparkUser
+
// Create and start the scheduler
private[spark] var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
@@ -158,9 +164,11 @@ class SparkContext(
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
- val SPARK_REGEX = """(spark://.*)""".r
- //Regular expression for connection to Mesos cluster
- val MESOS_REGEX = """(mesos://.*)""".r
+ val SPARK_REGEX = """spark://(.*)""".r
+ // Regular expression for connection to Mesos cluster
+ val MESOS_REGEX = """mesos://(.*)""".r
+ // Regular expression for connection to Simr cluster
+ val SIMR_REGEX = """simr://(.*)""".r
master match {
case "local" =>
@@ -174,7 +182,14 @@ class SparkContext(
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
+ val masterUrls = sparkUrl.split(",").map("spark://" + _)
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
+ scheduler.initialize(backend)
+ scheduler
+
+ case SIMR_REGEX(simrUrl) =>
+ val scheduler = new ClusterScheduler(this)
+ val backend = new SimrSchedulerBackend(scheduler, this, simrUrl)
scheduler.initialize(backend)
scheduler
@@ -190,8 +205,8 @@ class SparkContext(
val scheduler = new ClusterScheduler(this)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
- val sparkUrl = localCluster.start()
- val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
+ val masterUrls = localCluster.start()
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
@@ -210,25 +225,24 @@ class SparkContext(
throw new SparkException("YARN mode not available ?", th)
}
}
- val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
+ val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem)
scheduler.initialize(backend)
scheduler
- case _ =>
- if (MESOS_REGEX.findFirstIn(master).isEmpty) {
- logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
- }
+ case MESOS_REGEX(mesosUrl) =>
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
- val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
} else {
- new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ new MesosSchedulerBackend(scheduler, this, mesosUrl, appName)
}
scheduler.initialize(backend)
scheduler
+
+ case _ =>
+ throw new SparkException("Could not parse Master URL: '" + master + "'")
}
}
taskScheduler.start()
@@ -241,7 +255,7 @@ class SparkContext(
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
val env = SparkEnv.get
- val conf = env.hadoop.newConfiguration()
+ val conf = SparkHadoopUtil.get.newConfiguration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
@@ -251,8 +265,10 @@ class SparkContext(
conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
- for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
- conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
+ Utils.getSystemProperties.foreach { case (key, value) =>
+ if (key.startsWith("spark.hadoop.")) {
+ conf.set(key.substring("spark.hadoop.".length), value)
+ }
}
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
@@ -285,15 +301,46 @@ class SparkContext(
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
/** Set a human readable description of the current job. */
+ @deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
- setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
+ setJobGroup("", value)
+ }
+
+ /**
+ * Assigns a group id to all the jobs started by this thread until the group id is set to a
+ * different value or cleared.
+ *
+ * Often, a unit of execution in an application consists of multiple Spark actions or jobs.
+ * Application programmers can use this method to group all those jobs together and give a
+ * group description. Once set, the Spark web UI will associate such jobs with this group.
+ *
+ * The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all
+ * running jobs in this group. For example,
+ * {{{
+ * // In the main thread:
+ * sc.setJobGroup("some_job_to_cancel", "some job description")
+ * sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
+ *
+ * // In a separate thread:
+ * sc.cancelJobGroup("some_job_to_cancel")
+ * }}}
+ */
+ def setJobGroup(groupId: String, description: String) {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
+ }
+
+ /** Clear the job group id and its description. */
+ def clearJobGroup() {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
}
// Post init
taskScheduler.postStartHook()
- val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
- val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
+ private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
+ private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
@@ -332,7 +379,7 @@ class SparkContext(
}
/**
- * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any
+ * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
* etc).
*/
@@ -344,7 +391,7 @@ class SparkContext(
minSplits: Int = defaultMinSplits
): RDD[(K, V)] = {
// Add necessary security credentials to the JobConf before broadcasting it.
- SparkEnv.get.hadoop.addCredentials(conf)
+ SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
@@ -358,24 +405,15 @@ class SparkContext(
): RDD[(K, V)] = {
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
- hadoopFile(path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits)
- }
-
- /**
- * Get an RDD for a Hadoop file with an arbitray InputFormat. Accept a Hadoop Configuration
- * that has already been broadcast, assuming that it's safe to use it to construct a
- * HadoopFileRDD (i.e., except for file 'path', all other configuration properties can be resued).
- */
- def hadoopFile[K, V](
- path: String,
- confBroadcast: Broadcast[SerializableWritable[Configuration]],
- inputFormatClass: Class[_ <: InputFormat[K, V]],
- keyClass: Class[K],
- valueClass: Class[V],
- minSplits: Int
- ): RDD[(K, V)] = {
- new HadoopFileRDD(
- this, path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits)
+ val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
+ new HadoopRDD(
+ this,
+ confBroadcast,
+ Some(setInputPathsFunc),
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ minSplits)
}
/**
@@ -563,7 +601,8 @@ class SparkContext(
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
- case _ => path
+ case "local" => "file:" + uri.getPath
+ case _ => path
}
addedFiles(key) = System.currentTimeMillis
@@ -657,12 +696,11 @@ class SparkContext(
/**
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node.
*/
def addJar(path: String) {
if (path == null) {
- logWarning("null specified as parameter to addJar",
- new SparkException("null specified as parameter to addJar"))
+ logWarning("null specified as parameter to addJar")
} else {
var key = ""
if (path.contains("\\")) {
@@ -671,12 +709,27 @@ class SparkContext(
} else {
val uri = new URI(path)
key = uri.getScheme match {
+ // A JAR file which exists only on the driver node
case null | "file" =>
- if (env.hadoop.isYarnMode()) {
- logWarning("local jar specified as parameter to addJar under Yarn mode")
- return
+ if (SparkHadoopUtil.get.isYarnMode()) {
+ // In order for this to work on yarn the user must specify the --addjars option to
+ // the client to upload the file into the distributed cache to make it show up in the
+ // current working directory.
+ val fileName = new Path(uri.getPath).getName()
+ try {
+ env.httpFileServer.addJar(new File(fileName))
+ } catch {
+ case e: Exception => {
+ logError("Error adding jar (" + e + "), was the --addJars option used?")
+ throw e
+ }
+ }
+ } else {
+ env.httpFileServer.addJar(new File(uri.getPath))
}
- env.httpFileServer.addJar(new File(uri.getPath))
+ // A JAR file which exists locally on every worker node
+ case "local" =>
+ "file:" + uri.getPath
case _ =>
path
}
@@ -750,13 +803,13 @@ class SparkContext(
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.formatSparkCallSite
+ val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,
- localProperties.get)
+ dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
+ resultHandler, localProperties.get)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
- result
}
/**
@@ -843,6 +896,42 @@ class SparkContext(
}
/**
+ * Submit a job for execution and return a FutureJob holding the result.
+ */
+ def submitJob[T, U, R](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ partitions: Seq[Int],
+ resultHandler: (Int, U) => Unit,
+ resultFunc: => R): SimpleFutureAction[R] =
+ {
+ val cleanF = clean(processPartition)
+ val callSite = Utils.formatSparkCallSite
+ val waiter = dagScheduler.submitJob(
+ rdd,
+ (context: TaskContext, iter: Iterator[T]) => cleanF(iter),
+ partitions,
+ callSite,
+ allowLocal = false,
+ resultHandler,
+ localProperties.get)
+ new SimpleFutureAction(waiter, resultFunc)
+ }
+
+ /**
+ * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]]
+ * for more information.
+ */
+ def cancelJobGroup(groupId: String) {
+ dagScheduler.cancelJobGroup(groupId)
+ }
+
+ /** Cancel all jobs that have been scheduled or are running. */
+ def cancelAllJobs() {
+ dagScheduler.cancelAllJobs()
+ }
+
+ /**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
*/
@@ -859,9 +948,8 @@ class SparkContext(
* prevent accidental overriding of checkpoint files in the existing directory.
*/
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
- val env = SparkEnv.get
val path = new Path(dir)
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
if (!useExisting) {
if (fs.exists(path)) {
throw new Exception("Checkpoint directory '" + path + "' already exists.")
@@ -898,7 +986,12 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
- val SPARK_JOB_DESCRIPTION = "spark.job.description"
+
+ private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
+
+ private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
+
+ private[spark] val SPARK_UNKNOWN_USER = "<unknown>"
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
@@ -925,6 +1018,8 @@ object SparkContext {
implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
+ implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
+
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index a267407c67..84750e2e85 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -25,13 +25,13 @@ import akka.remote.RemoteActorRefProvider
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster}
import org.apache.spark.network.ConnectionManager
import org.apache.spark.serializer.{Serializer, SerializerManager}
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.api.python.PythonWorkerFactory
+import com.google.common.collect.MapMaker
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
@@ -58,18 +58,9 @@ class SparkEnv (
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
- val hadoop = {
- val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
- if(yarnMode) {
- try {
- Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
- } catch {
- case th: Throwable => throw new SparkException("Unable to load YARN support", th)
- }
- } else {
- new SparkHadoopUtil
- }
- }
+ // A general, soft-reference map for metadata needed during HadoopRDD split computation
+ // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
+ private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
def stop() {
pythonWorkers.foreach { case(key, worker) => worker.stop() }
@@ -188,10 +179,14 @@ object SparkEnv extends Logging {
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
- val mapOutputTracker = new MapOutputTracker()
+ val mapOutputTracker = if (isDriver) {
+ new MapOutputTrackerMaster()
+ } else {
+ new MapOutputTracker()
+ }
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
- new MapOutputTrackerActor(mapOutputTracker))
+ new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 2bab9d6e3d..103a1c2051 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -17,14 +17,14 @@
package org.apache.hadoop.mapred
-import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.fs.Path
-
+import java.io.IOException
import java.text.SimpleDateFormat
import java.text.NumberFormat
-import java.io.IOException
import java.util.Date
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.Logging
import org.apache.spark.SerializableWritable
@@ -36,7 +36,11 @@ import org.apache.spark.SerializableWritable
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
-class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable {
+private[apache]
+class SparkHadoopWriter(@transient jobConf: JobConf)
+ extends Logging
+ with SparkHadoopMapRedUtil
+ with Serializable {
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
@@ -83,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
}
getOutputCommitter().setupTask(getTaskContext())
- writer = getOutputFormat().getRecordWriter(
- fs, conf.value, outputName, Reporter.NULL)
+ writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL)
}
def write(key: AnyRef, value: AnyRef) {
- if (writer!=null) {
- //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")")
+ if (writer != null) {
writer.write(key, value)
} else {
throw new IOException("Writer is null, open() has not been called")
@@ -179,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
}
}
+private[apache]
object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index c2c358c7ad..cae983ed4c 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -17,21 +17,30 @@
package org.apache.spark
-import executor.TaskMetrics
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.executor.TaskMetrics
+
class TaskContext(
val stageId: Int,
- val splitId: Int,
+ val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
- val taskMetrics: TaskMetrics = TaskMetrics.empty()
+ @volatile var interrupted: Boolean = false,
+ private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable {
- @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+ @deprecated("use partitionId", "0.8.1")
+ def splitId = partitionId
+
+ // List of callback functions to execute when the task completes.
+ @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
- // Add a callback function to be executed on task completion. An example use
- // is for HadoopRDD to register a callback to close the input stream.
+ /**
+ * Add a callback function to be executed on task completion. An example use
+ * is for HadoopRDD to register a callback to close the input stream.
+ * @param f Callback function.
+ */
def addOnCompleteCallback(f: () => Unit) {
onCompleteCallbacks += f
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 8466c2a004..c1e5e04b31 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -52,4 +52,6 @@ private[spark] case class ExceptionFailure(
*/
private[spark] case object TaskResultLost extends TaskEndReason
+private[spark] case object TaskKilled extends TaskEndReason
+
private[spark] case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index f0a1960a1b..e5e20dbb66 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -51,6 +51,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
*/
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
+ */
+ def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist())
+
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking))
+
// first() has to be overriden here in order for its return type to be Double instead of Object.
override def first(): Double = srdd.first()
@@ -84,6 +97,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
fromRDD(srdd.coalesce(numPartitions, shuffle))
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions))
+
+ /**
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 899e17d4fa..eeea0eddb1 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -66,6 +66,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
+ */
+ def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist())
+
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking))
+
// Transformations (return a new RDD)
/**
@@ -96,6 +109,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
fromRDD(rdd.coalesce(numPartitions, shuffle))
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions))
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
@@ -599,4 +623,15 @@ object JavaPairRDD {
new JavaPairRDD[K, V](rdd)
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
+
+
+ /** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
+ def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
+ implicit val cmk: ClassTag[K] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val cmv: ClassTag[V] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
+ new JavaPairRDD[K, V](rdd.rdd)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 9968bc8e5f..c47657f512 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -43,9 +43,17 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
*/
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking))
+
// Transformations (return a new RDD)
/**
@@ -76,6 +84,17 @@ JavaRDDLike[T, JavaRDD[T]] {
rdd.coalesce(numPartitions, shuffle)
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions)
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
index 4830067f7a..3e85052cd0 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
@@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
-import scala.runtime.AbstractFunction1;
-
import java.io.Serializable;
/**
@@ -27,11 +25,7 @@ import java.io.Serializable;
*/
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
-public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
+public abstract class DoubleFlatMapFunction<T> extends WrappedFunction1<T, Iterable<Double>>
implements Serializable {
-
- public abstract Iterable<Double> call(T t);
-
- @Override
- public final Iterable<Double> apply(T t) { return call(t); }
+ // Intentionally left blank
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
index ed92d31af5..5e9b8c48b8 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
@@ -27,6 +27,5 @@ import java.io.Serializable;
// are overloaded for both Function and DoubleFunction.
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
implements Serializable {
-
- public abstract Double call(T t) throws Exception;
+ // Intentionally left blank
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
index b7c0d78e33..ed8fea97fc 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
@@ -23,8 +23,5 @@ import scala.reflect.ClassTag
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
- @throws(classOf[Exception])
- def call(x: T) : java.lang.Iterable[R]
-
def elementType() : ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
index 7a505df4be..aae1349c5e 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
@@ -23,8 +23,5 @@ import scala.reflect.ClassTag
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
- @throws(classOf[Exception])
- def call(a: A, b:B) : java.lang.Iterable[C]
-
def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
index e97116986f..49e661a376 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
@@ -32,7 +32,7 @@ public abstract class Function<T, R> extends WrappedFunction1<T, R> implements S
public abstract R call(T t) throws Exception;
public ClassTag<R> returnType() {
- return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
+ return ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
new file mode 100644
index 0000000000..fb1deceab5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+import scala.runtime.AbstractFunction2;
+
+import java.io.Serializable;
+
+/**
+ * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R.
+ */
+public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R>
+ implements Serializable {
+
+ public ClassTag<R> returnType() {
+ return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
index fbd0cdabe0..ca485b3cc2 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
@@ -33,8 +33,6 @@ public abstract class PairFlatMapFunction<T, K, V>
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
implements Serializable {
- public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception;
-
public ClassTag<K> keyType() {
return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
index f09559627d..cbe2306026 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
@@ -28,12 +28,9 @@ import java.io.Serializable;
*/
// PairFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and PairFunction.
-public abstract class PairFunction<T, K, V>
- extends WrappedFunction1<T, Tuple2<K, V>>
+public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>>
implements Serializable {
- public abstract Tuple2<K, V> call(T t) throws Exception;
-
public ClassTag<K> keyType() {
return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
new file mode 100644
index 0000000000..d314dbdf1d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function
+
+import scala.runtime.AbstractFunction3
+
+/**
+ * Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the
+ * apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply
+ * isn't marked to allow that).
+ */
+private[spark] abstract class WrappedFunction3[T1, T2, T3, R]
+ extends AbstractFunction3[T1, T2, T3, R] {
+ @throws(classOf[Exception])
+ def call(t1: T1, t2: T2, t3: T3): R
+
+ final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3)
+}
+
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 4d887cf195..53b53df9ac 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -308,7 +308,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
-class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
deleted file mode 100644
index 93e7815ab5..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
+++ /dev/null
@@ -1,1058 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Comparator, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ListBuffer, Map, Set}
-import scala.math
-
-import org.apache.spark._
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.Utils
-
-private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
- extends Broadcast[T](id)
- with Logging
- with Serializable {
-
- def value = value_
-
- def blockId: String = "broadcast_" + id
-
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var hasBlocksBitVector: BitSet = null
- @transient var numCopiesSent: Array[Int] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = new AtomicInteger(0)
-
- // Used ONLY by driver to track how many unique blocks have been sent out
- @transient var sentBlocks = new AtomicInteger(0)
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
-
- @transient var listOfSources = ListBuffer[SourceInfo]()
-
- @transient var serveMR: ServeMultipleRequests = null
-
- // Used only in driver
- @transient var guideMR: GuideMultipleRequests = null
-
- // Used only in Workers
- @transient var ttGuide: TalkToGuide = null
-
- @transient var hostAddress = Utils.localIpAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast()
- }
-
- def sendBroadcast() {
- logInfo("Local host address: " + hostAddress)
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = MultiTracker.blockifyObject(value_)
-
- // Prepare the value being broadcasted
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks.set(variableInfo.totalBlocks)
-
- // Guide has all the blocks
- hasBlocksBitVector = new BitSet(totalBlocks)
- hasBlocksBitVector.set(0, totalBlocks)
-
- // Guide still hasn't sent any block
- numCopiesSent = new Array[Int](totalBlocks)
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon(true)
- guideMR.start()
- logInfo("GuideMultipleRequests started...")
-
- // Must always come AFTER guideMR is created
- while (guidePort == -1) {
- guidePortLock.synchronized { guidePortLock.wait() }
- }
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- // Must always come AFTER serveMR is created
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Must always come AFTER listenPort is created
- val driverSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
- hasBlocksBitVector.synchronized {
- driverSource.hasBlocksBitVector = hasBlocksBitVector
- }
-
- // In the beginning, this is the only known source to Guide
- listOfSources += driverSource
-
- // Register with the Tracker
- MultiTracker.registerBroadcast(id,
- SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
- }
-
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(blockId) match {
- case Some(x) =>
- value_ = x.asInstanceOf[T]
-
- case None =>
- logInfo("Started reading broadcast variable " + id)
- // Initializing everything because driver will only send null/0 values
- // Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables()
-
- logInfo("Local host address: " + hostAddress)
-
- // Start local ServeMultipleRequests thread first
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(id)
- if (receptionSucceeded) {
- value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- // Initialize variables in the worker node. Driver sends everything as 0/null
- private def initializeWorkerVariables() {
- arrayOfBlocks = null
- hasBlocksBitVector = null
- numCopiesSent = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = new AtomicInteger(0)
-
- listenPortLock = new Object
- totalBlocksLock = new Object
-
- serveMR = null
- ttGuide = null
-
- hostAddress = Utils.localIpAddress
- listenPort = -1
-
- listOfSources = ListBuffer[SourceInfo]()
-
- stopBroadcast = false
- }
-
- private def getLocalSourceInfo: SourceInfo = {
- // Wait till hostName and listenPort are OK
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Wait till totalBlocks and totalBytes are OK
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized { totalBlocksLock.wait() }
- }
-
- var localSourceInfo = SourceInfo(
- hostAddress, listenPort, totalBlocks, totalBytes)
-
- localSourceInfo.hasBlocks = hasBlocks.get
-
- hasBlocksBitVector.synchronized {
- localSourceInfo.hasBlocksBitVector = hasBlocksBitVector
- }
-
- return localSourceInfo
- }
-
- // Add new SourceInfo to the listOfSources. Update if it exists already.
- // Optimizing just by OR-ing the BitVectors was BAD for performance
- private def addToListOfSources(newSourceInfo: SourceInfo) {
- listOfSources.synchronized {
- if (listOfSources.contains(newSourceInfo)) {
- listOfSources = listOfSources - newSourceInfo
- }
- listOfSources += newSourceInfo
- }
- }
-
- private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) {
- newSourceInfos.foreach { newSourceInfo =>
- addToListOfSources(newSourceInfo)
- }
- }
-
- class TalkToGuide(gInfo: SourceInfo)
- extends Thread with Logging {
- override def run() {
-
- // Keep exchaning information until all blocks have been received
- while (hasBlocks.get < totalBlocks) {
- talkOnce
- Thread.sleep(MultiTracker.ranGen.nextInt(
- MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
- MultiTracker.MinKnockInterval)
- }
-
- // Talk one more time to let the Guide know of reception completion
- talkOnce
- }
-
- // Connect to Guide and send this worker's information
- private def talkOnce {
- var clientSocketToGuide: Socket = null
- var oosGuide: ObjectOutputStream = null
- var oisGuide: ObjectInputStream = null
-
- clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort)
- oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream)
- oosGuide.flush()
- oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream)
-
- // Send local information
- oosGuide.writeObject(getLocalSourceInfo)
- oosGuide.flush()
-
- // Receive source information from Guide
- var suitableSources =
- oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
- logDebug("Received suitableSources from Driver " + suitableSources)
-
- addToListOfSources(suitableSources)
-
- oisGuide.close()
- oosGuide.close()
- clientSocketToGuide.close()
- }
- }
-
- def receiveBroadcast(variableID: Long): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableID)
-
- if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Setup initial states of variables
- totalBlocks = gInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- hasBlocksBitVector = new BitSet(totalBlocks)
- numCopiesSent = new Array[Int](totalBlocks)
- totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
- totalBytes = gInfo.totalBytes
-
- // Start ttGuide to periodically talk to the Guide
- var ttGuide = new TalkToGuide(gInfo)
- ttGuide.setDaemon(true)
- ttGuide.start()
- logInfo("TalkToGuide started...")
-
- // Start pController to run TalkToPeer threads
- var pcController = new PeerChatterController
- pcController.setDaemon(true)
- pcController.start()
- logInfo("PeerChatterController started...")
-
- // FIXME: Must fix this. This might never break if broadcast fails.
- // We should be able to break and send false. Also need to kill threads
- while (hasBlocks.get < totalBlocks) {
- Thread.sleep(MultiTracker.MaxKnockInterval)
- }
-
- return true
- }
-
- class PeerChatterController
- extends Thread with Logging {
- private var peersNowTalking = ListBuffer[SourceInfo]()
- // TODO: There is a possible bug with blocksInRequestBitVector when a
- // certain bit is NOT unset upon failure resulting in an infinite loop.
- private var blocksInRequestBitVector = new BitSet(totalBlocks)
-
- override def run() {
- var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
-
- while (hasBlocks.get < totalBlocks) {
- var numThreadsToCreate = 0
- listOfSources.synchronized {
- numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
- threadPool.getActiveCount
- }
-
- while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
- var peerToTalkTo = pickPeerToTalkToRandom
-
- if (peerToTalkTo != null)
- logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector)
- else
- logDebug("No peer chosen...")
-
- if (peerToTalkTo != null) {
- threadPool.execute(new TalkToPeer(peerToTalkTo))
-
- // Add to peersNowTalking. Remove in the thread. We have to do this
- // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once
- peersNowTalking.synchronized { peersNowTalking += peerToTalkTo }
- }
-
- numThreadsToCreate = numThreadsToCreate - 1
- }
-
- // Sleep for a while before starting some more threads
- Thread.sleep(MultiTracker.MinKnockInterval)
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- // Right now picking the one that has the most blocks this peer wants
- // Also picking peer randomly if no one has anything interesting
- private def pickPeerToTalkToRandom: SourceInfo = {
- var curPeer: SourceInfo = null
- var curMax = 0
-
- logDebug("Picking peers to talk to...")
-
- // Find peers that are not connected right now
- var peersNotInUse = ListBuffer[SourceInfo]()
- listOfSources.synchronized {
- peersNowTalking.synchronized {
- peersNotInUse = listOfSources -- peersNowTalking
- }
- }
-
- // Select the peer that has the most blocks that this receiver does not
- peersNotInUse.foreach { eachSource =>
- var tempHasBlocksBitVector: BitSet = null
- hasBlocksBitVector.synchronized {
- tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
- tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size)
- tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector)
-
- if (tempHasBlocksBitVector.cardinality > curMax) {
- curPeer = eachSource
- curMax = tempHasBlocksBitVector.cardinality
- }
- }
-
- // Always picking randomly
- if (curPeer == null && peersNotInUse.size > 0) {
- // Pick uniformly the i'th required peer
- var i = MultiTracker.ranGen.nextInt(peersNotInUse.size)
-
- var peerIter = peersNotInUse.iterator
- curPeer = peerIter.next
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
- }
-
- return curPeer
- }
-
- // Picking peer with the weight of rare blocks it has
- private def pickPeerToTalkToRarestFirst: SourceInfo = {
- // Find peers that are not connected right now
- var peersNotInUse = ListBuffer[SourceInfo]()
- listOfSources.synchronized {
- peersNowTalking.synchronized {
- peersNotInUse = listOfSources -- peersNowTalking
- }
- }
-
- // Count the number of copies of each block in the neighborhood
- var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
-
- listOfSources.synchronized {
- listOfSources.foreach { eachSource =>
- for (i <- 0 until totalBlocks) {
- numCopiesPerBlock(i) +=
- ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
- }
- }
- }
-
- // A block is considered rare if there are at most 2 copies of that block
- // This CONSTANT could be a function of the neighborhood size
- var rareBlocksIndices = ListBuffer[Int]()
- for (i <- 0 until totalBlocks) {
- if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) {
- rareBlocksIndices += i
- }
- }
-
- // Find peers with rare blocks
- var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]()
- var totalRareBlocks = 0
-
- peersNotInUse.foreach { eachPeer =>
- var hasRareBlocks = 0
- rareBlocksIndices.foreach { rareBlock =>
- if (eachPeer.hasBlocksBitVector.get(rareBlock)) {
- hasRareBlocks += 1
- }
- }
-
- if (hasRareBlocks > 0) {
- peersWithRareBlocks += ((eachPeer, hasRareBlocks))
- }
- totalRareBlocks += hasRareBlocks
- }
-
- // Select a peer from peersWithRareBlocks based on weight calculated from
- // unique rare blocks
- var selectedPeerToTalkTo: SourceInfo = null
-
- if (peersWithRareBlocks.size > 0) {
- // Sort the peers based on how many rare blocks they have
- peersWithRareBlocks.sortBy(_._2)
-
- var randomNumber = MultiTracker.ranGen.nextDouble
- var tempSum = 0.0
-
- var i = 0
- do {
- tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks)
- if (tempSum >= randomNumber) {
- selectedPeerToTalkTo = peersWithRareBlocks(i)._1
- }
- i += 1
- } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null)
- }
-
- if (selectedPeerToTalkTo == null) {
- selectedPeerToTalkTo = pickPeerToTalkToRandom
- }
-
- return selectedPeerToTalkTo
- }
-
- class TalkToPeer(peerToTalkTo: SourceInfo)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- override def run() {
- // TODO: There is a possible bug here regarding blocksInRequestBitVector
- var blockToAskFor = -1
-
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run() {
- cleanUpConnections()
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval)
-
- logInfo("TalkToPeer started... => " + peerToTalkTo)
-
- try {
- // Connect to the source
- peerSocketToSource =
- new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort)
- oosSource =
- new ObjectOutputStream(peerSocketToSource.getOutputStream)
- oosSource.flush()
- oisSource =
- new ObjectInputStream(peerSocketToSource.getInputStream)
-
- // Receive latest SourceInfo from peerToTalkTo
- var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo]
- // Update listOfSources
- addToListOfSources(newPeerToTalkTo)
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel()
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush()
-
- var keepReceiving = true
-
- while (hasBlocks.get < totalBlocks && keepReceiving) {
- blockToAskFor =
- pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector)
-
- // No block to request
- if (blockToAskFor < 0) {
- // Nothing to receive from newPeerToTalkTo
- keepReceiving = false
- } else {
- // Let other threads know that blockToAskFor is being requested
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor)
- }
-
- // Start with sending the blockID
- oosSource.writeObject(blockToAskFor)
- oosSource.flush()
-
- // CHANGED: Driver might send some other block than the one
- // requested to ensure fast spreading of all blocks.
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
-
- if (!hasBlocksBitVector.get(bcBlock.blockID)) {
- arrayOfBlocks(bcBlock.blockID) = bcBlock
-
- // Update the hasBlocksBitVector first
- hasBlocksBitVector.synchronized {
- hasBlocksBitVector.set(bcBlock.blockID)
- hasBlocks.getAndIncrement
- }
-
- // Some block(may NOT be blockToAskFor) has arrived.
- // In any case, blockToAskFor is not in request any more
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor, false)
- }
-
- // Reset blockToAskFor to -1. Else it will be considered missing
- blockToAskFor = -1
- }
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush()
- }
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logError("TalktoPeer had a " + e)
- // FIXME: Remove 'newPeerToTalkTo' from listOfSources
- // We probably should have the following in some form, but not
- // really here. This exception can happen if the sender just breaks connection
- // listOfSources.synchronized {
- // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo)
- // listOfSources = listOfSources - peerToTalkTo
- // }
- }
- } finally {
- // blockToAskFor != -1 => there was an exception
- if (blockToAskFor != -1) {
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor, false)
- }
- }
-
- cleanUpConnections()
- }
- }
-
- // Right now it picks a block uniformly that this peer does not have
- private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = {
- var needBlocksBitVector: BitSet = null
-
- // Blocks already present
- hasBlocksBitVector.synchronized {
- needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include blocks already in transmission ONLY IF
- // MultiTracker.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
- blocksInRequestBitVector.synchronized {
- needBlocksBitVector.or(blocksInRequestBitVector)
- }
- }
-
- // Find blocks that are neither here nor in transit
- needBlocksBitVector.flip(0, needBlocksBitVector.size)
-
- // Blocks that should/can be requested
- needBlocksBitVector.and(txHasBlocksBitVector)
-
- if (needBlocksBitVector.cardinality == 0) {
- return -1
- } else {
- // Pick uniformly the i'th required block
- var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality)
- var pickedBlockIndex = needBlocksBitVector.nextSetBit(0)
-
- while (i > 0) {
- pickedBlockIndex =
- needBlocksBitVector.nextSetBit(pickedBlockIndex + 1)
- i -= 1
- }
-
- return pickedBlockIndex
- }
- }
-
- // Pick the block that seems to be the rarest across sources
- private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = {
- var needBlocksBitVector: BitSet = null
-
- // Blocks already present
- hasBlocksBitVector.synchronized {
- needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include blocks already in transmission ONLY IF
- // MultiTracker.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
- blocksInRequestBitVector.synchronized {
- needBlocksBitVector.or(blocksInRequestBitVector)
- }
- }
-
- // Find blocks that are neither here nor in transit
- needBlocksBitVector.flip(0, needBlocksBitVector.size)
-
- // Blocks that should/can be requested
- needBlocksBitVector.and(txHasBlocksBitVector)
-
- if (needBlocksBitVector.cardinality == 0) {
- return -1
- } else {
- // Count the number of copies for each block across all sources
- var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
-
- listOfSources.synchronized {
- listOfSources.foreach { eachSource =>
- for (i <- 0 until totalBlocks) {
- numCopiesPerBlock(i) +=
- ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
- }
- }
- }
-
- // Find the minimum
- var minVal = Integer.MAX_VALUE
- for (i <- 0 until totalBlocks) {
- if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) {
- minVal = numCopiesPerBlock(i)
- }
- }
-
- // Find the blocks with the least copies that this peer does not have
- var minBlocksIndices = ListBuffer[Int]()
- for (i <- 0 until totalBlocks) {
- if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) {
- minBlocksIndices += i
- }
- }
-
- // Now select a random index from minBlocksIndices
- if (minBlocksIndices.size == 0) {
- return -1
- } else {
- // Pick uniformly the i'th index
- var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size)
- return minBlocksIndices(i)
- }
- }
- }
-
- private def cleanUpConnections() {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close()
- }
-
- // Delete from peersNowTalking
- peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo }
- }
- }
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo]()
-
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- guidePort = serverSocket.getLocalPort
- logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized { guidePortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
- listOfSources.synchronized {
- setOfCompletedSources.synchronized {
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
- logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
- }
- }
- }
- }
- }
- if (clientSocket != null) {
- logDebug("Guide: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new GuideSingleRequest(clientSocket))
- } catch {
- // In failure, close the socket here; else, thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown()
-
- logInfo("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- MultiTracker.unregisterBroadcast(id)
- } finally {
- if (serverSocket != null) {
- logInfo("GuideMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- }
-
- private def sendStopBroadcastNotifications() {
- listOfSources.synchronized {
- listOfSources.foreach { sourceInfo =>
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
- gosSource.flush()
- gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
-
- // Throw away whatever comes in
- gisSource.readObject.asInstanceOf[SourceInfo]
-
- // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast
- gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast))
- gosSource.flush()
- } catch {
- case e: Exception => {
- logError("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close()
- }
- if (gosSource != null) {
- gosSource.close()
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close()
- }
- }
- }
- }
- }
-
- class GuideSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var sourceInfo: SourceInfo = null
- private var selectedSources: ListBuffer[SourceInfo] = null
-
- override def run() {
- try {
- logInfo("new GuideSingleRequest is running")
- // Connecting worker is sending in its information
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Select a suitable source and send it back to the worker
- selectedSources = selectSuitableSources(sourceInfo)
- logDebug("Sending selectedSources:" + selectedSources)
- oos.writeObject(selectedSources)
- oos.flush()
-
- // Add this source to the listOfSources
- addToListOfSources(sourceInfo)
- } catch {
- case e: Exception => {
- // Assuming exception caused by receiver failure: remove
- if (listOfSources != null) {
- listOfSources.synchronized { listOfSources -= sourceInfo }
- }
- }
- } finally {
- logInfo("GuideSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- // Randomly select some sources to send back
- private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = {
- var selectedSources = ListBuffer[SourceInfo]()
-
- // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true'
- // then add skipSourceInfo to setOfCompletedSources. Return blank.
- if (skipSourceInfo.hasBlocks == totalBlocks) {
- setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo }
- return selectedSources
- }
-
- listOfSources.synchronized {
- if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) {
- selectedSources = listOfSources.clone
- } else {
- var picksLeft = MultiTracker.MaxPeersInGuideResponse
- var alreadyPicked = new BitSet(listOfSources.size)
-
- while (picksLeft > 0) {
- var i = -1
-
- do {
- i = MultiTracker.ranGen.nextInt(listOfSources.size)
- } while (alreadyPicked.get(i))
-
- var peerIter = listOfSources.iterator
- var curPeer = peerIter.next
-
- // Set the BitSet before i is decremented
- alreadyPicked.set(i)
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
-
- selectedSources += curPeer
-
- picksLeft = picksLeft - 1
- }
- }
- }
-
- // Remove the receiving source (if present)
- selectedSources = selectedSources - skipSourceInfo
-
- return selectedSources
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
- // Server at most MultiTracker.MaxChatSlots peers
- var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
-
- override def run() {
- var serverSocket = new ServerSocket(0)
- listenPort = serverSocket.getLocalPort
-
- logInfo("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized { listenPortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => { }
- }
- if (clientSocket != null) {
- logDebug("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new ServeSingleRequest(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ServeMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ServeSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- logInfo("new ServeSingleRequest is running")
-
- override def run() {
- try {
- // Send latest local SourceInfo to the receiver
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- oos.writeObject(getLocalSourceInfo)
- oos.flush()
-
- // Receive latest SourceInfo from the receiver
- var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- addToListOfSources(rxSourceInfo)
- }
-
- val startTime = System.currentTimeMillis
- var curTime = startTime
- var keepSending = true
- var numBlocksToSend = MultiTracker.MaxChatBlocks
-
- while (!stopBroadcast && keepSending && numBlocksToSend > 0) {
- // Receive which block to send
- var blockToSend = ois.readObject.asInstanceOf[Int]
-
- // If it is driver AND at least one copy of each block has not been
- // sent out already, MODIFY blockToSend
- if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) {
- blockToSend = sentBlocks.getAndIncrement
- }
-
- // Send the block
- sendBlock(blockToSend)
- rxSourceInfo.hasBlocksBitVector.set(blockToSend)
-
- numBlocksToSend -= 1
-
- // Receive latest SourceInfo from the receiver
- rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
- logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
- addToListOfSources(rxSourceInfo)
-
- curTime = System.currentTimeMillis
- // Revoke sending only if there is anyone waiting in the queue
- if (curTime - startTime >= MultiTracker.MaxChatTime &&
- threadPool.getQueue.size > 0) {
- keepSending = false
- }
- }
- } catch {
- case e: Exception => logError("ServeSingleRequest had a " + e)
- } finally {
- logInfo("ServeSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- private def sendBlock(blockToSend: Int) {
- try {
- oos.writeObject(arrayOfBlocks(blockToSend))
- oos.flush()
- } catch {
- case e: Exception => logError("sendBlock had a " + e)
- }
- logDebug("Sent block: " + blockToSend + " to " + clientSocket)
- }
- }
- }
-}
-
-private[spark] class BitTorrentBroadcastFactory
-extends BroadcastFactory {
- def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new BitTorrentBroadcast[T](value_, isLocal, id)
-
- def stop() { MultiTracker.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 9db26ae6de..609464e38d 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -25,16 +25,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import org.apache.spark.{HttpServer, Logging, SparkEnv}
import org.apache.spark.io.CompressionCodec
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashSet}
-
+import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
- def blockId: String = "broadcast_" + id
+ def blockId = BroadcastBlockId(id)
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
@@ -82,7 +81,7 @@ private object HttpBroadcast extends Logging {
private var server: HttpServer = null
private val files = new TimeStampedHashSet[String]
- private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
+ private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup)
private lazy val compressionCodec = CompressionCodec.createCodec()
@@ -121,7 +120,7 @@ private object HttpBroadcast extends Logging {
}
def write(id: Long, value: Any) {
- val file = new File(broadcastDir, "broadcast-" + id)
+ val file = new File(broadcastDir, BroadcastBlockId(id).name)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
@@ -137,7 +136,7 @@ private object HttpBroadcast extends Logging {
}
def read[T](id: Long): T = {
- val url = serverUri + "/broadcast-" + id
+ val url = serverUri + "/" + BroadcastBlockId(id).name
val in = {
if (compress) {
compressionCodec.compressedInputStream(new URL(url).openStream())
diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
deleted file mode 100644
index 21ec94659e..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
+++ /dev/null
@@ -1,410 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.Random
-
-import scala.collection.mutable.Map
-
-import org.apache.spark._
-import org.apache.spark.util.Utils
-
-private object MultiTracker
-extends Logging {
-
- // Tracker Messages
- val REGISTER_BROADCAST_TRACKER = 0
- val UNREGISTER_BROADCAST_TRACKER = 1
- val FIND_BROADCAST_TRACKER = 2
-
- // Map to keep track of guides of ongoing broadcasts
- var valueToGuideMap = Map[Long, SourceInfo]()
-
- // Random number generator
- var ranGen = new Random
-
- private var initialized = false
- private var _isDriver = false
-
- private var stopBroadcast = false
-
- private var trackMV: TrackMultipleValues = null
-
- def initialize(__isDriver: Boolean) {
- synchronized {
- if (!initialized) {
- _isDriver = __isDriver
-
- if (isDriver) {
- trackMV = new TrackMultipleValues
- trackMV.setDaemon(true)
- trackMV.start()
-
- // Set DriverHostAddress to the driver's IP address for the slaves to read
- System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
- }
-
- initialized = true
- }
- }
- }
-
- def stop() {
- stopBroadcast = true
- }
-
- // Load common parameters
- private var DriverHostAddress_ = System.getProperty(
- "spark.MultiTracker.DriverHostAddress", "")
- private var DriverTrackerPort_ = System.getProperty(
- "spark.broadcast.driverTrackerPort", "11111").toInt
- private var BlockSize_ = System.getProperty(
- "spark.broadcast.blockSize", "4096").toInt * 1024
- private var MaxRetryCount_ = System.getProperty(
- "spark.broadcast.maxRetryCount", "2").toInt
-
- private var TrackerSocketTimeout_ = System.getProperty(
- "spark.broadcast.trackerSocketTimeout", "50000").toInt
- private var ServerSocketTimeout_ = System.getProperty(
- "spark.broadcast.serverSocketTimeout", "10000").toInt
-
- private var MinKnockInterval_ = System.getProperty(
- "spark.broadcast.minKnockInterval", "500").toInt
- private var MaxKnockInterval_ = System.getProperty(
- "spark.broadcast.maxKnockInterval", "999").toInt
-
- // Load TreeBroadcast config params
- private var MaxDegree_ = System.getProperty(
- "spark.broadcast.maxDegree", "2").toInt
-
- // Load BitTorrentBroadcast config params
- private var MaxPeersInGuideResponse_ = System.getProperty(
- "spark.broadcast.maxPeersInGuideResponse", "4").toInt
-
- private var MaxChatSlots_ = System.getProperty(
- "spark.broadcast.maxChatSlots", "4").toInt
- private var MaxChatTime_ = System.getProperty(
- "spark.broadcast.maxChatTime", "500").toInt
- private var MaxChatBlocks_ = System.getProperty(
- "spark.broadcast.maxChatBlocks", "1024").toInt
-
- private var EndGameFraction_ = System.getProperty(
- "spark.broadcast.endGameFraction", "0.95").toDouble
-
- def isDriver = _isDriver
-
- // Common config params
- def DriverHostAddress = DriverHostAddress_
- def DriverTrackerPort = DriverTrackerPort_
- def BlockSize = BlockSize_
- def MaxRetryCount = MaxRetryCount_
-
- def TrackerSocketTimeout = TrackerSocketTimeout_
- def ServerSocketTimeout = ServerSocketTimeout_
-
- def MinKnockInterval = MinKnockInterval_
- def MaxKnockInterval = MaxKnockInterval_
-
- // TreeBroadcast configs
- def MaxDegree = MaxDegree_
-
- // BitTorrentBroadcast configs
- def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
-
- def MaxChatSlots = MaxChatSlots_
- def MaxChatTime = MaxChatTime_
- def MaxChatBlocks = MaxChatBlocks_
-
- def EndGameFraction = EndGameFraction_
-
- class TrackMultipleValues
- extends Thread with Logging {
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(DriverTrackerPort)
- logInfo("TrackMultipleValues started at " + serverSocket)
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(TrackerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- if (stopBroadcast) {
- logInfo("Stopping TrackMultipleValues...")
- }
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute(new Thread {
- override def run() {
- val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- try {
- // First, read message type
- val messageType = ois.readObject.asInstanceOf[Int]
-
- if (messageType == REGISTER_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
- // Receive hostAddress and listenPort
- val gInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Add to the map
- valueToGuideMap.synchronized {
- valueToGuideMap += (id -> gInfo)
- }
-
- logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
-
- // Remove from the map
- valueToGuideMap.synchronized {
- valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
- }
-
- logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == FIND_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
-
- var gInfo =
- if (valueToGuideMap.contains(id)) valueToGuideMap(id)
- else SourceInfo("", SourceInfo.TxNotStartedRetry)
-
- logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
-
- // Send reply back
- oos.writeObject(gInfo)
- oos.flush()
- } else {
- throw new SparkException("Undefined messageType at TrackMultipleValues")
- }
- } catch {
- case e: Exception => {
- logError("TrackMultipleValues had a " + e)
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- serverSocket.close()
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
- }
-
- def getGuideInfo(variableLong: Long): SourceInfo = {
- var clientSocketToTracker: Socket = null
- var oosTracker: ObjectOutputStream = null
- var oisTracker: ObjectInputStream = null
-
- var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
-
- var retriesLeft = MultiTracker.MaxRetryCount
- do {
- try {
- // Connect to the tracker to find out GuideInfo
- clientSocketToTracker =
- new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
- oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- // Send messageType/intention
- oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
- oosTracker.flush()
-
- // Send Long and receive GuideInfo
- oosTracker.writeObject(variableLong)
- oosTracker.flush()
- gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
- } catch {
- case e: Exception => logError("getGuideInfo had a " + e)
- } finally {
- if (oisTracker != null) {
- oisTracker.close()
- }
- if (oosTracker != null) {
- oosTracker.close()
- }
- if (clientSocketToTracker != null) {
- clientSocketToTracker.close()
- }
- }
-
- Thread.sleep(MultiTracker.ranGen.nextInt(
- MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
- MultiTracker.MinKnockInterval)
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
-
- logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
- return gInfo
- }
-
- def registerBroadcast(id: Long, gInfo: SourceInfo) {
- val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(REGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send Long of this broadcast
- oosST.writeObject(id)
- oosST.flush()
-
- // Send this tracker's information
- oosST.writeObject(gInfo)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
- def unregisterBroadcast(id: Long) {
- val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send Long of this broadcast
- oosST.writeObject(id)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
- // Helper method to convert an object to Array[BroadcastBlock]
- def blockifyObject[IN](obj: IN): VariableInfo = {
- val baos = new ByteArrayOutputStream
- val oos = new ObjectOutputStream(baos)
- oos.writeObject(obj)
- oos.close()
- baos.close()
- val byteArray = baos.toByteArray
- val bais = new ByteArrayInputStream(byteArray)
-
- var blockNum = (byteArray.length / BlockSize)
- if (byteArray.length % BlockSize != 0)
- blockNum += 1
-
- var retVal = new Array[BroadcastBlock](blockNum)
- var blockID = 0
-
- for (i <- 0 until (byteArray.length, BlockSize)) {
- val thisBlockSize = math.min(BlockSize, byteArray.length - i)
- var tempByteArray = new Array[Byte](thisBlockSize)
- val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
-
- retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
- blockID += 1
- }
- bais.close()
-
- var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
- variableInfo.hasBlocks = blockNum
-
- return variableInfo
- }
-
- // Helper method to convert Array[BroadcastBlock] to object
- def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
- totalBytes: Int,
- totalBlocks: Int): OUT = {
-
- var retByteArray = new Array[Byte](totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * BlockSize, arrayOfBlocks(i).byteArray.length)
- }
- byteArrayToObject(retByteArray)
- }
-
- private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
- val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
- override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
- }
- val retVal = in.readObject.asInstanceOf[OUT]
- in.close()
- return retVal
- }
-}
-
-private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
-extends Serializable
-
-private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
- totalBlocks: Int,
- totalBytes: Int)
-extends Serializable {
- @transient var hasBlocks = 0
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
deleted file mode 100644
index baa1fd6da4..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.util.BitSet
-
-import org.apache.spark._
-
-/**
- * Used to keep and pass around information of peers involved in a broadcast
- */
-private[spark] case class SourceInfo (hostAddress: String,
- listenPort: Int,
- totalBlocks: Int = SourceInfo.UnusedParam,
- totalBytes: Int = SourceInfo.UnusedParam)
-extends Comparable[SourceInfo] with Logging {
-
- var currentLeechers = 0
- var receptionFailed = false
-
- var hasBlocks = 0
- var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
-
- // Ascending sort based on leecher count
- def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
-}
-
-/**
- * Helper Object of SourceInfo for its constants
- */
-private[spark] object SourceInfo {
- // Broadcast has not started yet! Should never happen.
- val TxNotStartedRetry = -1
- // Broadcast has already finished. Try default mechanism.
- val TxOverGoToDefault = -3
- // Other constants
- val StopBroadcast = -2
- val UnusedParam = 0
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
new file mode 100644
index 0000000000..073a0a5029
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io._
+
+import scala.math
+import scala.util.Random
+
+import org.apache.spark._
+import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+
+private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+extends Broadcast[T](id) with Logging with Serializable {
+
+ def value = value_
+
+ def broadcastId = BroadcastBlockId(id)
+
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ }
+
+ @transient var arrayOfBlocks: Array[TorrentBlock] = null
+ @transient var totalBlocks = -1
+ @transient var totalBytes = -1
+ @transient var hasBlocks = 0
+
+ if (!isLocal) {
+ sendBroadcast()
+ }
+
+ def sendBroadcast() {
+ var tInfo = TorrentBroadcast.blockifyObject(value_)
+
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ hasBlocks = tInfo.totalBlocks
+
+ // Store meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
+ }
+
+ // Store individual pieces
+ for (i <- 0 until totalBlocks) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
+ }
+ }
+ }
+
+ // Called by JVM when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(broadcastId) match {
+ case Some(x) =>
+ value_ = x.asInstanceOf[T]
+
+ case None =>
+ val start = System.nanoTime
+ logInfo("Started reading broadcast variable " + id)
+
+ // Initialize @transient variables that will receive garbage values from the master.
+ resetWorkerVariables()
+
+ if (receiveBroadcast(id)) {
+ value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+
+ // Store the merged copy in cache so that the next worker doesn't need to rebuild it.
+ // This creates a tradeoff between memory usage and latency.
+ // Storing copy doubles the memory footprint; not storing doubles deserialization cost.
+ SparkEnv.get.blockManager.putSingle(
+ broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+
+ // Remove arrayOfBlocks from memory once value_ is on local cache
+ resetWorkerVariables()
+ } else {
+ logError("Reading broadcast variable " + id + " failed")
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ }
+ }
+ }
+
+ private def resetWorkerVariables() {
+ arrayOfBlocks = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = 0
+ }
+
+ def receiveBroadcast(variableID: Long): Boolean = {
+ // Receive meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ var attemptId = 10
+ while (attemptId > 0 && totalBlocks == -1) {
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(metaId) match {
+ case Some(x) =>
+ val tInfo = x.asInstanceOf[TorrentInfo]
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
+ hasBlocks = 0
+
+ case None =>
+ Thread.sleep(500)
+ }
+ }
+ attemptId -= 1
+ }
+ if (totalBlocks == -1) {
+ return false
+ }
+
+ // Receive actual blocks
+ val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
+ for (pid <- recvOrder) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(pieceId) match {
+ case Some(x) =>
+ arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
+ hasBlocks += 1
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
+
+ case None =>
+ throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+ }
+ }
+ }
+
+ (hasBlocks == totalBlocks)
+ }
+
+}
+
+private object TorrentBroadcast
+extends Logging {
+
+ private var initialized = false
+
+ def initialize(_isDriver: Boolean) {
+ synchronized {
+ if (!initialized) {
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ initialized = false
+ }
+
+ val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024
+
+ def blockifyObject[T](obj: T): TorrentInfo = {
+ val byteArray = Utils.serialize[T](obj)
+ val bais = new ByteArrayInputStream(byteArray)
+
+ var blockNum = (byteArray.length / BLOCK_SIZE)
+ if (byteArray.length % BLOCK_SIZE != 0)
+ blockNum += 1
+
+ var retVal = new Array[TorrentBlock](blockNum)
+ var blockID = 0
+
+ for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
+ val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
+ var tempByteArray = new Array[Byte](thisBlockSize)
+ val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+
+ retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close()
+
+ var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
+ tInfo.hasBlocks = blockNum
+
+ return tInfo
+ }
+
+ def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
+ totalBytes: Int,
+ totalBlocks: Int): T = {
+ var retByteArray = new Array[Byte](totalBytes)
+ for (i <- 0 until totalBlocks) {
+ System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
+ i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
+ }
+ Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
+ }
+
+}
+
+private[spark] case class TorrentBlock(
+ blockID: Int,
+ byteArray: Array[Byte])
+ extends Serializable
+
+private[spark] case class TorrentInfo(
+ @transient arrayOfBlocks : Array[TorrentBlock],
+ totalBlocks: Int,
+ totalBytes: Int)
+ extends Serializable {
+
+ @transient var hasBlocks = 0
+}
+
+private[spark] class TorrentBroadcastFactory
+ extends BroadcastFactory {
+
+ def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new TorrentBroadcast[T](value_, isLocal, id)
+
+ def stop() { TorrentBroadcast.stop() }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
deleted file mode 100644
index 80c97ca073..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
+++ /dev/null
@@ -1,603 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.{Comparator, Random, UUID}
-
-import scala.collection.mutable.{ListBuffer, Map, Set}
-import scala.math
-
-import org.apache.spark._
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.Utils
-
-private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
-extends Broadcast[T](id) with Logging with Serializable {
-
- def value = value_
-
- def blockId = "broadcast_" + id
-
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = 0
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
- @transient var hasBlocksLock = new Object
-
- @transient var listOfSources = ListBuffer[SourceInfo]()
-
- @transient var serveMR: ServeMultipleRequests = null
- @transient var guideMR: GuideMultipleRequests = null
-
- @transient var hostAddress = Utils.localIpAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast()
- }
-
- def sendBroadcast() {
- logInfo("Local host address: " + hostAddress)
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = MultiTracker.blockifyObject(value_)
-
- // Prepare the value being broadcasted
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks = variableInfo.totalBlocks
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon(true)
- guideMR.start()
- logInfo("GuideMultipleRequests started...")
-
- // Must always come AFTER guideMR is created
- while (guidePort == -1) {
- guidePortLock.synchronized { guidePortLock.wait() }
- }
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- // Must always come AFTER serveMR is created
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Must always come AFTER listenPort is created
- val masterSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
- listOfSources += masterSource
-
- // Register with the Tracker
- MultiTracker.registerBroadcast(id,
- SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
- }
-
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(blockId) match {
- case Some(x) =>
- value_ = x.asInstanceOf[T]
-
- case None =>
- logInfo("Started reading broadcast variable " + id)
- // Initializing everything because Driver will only send null/0 values
- // Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables()
-
- logInfo("Local host address: " + hostAddress)
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(id)
- if (receptionSucceeded) {
- value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- private def initializeWorkerVariables() {
- arrayOfBlocks = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
-
- listenPortLock = new Object
- totalBlocksLock = new Object
- hasBlocksLock = new Object
-
- serveMR = null
-
- hostAddress = Utils.localIpAddress
- listenPort = -1
-
- stopBroadcast = false
- }
-
- def receiveBroadcast(variableID: Long): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableID)
-
- if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- var clientSocketToDriver: Socket = null
- var oosDriver: ObjectOutputStream = null
- var oisDriver: ObjectInputStream = null
-
- // Connect and receive broadcast from the specified source, retrying the
- // specified number of times in case of failures
- var retriesLeft = MultiTracker.MaxRetryCount
- do {
- // Connect to Driver and send this worker's Information
- clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
- oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
- oosDriver.flush()
- oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
-
- logDebug("Connected to Driver's guiding object")
-
- // Send local source information
- oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
- oosDriver.flush()
-
- // Receive source information from Driver
- var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
- totalBlocks = sourceInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
- totalBytes = sourceInfo.totalBytes
-
- logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
-
- val start = System.nanoTime
- val receptionSucceeded = receiveSingleTransmission(sourceInfo)
- val time = (System.nanoTime - start) / 1e9
-
- // Updating some statistics in sourceInfo. Driver will be using them later
- if (!receptionSucceeded) {
- sourceInfo.receptionFailed = true
- }
-
- // Send back statistics to the Driver
- oosDriver.writeObject(sourceInfo)
-
- if (oisDriver != null) {
- oisDriver.close()
- }
- if (oosDriver != null) {
- oosDriver.close()
- }
- if (clientSocketToDriver != null) {
- clientSocketToDriver.close()
- }
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && hasBlocks < totalBlocks)
-
- return (hasBlocks == totalBlocks)
- }
-
- /**
- * Tries to receive broadcast from the source and returns Boolean status.
- * This might be called multiple times to retry a defined number of times.
- */
- private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
- var clientSocketToSource: Socket = null
- var oosSource: ObjectOutputStream = null
- var oisSource: ObjectInputStream = null
-
- var receptionSucceeded = false
- try {
- // Connect to the source to get the object itself
- clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
- oosSource.flush()
- oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
-
- logDebug("Inside receiveSingleTransmission")
- logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
-
- // Send the range
- oosSource.writeObject((hasBlocks, totalBlocks))
- oosSource.flush()
-
- for (i <- hasBlocks until totalBlocks) {
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
-
- arrayOfBlocks(hasBlocks) = bcBlock
- hasBlocks += 1
-
- // Set to true if at least one block is received
- receptionSucceeded = true
- hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
- }
- } catch {
- case e: Exception => logError("receiveSingleTransmission had a " + e)
- } finally {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (clientSocketToSource != null) {
- clientSocketToSource.close()
- }
- }
-
- return receptionSucceeded
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo]()
-
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- guidePort = serverSocket.getLocalPort
- logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized { guidePortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
- listOfSources.synchronized {
- setOfCompletedSources.synchronized {
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
- logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
- }
- }
- }
- }
- }
- if (clientSocket != null) {
- logDebug("Guide: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new GuideSingleRequest(clientSocket))
- } catch {
- // In failure, close() the socket here; else, the thread will close() it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
-
- logInfo("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- MultiTracker.unregisterBroadcast(id)
- } finally {
- if (serverSocket != null) {
- logInfo("GuideMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- private def sendStopBroadcastNotifications() {
- listOfSources.synchronized {
- var listIter = listOfSources.iterator
- while (listIter.hasNext) {
- var sourceInfo = listIter.next
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
- gosSource.flush()
- gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
-
- // Send stopBroadcast signal
- gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
- gosSource.flush()
- } catch {
- case e: Exception => {
- logError("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close()
- }
- if (gosSource != null) {
- gosSource.close()
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close()
- }
- }
- }
- }
- }
-
- class GuideSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var selectedSourceInfo: SourceInfo = null
- private var thisWorkerInfo:SourceInfo = null
-
- override def run() {
- try {
- logInfo("new GuideSingleRequest is running")
- // Connecting worker is sending in its hostAddress and listenPort it will
- // be listening to. Other fields are invalid (SourceInfo.UnusedParam)
- var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- listOfSources.synchronized {
- // Select a suitable source and send it back to the worker
- selectedSourceInfo = selectSuitableSource(sourceInfo)
- logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
- oos.writeObject(selectedSourceInfo)
- oos.flush()
-
- // Add this new (if it can finish) source to the list of sources
- thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
- sourceInfo.listenPort, totalBlocks, totalBytes)
- logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
- listOfSources += thisWorkerInfo
- }
-
- // Wait till the whole transfer is done. Then receive and update source
- // statistics in listOfSources
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- listOfSources.synchronized {
- // This should work since SourceInfo is a case class
- assert(listOfSources.contains(selectedSourceInfo))
-
- // Remove first
- // (Currently removing a source based on just one failure notification!)
- listOfSources = listOfSources - selectedSourceInfo
-
- // Update sourceInfo and put it back in, IF reception succeeded
- if (!sourceInfo.receptionFailed) {
- // Add thisWorkerInfo to sources that have completed reception
- setOfCompletedSources.synchronized {
- setOfCompletedSources += thisWorkerInfo
- }
-
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- listOfSources += selectedSourceInfo
- }
- }
- } catch {
- case e: Exception => {
- // Remove failed worker from listOfSources and update leecherCount of
- // corresponding source worker
- listOfSources.synchronized {
- if (selectedSourceInfo != null) {
- // Remove first
- listOfSources = listOfSources - selectedSourceInfo
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- listOfSources += selectedSourceInfo
- }
-
- // Remove thisWorkerInfo
- if (listOfSources != null) {
- listOfSources = listOfSources - thisWorkerInfo
- }
- }
- }
- } finally {
- logInfo("GuideSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- // Assuming the caller to have a synchronized block on listOfSources
- // Select one with the most leechers. This will level-wise fill the tree
- private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
- var maxLeechers = -1
- var selectedSource: SourceInfo = null
-
- listOfSources.foreach { source =>
- if ((source.hostAddress != skipSourceInfo.hostAddress ||
- source.listenPort != skipSourceInfo.listenPort) &&
- source.currentLeechers < MultiTracker.MaxDegree &&
- source.currentLeechers > maxLeechers) {
- selectedSource = source
- maxLeechers = source.currentLeechers
- }
- }
-
- // Update leecher count
- selectedSource.currentLeechers += 1
- return selectedSource
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
-
- var threadPool = Utils.newDaemonCachedThreadPool()
-
- override def run() {
- var serverSocket = new ServerSocket(0)
- listenPort = serverSocket.getLocalPort
-
- logInfo("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized { listenPortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => { }
- }
-
- if (clientSocket != null) {
- logDebug("Serve: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new ServeSingleRequest(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ServeMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ServeSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var sendFrom = 0
- private var sendUntil = totalBlocks
-
- override def run() {
- try {
- logInfo("new ServeSingleRequest is running")
-
- // Receive range to send
- var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
- sendFrom = rangeToSend._1
- sendUntil = rangeToSend._2
-
- // If not a valid range, stop broadcast
- if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- sendObject
- }
- } catch {
- case e: Exception => logError("ServeSingleRequest had a " + e)
- } finally {
- logInfo("ServeSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- private def sendObject() {
- // Wait till receiving the SourceInfo from Driver
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized { totalBlocksLock.wait() }
- }
-
- for (i <- sendFrom until sendUntil) {
- while (i == hasBlocks) {
- hasBlocksLock.synchronized { hasBlocksLock.wait() }
- }
- try {
- oos.writeObject(arrayOfBlocks(i))
- oos.flush()
- } catch {
- case e: Exception => logError("sendObject had a " + e)
- }
- logDebug("Sent block: " + i + " to " + clientSocket)
- }
- }
- }
- }
-}
-
-private[spark] class TreeBroadcastFactory
-extends BroadcastFactory {
- def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new TreeBroadcast[T](value_, isLocal, id)
-
- def stop() { MultiTracker.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index 1cfff5e565..275331724a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -21,12 +21,14 @@ import scala.collection.immutable.List
import org.apache.spark.deploy.ExecutorState.ExecutorState
import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo}
+import org.apache.spark.deploy.master.RecoveryState.MasterState
import org.apache.spark.deploy.worker.ExecutorRunner
import org.apache.spark.util.Utils
private[deploy] sealed trait DeployMessage extends Serializable
+/** Contains messages sent between Scheduler actor nodes. */
private[deploy] object DeployMessages {
// Worker to Master
@@ -52,17 +54,20 @@ private[deploy] object DeployMessages {
exitStatus: Option[Int])
extends DeployMessage
+ case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription])
+
case class Heartbeat(workerId: String) extends DeployMessage
// Master to Worker
- case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
+ case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage
case class RegisterWorkerFailed(message: String) extends DeployMessage
- case class KillExecutor(appId: String, execId: Int) extends DeployMessage
+ case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage
case class LaunchExecutor(
+ masterUrl: String,
appId: String,
execId: Int,
appDesc: ApplicationDescription,
@@ -76,9 +81,11 @@ private[deploy] object DeployMessages {
case class RegisterApplication(appDescription: ApplicationDescription)
extends DeployMessage
+ case class MasterChangeAcknowledged(appId: String)
+
// Master to Client
- case class RegisteredApplication(appId: String) extends DeployMessage
+ case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage
// TODO(matei): replace hostPort with host
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
@@ -94,6 +101,10 @@ private[deploy] object DeployMessages {
case object StopClient
+ // Master to Worker & Client
+
+ case class MasterChanged(masterUrl: String, masterWebUiUrl: String)
+
// MasterWebUI To Master
case object RequestMasterState
@@ -101,7 +112,8 @@ private[deploy] object DeployMessages {
// Master to MasterWebUI
case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo],
- activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
+ activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo],
+ status: MasterState) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
@@ -123,12 +135,7 @@ private[deploy] object DeployMessages {
assert (port > 0)
}
- // Actor System to Master
-
- case object CheckForWorkerTimeOut
-
- case object RequestWebUIPort
-
- case class WebUIPortResponse(webUIBoundPort: Int)
+ // Actor System to Worker
+ case object SendHeartbeat
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala
new file mode 100644
index 0000000000..2abf0b69dd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+/**
+ * Used to send state on-the-wire about Executors from Worker to Master.
+ * This state is sufficient for the Master to reconstruct its internal data structures during
+ * failover.
+ */
+private[spark] class ExecutorDescription(
+ val appId: String,
+ val execId: Int,
+ val cores: Int,
+ val state: ExecutorState.Value)
+ extends Serializable {
+
+ override def toString: String =
+ "ExecutorState(appId=%s, execId=%d, cores=%d, state=%s)".format(appId, execId, cores, state)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
new file mode 100644
index 0000000000..668032a3a2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -0,0 +1,420 @@
+/*
+ *
+ * * Licensed to the Apache Software Foundation (ASF) under one or more
+ * * contributor license agreements. See the NOTICE file distributed with
+ * * this work for additional information regarding copyright ownership.
+ * * The ASF licenses this file to You under the Apache License, Version 2.0
+ * * (the "License"); you may not use this file except in compliance with
+ * * the License. You may obtain a copy of the License at
+ * *
+ * * http://www.apache.org/licenses/LICENSE-2.0
+ * *
+ * * Unless required by applicable law or agreed to in writing, software
+ * * distributed under the License is distributed on an "AS IS" BASIS,
+ * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * * See the License for the specific language governing permissions and
+ * * limitations under the License.
+ *
+ */
+
+package org.apache.spark.deploy
+
+import java.io._
+import java.net.URL
+import java.util.concurrent.TimeoutException
+
+import scala.concurrent.{Await, future, promise}
+import scala.concurrent.duration._
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.collection.mutable.ListBuffer
+import scala.sys.process._
+
+import net.liftweb.json.JsonParser
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.deploy.master.RecoveryState
+
+/**
+ * This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master.
+ * In order to mimic a real distributed cluster more closely, Docker is used.
+ * Execute using
+ * ./spark-class org.apache.spark.deploy.FaultToleranceTest
+ *
+ * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS:
+ * - spark.deploy.recoveryMode=ZOOKEEPER
+ * - spark.deploy.zookeeper.url=172.17.42.1:2181
+ * Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port.
+ *
+ * Unfortunately, due to the Docker dependency this suite cannot be run automatically without a
+ * working installation of Docker. In addition to having Docker, the following are assumed:
+ * - Docker can run without sudo (see http://docs.docker.io/en/latest/use/basics/)
+ * - The docker images tagged spark-test-master and spark-test-worker are built from the
+ * docker/ directory. Run 'docker/spark-test/build' to generate these.
+ */
+private[spark] object FaultToleranceTest extends App with Logging {
+ val masters = ListBuffer[TestMasterInfo]()
+ val workers = ListBuffer[TestWorkerInfo]()
+ var sc: SparkContext = _
+
+ var numPassed = 0
+ var numFailed = 0
+
+ val sparkHome = System.getenv("SPARK_HOME")
+ assertTrue(sparkHome != null, "Run with a valid SPARK_HOME")
+
+ val containerSparkHome = "/opt/spark"
+ val dockerMountDir = "%s:%s".format(sparkHome, containerSparkHome)
+
+ System.setProperty("spark.driver.host", "172.17.42.1") // default docker host ip
+
+ def afterEach() {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ terminateCluster()
+ }
+
+ test("sanity-basic") {
+ addMasters(1)
+ addWorkers(1)
+ createClient()
+ assertValidClusterState()
+ }
+
+ test("sanity-many-masters") {
+ addMasters(3)
+ addWorkers(3)
+ createClient()
+ assertValidClusterState()
+ }
+
+ test("single-master-halt") {
+ addMasters(3)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ killLeader()
+ delay(30 seconds)
+ assertValidClusterState()
+ createClient()
+ assertValidClusterState()
+ }
+
+ test("single-master-restart") {
+ addMasters(1)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ killLeader()
+ addMasters(1)
+ delay(30 seconds)
+ assertValidClusterState()
+
+ killLeader()
+ addMasters(1)
+ delay(30 seconds)
+ assertValidClusterState()
+ }
+
+ test("cluster-failure") {
+ addMasters(2)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ terminateCluster()
+ addMasters(2)
+ addWorkers(2)
+ assertValidClusterState()
+ }
+
+ test("all-but-standby-failure") {
+ addMasters(2)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+
+ killLeader()
+ workers.foreach(_.kill())
+ workers.clear()
+ delay(30 seconds)
+ addWorkers(2)
+ assertValidClusterState()
+ }
+
+ test("rolling-outage") {
+ addMasters(1)
+ delay()
+ addMasters(1)
+ delay()
+ addMasters(1)
+ addWorkers(2)
+ createClient()
+ assertValidClusterState()
+ assertTrue(getLeader == masters.head)
+
+ (1 to 3).foreach { _ =>
+ killLeader()
+ delay(30 seconds)
+ assertValidClusterState()
+ assertTrue(getLeader == masters.head)
+ addMasters(1)
+ }
+ }
+
+ def test(name: String)(fn: => Unit) {
+ try {
+ fn
+ numPassed += 1
+ logInfo("Passed: " + name)
+ } catch {
+ case e: Exception =>
+ numFailed += 1
+ logError("FAILED: " + name, e)
+ }
+ afterEach()
+ }
+
+ def addMasters(num: Int) {
+ (1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) }
+ }
+
+ def addWorkers(num: Int) {
+ val masterUrls = getMasterUrls(masters)
+ (1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) }
+ }
+
+ /** Creates a SparkContext, which constructs a Client to interact with our cluster. */
+ def createClient() = {
+ if (sc != null) { sc.stop() }
+ // Counter-hack: Because of a hack in SparkEnv#createFromSystemProperties() that changes this
+ // property, we need to reset it.
+ System.setProperty("spark.driver.port", "0")
+ sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome)
+ }
+
+ def getMasterUrls(masters: Seq[TestMasterInfo]): String = {
+ "spark://" + masters.map(master => master.ip + ":7077").mkString(",")
+ }
+
+ def getLeader: TestMasterInfo = {
+ val leaders = masters.filter(_.state == RecoveryState.ALIVE)
+ assertTrue(leaders.size == 1)
+ leaders(0)
+ }
+
+ def killLeader(): Unit = {
+ masters.foreach(_.readState())
+ val leader = getLeader
+ masters -= leader
+ leader.kill()
+ }
+
+ def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis)
+
+ def terminateCluster() {
+ masters.foreach(_.kill())
+ workers.foreach(_.kill())
+ masters.clear()
+ workers.clear()
+ }
+
+ /** This includes Client retry logic, so it may take a while if the cluster is recovering. */
+ def assertUsable() = {
+ val f = future {
+ try {
+ val res = sc.parallelize(0 until 10).collect()
+ assertTrue(res.toList == (0 until 10))
+ true
+ } catch {
+ case e: Exception =>
+ logError("assertUsable() had exception", e)
+ e.printStackTrace()
+ false
+ }
+ }
+
+ // Avoid waiting indefinitely (e.g., we could register but get no executors).
+ assertTrue(Await.result(f, 120 seconds))
+ }
+
+ /**
+ * Asserts that the cluster is usable and that the expected masters and workers
+ * are all alive in a proper configuration (e.g., only one leader).
+ */
+ def assertValidClusterState() = {
+ assertUsable()
+ var numAlive = 0
+ var numStandby = 0
+ var numLiveApps = 0
+ var liveWorkerIPs: Seq[String] = List()
+
+ def stateValid(): Boolean = {
+ (workers.map(_.ip) -- liveWorkerIPs).isEmpty &&
+ numAlive == 1 && numStandby == masters.size - 1 && numLiveApps >= 1
+ }
+
+ val f = future {
+ try {
+ while (!stateValid()) {
+ Thread.sleep(1000)
+
+ numAlive = 0
+ numStandby = 0
+ numLiveApps = 0
+
+ masters.foreach(_.readState())
+
+ for (master <- masters) {
+ master.state match {
+ case RecoveryState.ALIVE =>
+ numAlive += 1
+ liveWorkerIPs = master.liveWorkerIPs
+ case RecoveryState.STANDBY =>
+ numStandby += 1
+ case _ => // ignore
+ }
+
+ numLiveApps += master.numLiveApps
+ }
+ }
+ true
+ } catch {
+ case e: Exception =>
+ logError("assertValidClusterState() had exception", e)
+ false
+ }
+ }
+
+ try {
+ assertTrue(Await.result(f, 120 seconds))
+ } catch {
+ case e: TimeoutException =>
+ logError("Master states: " + masters.map(_.state))
+ logError("Num apps: " + numLiveApps)
+ logError("IPs expected: " + workers.map(_.ip) + " / found: " + liveWorkerIPs)
+ throw new RuntimeException("Failed to get into acceptable cluster state after 2 min.", e)
+ }
+ }
+
+ def assertTrue(bool: Boolean, message: String = "") {
+ if (!bool) {
+ throw new IllegalStateException("Assertion failed: " + message)
+ }
+ }
+
+ logInfo("Ran %s tests, %s passed and %s failed".format(numPassed+numFailed, numPassed, numFailed))
+}
+
+private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File)
+ extends Logging {
+
+ implicit val formats = net.liftweb.json.DefaultFormats
+ var state: RecoveryState.Value = _
+ var liveWorkerIPs: List[String] = _
+ var numLiveApps = 0
+
+ logDebug("Created master: " + this)
+
+ def readState() {
+ try {
+ val masterStream = new InputStreamReader(new URL("http://%s:8080/json".format(ip)).openStream)
+ val json = JsonParser.parse(masterStream, closeAutomatically = true)
+
+ val workers = json \ "workers"
+ val liveWorkers = workers.children.filter(w => (w \ "state").extract[String] == "ALIVE")
+ liveWorkerIPs = liveWorkers.map(w => (w \ "host").extract[String])
+
+ numLiveApps = (json \ "activeapps").children.size
+
+ val status = json \\ "status"
+ val stateString = status.extract[String]
+ state = RecoveryState.values.filter(state => state.toString == stateString).head
+ } catch {
+ case e: Exception =>
+ // ignore, no state update
+ logWarning("Exception", e)
+ }
+ }
+
+ def kill() { Docker.kill(dockerId) }
+
+ override def toString: String =
+ "[ip=%s, id=%s, logFile=%s, state=%s]".
+ format(ip, dockerId.id, logFile.getAbsolutePath, state)
+}
+
+private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File)
+ extends Logging {
+
+ implicit val formats = net.liftweb.json.DefaultFormats
+
+ logDebug("Created worker: " + this)
+
+ def kill() { Docker.kill(dockerId) }
+
+ override def toString: String =
+ "[ip=%s, id=%s, logFile=%s]".format(ip, dockerId, logFile.getAbsolutePath)
+}
+
+private[spark] object SparkDocker {
+ def startMaster(mountDir: String): TestMasterInfo = {
+ val cmd = Docker.makeRunCmd("spark-test-master", mountDir = mountDir)
+ val (ip, id, outFile) = startNode(cmd)
+ new TestMasterInfo(ip, id, outFile)
+ }
+
+ def startWorker(mountDir: String, masters: String): TestWorkerInfo = {
+ val cmd = Docker.makeRunCmd("spark-test-worker", args = masters, mountDir = mountDir)
+ val (ip, id, outFile) = startNode(cmd)
+ new TestWorkerInfo(ip, id, outFile)
+ }
+
+ private def startNode(dockerCmd: ProcessBuilder) : (String, DockerId, File) = {
+ val ipPromise = promise[String]()
+ val outFile = File.createTempFile("fault-tolerance-test", "")
+ outFile.deleteOnExit()
+ val outStream: FileWriter = new FileWriter(outFile)
+ def findIpAndLog(line: String): Unit = {
+ if (line.startsWith("CONTAINER_IP=")) {
+ val ip = line.split("=")(1)
+ ipPromise.success(ip)
+ }
+
+ outStream.write(line + "\n")
+ outStream.flush()
+ }
+
+ dockerCmd.run(ProcessLogger(findIpAndLog _))
+ val ip = Await.result(ipPromise.future, 30 seconds)
+ val dockerId = Docker.getLastProcessId
+ (ip, dockerId, outFile)
+ }
+}
+
+private[spark] class DockerId(val id: String) {
+ override def toString = id
+}
+
+private[spark] object Docker extends Logging {
+ def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = {
+ val mountCmd = if (mountDir != "") { " -v " + mountDir } else ""
+
+ val cmd = "docker run %s %s %s".format(mountCmd, imageTag, args)
+ logDebug("Run command: " + cmd)
+ cmd
+ }
+
+ def kill(dockerId: DockerId) : Unit = {
+ "docker kill %s".format(dockerId.id).!
+ }
+
+ def getLastProcessId: DockerId = {
+ var id: String = null
+ "docker ps -l -q".!(ProcessLogger(line => id = line))
+ new DockerId(id)
+ }
+} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
index 04d01c169d..e607b8c6f4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
@@ -72,7 +72,8 @@ private[spark] object JsonProtocol {
("memory" -> obj.workers.map(_.memory).sum) ~
("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~
("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~
- ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo))
+ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~
+ ("status" -> obj.status.toString)
}
def writeWorkerState(obj: WorkerStateResponse) = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index 6a7d5a85ba..94cf4ff88b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -39,22 +39,23 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
- def start(): String = {
+ def start(): Array[String] = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localHostname + ":" + masterPort
+ val masters = Array(masterUrl)
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
- memoryPerWorker, masterUrl, null, Some(workerNum))
+ memoryPerWorker, masters, null, Some(workerNum))
workerActorSystems += workerSystem
}
- return masterUrl
+ return masters
}
def stop() {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 993ba6bd3d..c29a30184a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,28 +17,59 @@
package org.apache.spark.deploy
-import com.google.common.collect.MapMaker
+import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.spark.SparkException
/**
- * Contains util methods to interact with Hadoop from spark.
+ * Contains util methods to interact with Hadoop from Spark.
*/
+private[spark]
class SparkHadoopUtil {
- // A general, soft-reference map for metadata needed during HadoopRDD split computation
- // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
- private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
+ val conf = newConfiguration()
+ UserGroupInformation.setConfiguration(conf)
- // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop
- // subsystems
+ def runAsUser(user: String)(func: () => Unit) {
+ val ugi = UserGroupInformation.createRemoteUser(user)
+ ugi.doAs(new PrivilegedExceptionAction[Unit] {
+ def run: Unit = func()
+ })
+ }
+
+ /**
+ * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
+ * subsystems.
+ */
def newConfiguration(): Configuration = new Configuration()
- // Add any user credentials to the job conf which are necessary for running on a secure Hadoop
- // cluster
+ /**
+ * Add any user credentials to the job conf which are necessary for running on a secure Hadoop
+ * cluster.
+ */
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
+}
+
+object SparkHadoopUtil {
+ private val hadoop = {
+ val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
+ if (yarnMode) {
+ try {
+ Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
+ } catch {
+ case th: Throwable => throw new SparkException("Unable to load YARN support", th)
+ }
+ } else {
+ new SparkHadoopUtil
+ }
+ }
+ def get: SparkHadoopUtil = {
+ hadoop
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
index 164386782c..be8693ec54 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
@@ -21,6 +21,7 @@ import java.util.concurrent.TimeoutException
import scala.concurrent.duration._
import scala.concurrent.Await
+import scala.concurrent.ExecutionContext.Implicits.global
import akka.actor._
import akka.actor.Terminated
@@ -37,41 +38,81 @@ import org.apache.spark.deploy.master.Master
/**
* The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description,
* and a listener for cluster events, and calls back the listener when various events occur.
+ *
+ * @param masterUrls Each url should look like spark://host:port.
*/
private[spark] class Client(
actorSystem: ActorSystem,
- masterUrl: String,
+ masterUrls: Array[String],
appDescription: ApplicationDescription,
listener: ClientListener)
extends Logging {
+ val REGISTRATION_TIMEOUT = 20.seconds
+ val REGISTRATION_RETRIES = 3
+
var actor: ActorRef = null
var appId: String = null
+ var registered = false
+ var activeMasterUrl: String = null
class ClientActor extends Actor with Logging {
var master: ActorRef = null
var masterAddress: Address = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
+ var alreadyDead = false // To avoid calling listener.dead() multiple times
override def preStart() {
- logInfo("Connecting to master " + masterUrl)
try {
- master = context.actorFor(Master.toAkkaUrl(masterUrl))
- masterAddress = master.path.address
- master ! RegisterApplication(appDescription)
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ registerWithMaster()
} catch {
case e: Exception =>
- logError("Failed to connect to master", e)
+ logWarning("Failed to connect to master", e)
markDisconnected()
context.stop(self)
}
}
+ def tryRegisterAllMasters() {
+ for (masterUrl <- masterUrls) {
+ logInfo("Connecting to master " + masterUrl + "...")
+ val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ actor ! RegisterApplication(appDescription)
+ }
+ }
+
+ def registerWithMaster() {
+ tryRegisterAllMasters()
+
+ var retries = 0
+ lazy val retryTimer: Cancellable =
+ context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
+ retries += 1
+ if (registered) {
+ retryTimer.cancel()
+ } else if (retries >= REGISTRATION_RETRIES) {
+ logError("All masters are unresponsive! Giving up.")
+ markDead()
+ } else {
+ tryRegisterAllMasters()
+ }
+ }
+ retryTimer // start timer
+ }
+
+ def changeMaster(url: String) {
+ activeMasterUrl = url
+ master = context.actorFor(Master.toAkkaUrl(url))
+ masterAddress = master.path.address
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ }
+
override def receive = {
- case RegisteredApplication(appId_) =>
+ case RegisteredApplication(appId_, masterUrl) =>
appId = appId_
+ registered = true
+ changeMaster(masterUrl)
listener.connected(appId)
case ApplicationRemoved(message) =>
@@ -92,23 +133,27 @@ private[spark] class Client(
listener.executorRemoved(fullId, message.getOrElse(""), exitStatus)
}
+ case MasterChanged(masterUrl, masterWebUiUrl) =>
+ logInfo("Master has changed, new master is at " + masterUrl)
+ context.unwatch(master)
+ changeMaster(masterUrl)
+ alreadyDisconnected = false
+ sender ! MasterChangeAcknowledged(appId)
+
case Terminated(actor_) if actor_ == master =>
- logError("Connection to master failed; stopping client")
+ logWarning("Connection to master failed; waiting for master to reconnect...")
markDisconnected()
- context.stop(self)
case DisassociatedEvent(_, address, _) if address == masterAddress =>
logError("Connection to master failed; stopping client")
markDisconnected()
- context.stop(self)
case AssociationErrorEvent(_, _, address, _) if address == masterAddress =>
logError("Connection to master failed; stopping client")
markDisconnected()
- context.stop(self)
case StopClient =>
- markDisconnected()
+ markDead()
sender ! true
context.stop(self)
}
@@ -122,6 +167,13 @@ private[spark] class Client(
alreadyDisconnected = true
}
}
+
+ def markDead() {
+ if (!alreadyDead) {
+ listener.dead()
+ alreadyDead = true
+ }
+ }
}
def start() {
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
index 4605368c11..be7a11bd15 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
@@ -27,8 +27,12 @@ package org.apache.spark.deploy.client
private[spark] trait ClientListener {
def connected(appId: String): Unit
+ /** Disconnection may be a temporary state, as we fail over to a new Master. */
def disconnected(): Unit
+ /** Dead means that we couldn't find any Masters to connect to, and have given up. */
+ def dead(): Unit
+
def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
index d5e9a0e095..5b62d3ba6c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -33,6 +33,11 @@ private[spark] object TestClient {
System.exit(0)
}
+ def dead() {
+ logInfo("Could not connect to master")
+ System.exit(0)
+ }
+
def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
@@ -44,7 +49,7 @@ private[spark] object TestClient {
val desc = new ApplicationDescription(
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored")
val listener = new TestListener
- val client = new Client(actorSystem, url, desc, listener)
+ val client = new Client(actorSystem, Array(url), desc, listener)
client.start()
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index bd5327627a..5150b7c7de 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -29,23 +29,46 @@ private[spark] class ApplicationInfo(
val submitDate: Date,
val driver: ActorRef,
val appUiUrl: String)
-{
- var state = ApplicationState.WAITING
- var executors = new mutable.HashMap[Int, ExecutorInfo]
- var coresGranted = 0
- var endTime = -1L
- val appSource = new ApplicationSource(this)
-
- private var nextExecutorId = 0
-
- def newExecutorId(): Int = {
- val id = nextExecutorId
- nextExecutorId += 1
- id
+ extends Serializable {
+
+ @transient var state: ApplicationState.Value = _
+ @transient var executors: mutable.HashMap[Int, ExecutorInfo] = _
+ @transient var coresGranted: Int = _
+ @transient var endTime: Long = _
+ @transient var appSource: ApplicationSource = _
+
+ @transient private var nextExecutorId: Int = _
+
+ init()
+
+ private def readObject(in: java.io.ObjectInputStream) : Unit = {
+ in.defaultReadObject()
+ init()
+ }
+
+ private def init() {
+ state = ApplicationState.WAITING
+ executors = new mutable.HashMap[Int, ExecutorInfo]
+ coresGranted = 0
+ endTime = -1L
+ appSource = new ApplicationSource(this)
+ nextExecutorId = 0
+ }
+
+ private def newExecutorId(useID: Option[Int] = None): Int = {
+ useID match {
+ case Some(id) =>
+ nextExecutorId = math.max(nextExecutorId, id + 1)
+ id
+ case None =>
+ val id = nextExecutorId
+ nextExecutorId += 1
+ id
+ }
}
- def addExecutor(worker: WorkerInfo, cores: Int): ExecutorInfo = {
- val exec = new ExecutorInfo(newExecutorId(), this, worker, cores, desc.memoryPerSlave)
+ def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorInfo = {
+ val exec = new ExecutorInfo(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave)
executors(exec.id) = exec
coresGranted += cores
exec
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
index 39ef090ddf..a74d7be4c9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
@@ -22,7 +22,7 @@ private[spark] object ApplicationState
type ApplicationState = Value
- val WAITING, RUNNING, FINISHED, FAILED = Value
+ val WAITING, RUNNING, FINISHED, FAILED, UNKNOWN = Value
val MAX_NUM_RETRY = 10
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
index cf384a985e..76db61dd61 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.master
-import org.apache.spark.deploy.ExecutorState
+import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
private[spark] class ExecutorInfo(
val id: Int,
@@ -28,5 +28,10 @@ private[spark] class ExecutorInfo(
var state = ExecutorState.LAUNCHING
+ /** Copy all state (non-val) variables from the given on-the-wire ExecutorDescription. */
+ def copyState(execDesc: ExecutorDescription) {
+ state = execDesc.state
+ }
+
def fullId: String = application.id + "/" + id
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
new file mode 100644
index 0000000000..043945a211
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import java.io._
+
+import scala.Serializable
+
+import akka.serialization.Serialization
+import org.apache.spark.Logging
+
+/**
+ * Stores data in a single on-disk directory with one file per application and worker.
+ * Files are deleted when applications and workers are removed.
+ *
+ * @param dir Directory to store files. Created if non-existent (but not recursively).
+ * @param serialization Used to serialize our objects.
+ */
+private[spark] class FileSystemPersistenceEngine(
+ val dir: String,
+ val serialization: Serialization)
+ extends PersistenceEngine with Logging {
+
+ new File(dir).mkdir()
+
+ override def addApplication(app: ApplicationInfo) {
+ val appFile = new File(dir + File.separator + "app_" + app.id)
+ serializeIntoFile(appFile, app)
+ }
+
+ override def removeApplication(app: ApplicationInfo) {
+ new File(dir + File.separator + "app_" + app.id).delete()
+ }
+
+ override def addWorker(worker: WorkerInfo) {
+ val workerFile = new File(dir + File.separator + "worker_" + worker.id)
+ serializeIntoFile(workerFile, worker)
+ }
+
+ override def removeWorker(worker: WorkerInfo) {
+ new File(dir + File.separator + "worker_" + worker.id).delete()
+ }
+
+ override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
+ val sortedFiles = new File(dir).listFiles().sortBy(_.getName)
+ val appFiles = sortedFiles.filter(_.getName.startsWith("app_"))
+ val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
+ val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_"))
+ val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
+ (apps, workers)
+ }
+
+ private def serializeIntoFile(file: File, value: AnyRef) {
+ val created = file.createNewFile()
+ if (!created) { throw new IllegalStateException("Could not create file: " + file) }
+
+ val serializer = serialization.findSerializerFor(value)
+ val serialized = serializer.toBinary(value)
+
+ val out = new FileOutputStream(file)
+ out.write(serialized)
+ out.close()
+ }
+
+ def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
+ val fileData = new Array[Byte](file.length().asInstanceOf[Int])
+ val dis = new DataInputStream(new FileInputStream(file))
+ dis.readFully(fileData)
+ dis.close()
+
+ val clazz = m.runtimeClass.asInstanceOf[Class[T]]
+ val serializer = serialization.serializerFor(clazz)
+ serializer.fromBinary(fileData).asInstanceOf[T]
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
new file mode 100644
index 0000000000..f25a1ad3bf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import akka.actor.{Actor, ActorRef}
+
+import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
+
+/**
+ * A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it
+ * is the only Master serving requests.
+ * In addition to the API provided, the LeaderElectionAgent will use of the following messages
+ * to inform the Master of leader changes:
+ * [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]]
+ * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
+ */
+private[spark] trait LeaderElectionAgent extends Actor {
+ val masterActor: ActorRef
+}
+
+/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
+private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent {
+ override def preStart() {
+ masterActor ! ElectedLeader
+ }
+
+ override def receive = {
+ case _ =>
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index cb0fe6a850..26f980760d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -23,23 +23,25 @@ import java.text.SimpleDateFormat
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.concurrent.Await
import scala.concurrent.duration._
+import scala.concurrent.duration.{ Duration, FiniteDuration }
+import scala.concurrent.ExecutionContext.Implicits.global
import akka.actor._
import akka.pattern.ask
import akka.remote._
+import akka.util.Timeout
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
+import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{Utils, AkkaUtils}
-import akka.util.Timeout
import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed
import org.apache.spark.deploy.DeployMessages.KillExecutor
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
import scala.Some
-import org.apache.spark.deploy.DeployMessages.WebUIPortResponse
import org.apache.spark.deploy.DeployMessages.LaunchExecutor
import org.apache.spark.deploy.DeployMessages.RegisteredApplication
import org.apache.spark.deploy.DeployMessages.RegisterWorker
@@ -51,6 +53,8 @@ import org.apache.spark.deploy.DeployMessages.ApplicationRemoved
import org.apache.spark.deploy.DeployMessages.Heartbeat
import org.apache.spark.deploy.DeployMessages.RegisteredWorker
import akka.actor.Terminated
+import akka.serialization.SerializationExtension
+import java.util.concurrent.TimeUnit
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
@@ -58,7 +62,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt
val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt
-
+ val RECOVERY_DIR = System.getProperty("spark.deploy.recoveryDirectory", "")
+ val RECOVERY_MODE = System.getProperty("spark.deploy.recoveryMode", "NONE")
+
var nextAppNumber = 0
val workers = new HashSet[WorkerInfo]
val idToWorker = new HashMap[String, WorkerInfo]
@@ -88,52 +94,115 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
if (envVar != null) envVar else host
}
+ val masterUrl = "spark://" + host + ":" + port
+ var masterWebUiUrl: String = _
+
+ var state = RecoveryState.STANDBY
+
+ var persistenceEngine: PersistenceEngine = _
+
+ var leaderElectionAgent: ActorRef = _
+
// As a temporary workaround before better ways of configuring memory, we allow users to set
// a flag that will perform round-robin scheduling across the nodes (spreading out each app
// among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
override def preStart() {
- logInfo("Starting Spark master at spark://" + host + ":" + port)
+ logInfo("Starting Spark master at " + masterUrl)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.start()
- import context.dispatcher
+ masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get
context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
masterMetricsSystem.registerSource(masterSource)
masterMetricsSystem.start()
applicationMetricsSystem.start()
+
+ persistenceEngine = RECOVERY_MODE match {
+ case "ZOOKEEPER" =>
+ logInfo("Persisting recovery state to ZooKeeper")
+ new ZooKeeperPersistenceEngine(SerializationExtension(context.system))
+ case "FILESYSTEM" =>
+ logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
+ new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system))
+ case _ =>
+ new BlackHolePersistenceEngine()
+ }
+
+ leaderElectionAgent = RECOVERY_MODE match {
+ case "ZOOKEEPER" =>
+ context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl))
+ case _ =>
+ context.actorOf(Props(classOf[MonarchyLeaderAgent], self))
+ }
+ }
+
+ override def preRestart(reason: Throwable, message: Option[Any]) {
+ super.preRestart(reason, message) // calls postStop()!
+ logError("Master actor restarted due to exception", reason)
}
override def postStop() {
webUi.stop()
masterMetricsSystem.stop()
applicationMetricsSystem.stop()
+ persistenceEngine.close()
+ context.stop(leaderElectionAgent)
}
override def receive = {
- case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) => {
+ case ElectedLeader => {
+ val (storedApps, storedWorkers) = persistenceEngine.readPersistedData()
+ state = if (storedApps.isEmpty && storedWorkers.isEmpty)
+ RecoveryState.ALIVE
+ else
+ RecoveryState.RECOVERING
+
+ logInfo("I have been elected leader! New state: " + state)
+
+ if (state == RecoveryState.RECOVERING) {
+ beginRecovery(storedApps, storedWorkers)
+ context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() }
+ }
+ }
+
+ case RevokedLeadership => {
+ logError("Leadership has been revoked -- master shutting down.")
+ System.exit(0)
+ }
+
+ case RegisterWorker(id, host, workerPort, cores, memory, webUiPort, publicAddress) => {
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
host, workerPort, cores, Utils.megabytesToString(memory)))
- if (idToWorker.contains(id)) {
+ if (state == RecoveryState.STANDBY) {
+ // ignore, don't send response
+ } else if (idToWorker.contains(id)) {
sender ! RegisterWorkerFailed("Duplicate worker ID")
} else {
- addWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress)
+ val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
+ registerWorker(worker)
context.watch(sender) // This doesn't work with remote actors but helps for testing
- sender ! RegisteredWorker("http://" + masterPublicAddress + ":" + webUi.boundPort.get)
+ persistenceEngine.addWorker(worker)
+ sender ! RegisteredWorker(masterUrl, masterWebUiUrl)
schedule()
}
}
case RegisterApplication(description) => {
- logInfo("Registering app " + description.name)
- val app = addApplication(description, sender)
- logInfo("Registered app " + description.name + " with ID " + app.id)
- waitingApps += app
- context.watch(sender) // This doesn't work with remote actors but helps for testing
- sender ! RegisteredApplication(app.id)
- schedule()
+ if (state == RecoveryState.STANDBY) {
+ // ignore, don't send response
+ } else {
+ logInfo("Registering app " + description.name)
+ val app = createApplication(description, sender)
+ registerApplication(app)
+ logInfo("Registered app " + description.name + " with ID " + app.id)
+ context.watch(sender) // This doesn't work with remote actors but helps for testing
+ persistenceEngine.addApplication(app)
+ sender ! RegisteredApplication(app.id, masterUrl)
+ schedule()
+ }
}
case ExecutorStateChanged(appId, execId, state, message, exitStatus) => {
@@ -173,27 +242,63 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
+ case MasterChangeAcknowledged(appId) => {
+ idToApp.get(appId) match {
+ case Some(app) =>
+ logInfo("Application has been re-registered: " + appId)
+ app.state = ApplicationState.WAITING
+ case None =>
+ logWarning("Master change ack from unknown app: " + appId)
+ }
+
+ if (canCompleteRecovery) { completeRecovery() }
+ }
+
+ case WorkerSchedulerStateResponse(workerId, executors) => {
+ idToWorker.get(workerId) match {
+ case Some(worker) =>
+ logInfo("Worker has been re-registered: " + workerId)
+ worker.state = WorkerState.ALIVE
+
+ val validExecutors = executors.filter(exec => idToApp.get(exec.appId).isDefined)
+ for (exec <- validExecutors) {
+ val app = idToApp.get(exec.appId).get
+ val execInfo = app.addExecutor(worker, exec.cores, Some(exec.execId))
+ worker.addExecutor(execInfo)
+ execInfo.copyState(exec)
+ }
+ case None =>
+ logWarning("Scheduler state from unknown worker: " + workerId)
+ }
+
+ if (canCompleteRecovery) { completeRecovery() }
+ }
+
case Terminated(actor) => {
// The disconnected actor could've been either a worker or an app; remove whichever of
// those we have an entry for in the corresponding actor hashmap
actorToWorker.get(actor).foreach(removeWorker)
actorToApp.get(actor).foreach(finishApplication)
+ if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
}
case DisassociatedEvent(_, address, _) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(finishApplication)
+ if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
}
case AssociationErrorEvent(_, _, address, _) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(finishApplication)
+ if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
}
case RequestMasterState => {
- sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray)
+ sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray,
+ state)
}
case CheckForWorkerTimeOut => {
@@ -205,6 +310,50 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
+ def canCompleteRecovery =
+ workers.count(_.state == WorkerState.UNKNOWN) == 0 &&
+ apps.count(_.state == ApplicationState.UNKNOWN) == 0
+
+ def beginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo]) {
+ for (app <- storedApps) {
+ logInfo("Trying to recover app: " + app.id)
+ try {
+ registerApplication(app)
+ app.state = ApplicationState.UNKNOWN
+ app.driver ! MasterChanged(masterUrl, masterWebUiUrl)
+ } catch {
+ case e: Exception => logInfo("App " + app.id + " had exception on reconnect")
+ }
+ }
+
+ for (worker <- storedWorkers) {
+ logInfo("Trying to recover worker: " + worker.id)
+ try {
+ registerWorker(worker)
+ worker.state = WorkerState.UNKNOWN
+ worker.actor ! MasterChanged(masterUrl, masterWebUiUrl)
+ } catch {
+ case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect")
+ }
+ }
+ }
+
+ def completeRecovery() {
+ // Ensure "only-once" recovery semantics using a short synchronization period.
+ synchronized {
+ if (state != RecoveryState.RECOVERING) { return }
+ state = RecoveryState.COMPLETING_RECOVERY
+ }
+
+ // Kill off any workers and apps that didn't respond to us.
+ workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker)
+ apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication)
+
+ state = RecoveryState.ALIVE
+ schedule()
+ logInfo("Recovery complete - resuming operations!")
+ }
+
/**
* Can an app use the given worker? True if the worker has enough memory and we haven't already
* launched an executor for the app on it (right now the standalone backend doesn't like having
@@ -219,6 +368,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
* every time a new app joins or resource availability changes.
*/
def schedule() {
+ if (state != RecoveryState.ALIVE) { return }
// Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app
// in the queue, then the second app, etc.
if (spreadOutApps) {
@@ -266,14 +416,13 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
- worker.actor ! LaunchExecutor(
+ worker.actor ! LaunchExecutor(masterUrl,
exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
exec.application.driver ! ExecutorAdded(
exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
}
- def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
- publicAddress: String): WorkerInfo = {
+ def registerWorker(worker: WorkerInfo): Unit = {
// There may be one or more refs to dead workers on this same node (w/ different ID's),
// remove them.
workers.filter { w =>
@@ -281,12 +430,17 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}.foreach { w =>
workers -= w
}
- val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
+
+ val workerAddress = worker.actor.path.address
+ if (addressToWorker.contains(workerAddress)) {
+ logInfo("Attempted to re-register worker at same address: " + workerAddress)
+ return
+ }
+
workers += worker
idToWorker(worker.id) = worker
- actorToWorker(sender) = worker
- addressToWorker(sender.path.address) = worker
- worker
+ actorToWorker(worker.actor) = worker
+ addressToWorker(workerAddress) = worker
}
def removeWorker(worker: WorkerInfo) {
@@ -301,25 +455,36 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
exec.id, ExecutorState.LOST, Some("worker lost"), None)
exec.application.removeExecutor(exec)
}
+ persistenceEngine.removeWorker(worker)
}
- def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
+ def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
- val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
+ new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
+ }
+
+ def registerApplication(app: ApplicationInfo): Unit = {
+ val appAddress = app.driver.path.address
+ if (addressToWorker.contains(appAddress)) {
+ logInfo("Attempted to re-register application at same address: " + appAddress)
+ return
+ }
+
applicationMetricsSystem.registerSource(app.appSource)
apps += app
idToApp(app.id) = app
- actorToApp(driver) = app
- addressToApp(driver.path.address) = app
+ actorToApp(app.driver) = app
+ addressToApp(appAddress) = app
if (firstApp == None) {
firstApp = Some(app)
}
+ // TODO: What is firstApp?? Can we remove it?
val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray
- if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) {
+ if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= app.desc.memoryPerSlave)) {
logWarning("Could not find any workers with enough memory for " + firstApp.get.id)
}
- app
+ waitingApps += app
}
def finishApplication(app: ApplicationInfo) {
@@ -344,13 +509,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
waitingApps -= app
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
- exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
+ exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id)
exec.state = ExecutorState.KILLED
}
app.markFinished(state)
if (state != ApplicationState.FINISHED) {
app.driver ! ApplicationRemoved(state.toString)
}
+ persistenceEngine.removeApplication(app)
schedule()
}
}
@@ -404,8 +570,8 @@ private[spark] object Master {
def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), name = actorName)
- val timeoutDuration = Duration.create(
- System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val timeoutDuration : FiniteDuration = Duration.create(
+ System.getProperty("spark.akka.askTimeout", "10").toLong, TimeUnit.SECONDS)
implicit val timeout = Timeout(timeoutDuration)
val respFuture = actor ? RequestWebUIPort // ask pattern
val resp = Await.result(respFuture, timeoutDuration).asInstanceOf[WebUIPortResponse]
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
new file mode 100644
index 0000000000..74a9f8cd82
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+sealed trait MasterMessages extends Serializable
+
+/** Contains messages seen only by the Master and its associated entities. */
+private[master] object MasterMessages {
+
+ // LeaderElectionAgent to Master
+
+ case object ElectedLeader
+
+ case object RevokedLeadership
+
+ // Actor System to LeaderElectionAgent
+
+ case object CheckLeader
+
+ // Actor System to Master
+
+ case object CheckForWorkerTimeOut
+
+ case class BeginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo])
+
+ case object CompleteRecovery
+
+ case object RequestWebUIPort
+
+ case class WebUIPortResponse(webUIBoundPort: Int)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
new file mode 100644
index 0000000000..94b986caf2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+/**
+ * Allows Master to persist any state that is necessary in order to recover from a failure.
+ * The following semantics are required:
+ * - addApplication and addWorker are called before completing registration of a new app/worker.
+ * - removeApplication and removeWorker are called at any time.
+ * Given these two requirements, we will have all apps and workers persisted, but
+ * we might not have yet deleted apps or workers that finished (so their liveness must be verified
+ * during recovery).
+ */
+private[spark] trait PersistenceEngine {
+ def addApplication(app: ApplicationInfo)
+
+ def removeApplication(app: ApplicationInfo)
+
+ def addWorker(worker: WorkerInfo)
+
+ def removeWorker(worker: WorkerInfo)
+
+ /**
+ * Returns the persisted data sorted by their respective ids (which implies that they're
+ * sorted by time of creation).
+ */
+ def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo])
+
+ def close() {}
+}
+
+private[spark] class BlackHolePersistenceEngine extends PersistenceEngine {
+ override def addApplication(app: ApplicationInfo) {}
+ override def removeApplication(app: ApplicationInfo) {}
+ override def addWorker(worker: WorkerInfo) {}
+ override def removeWorker(worker: WorkerInfo) {}
+ override def readPersistedData() = (Nil, Nil)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
new file mode 100644
index 0000000000..b91be821f0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+private[spark] object RecoveryState
+ extends Enumeration("STANDBY", "ALIVE", "RECOVERING", "COMPLETING_RECOVERY") {
+
+ type MasterState = Value
+
+ val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
new file mode 100644
index 0000000000..81e15c534f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import scala.collection.JavaConversions._
+import scala.concurrent.ops._
+
+import org.apache.spark.Logging
+import org.apache.zookeeper._
+import org.apache.zookeeper.data.Stat
+import org.apache.zookeeper.Watcher.Event.KeeperState
+
+/**
+ * Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry
+ * logic. If the ZooKeeper session expires or otherwise dies, a new ZooKeeper session will be
+ * created. If ZooKeeper remains down after several retries, the given
+ * [[org.apache.spark.deploy.master.SparkZooKeeperWatcher SparkZooKeeperWatcher]] will be
+ * informed via zkDown().
+ *
+ * Additionally, all commands sent to ZooKeeper will be retried until they either fail too many
+ * times or a semantic exception is thrown (e.g.., "node already exists").
+ */
+private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) extends Logging {
+ val ZK_URL = System.getProperty("spark.deploy.zookeeper.url", "")
+
+ val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE
+ val ZK_TIMEOUT_MILLIS = 30000
+ val RETRY_WAIT_MILLIS = 5000
+ val ZK_CHECK_PERIOD_MILLIS = 10000
+ val MAX_RECONNECT_ATTEMPTS = 3
+
+ private var zk: ZooKeeper = _
+
+ private val watcher = new ZooKeeperWatcher()
+ private var reconnectAttempts = 0
+ private var closed = false
+
+ /** Connect to ZooKeeper to start the session. Must be called before anything else. */
+ def connect() {
+ connectToZooKeeper()
+
+ new Thread() {
+ override def run() = sessionMonitorThread()
+ }.start()
+ }
+
+ def sessionMonitorThread(): Unit = {
+ while (!closed) {
+ Thread.sleep(ZK_CHECK_PERIOD_MILLIS)
+ if (zk.getState != ZooKeeper.States.CONNECTED) {
+ reconnectAttempts += 1
+ val attemptsLeft = MAX_RECONNECT_ATTEMPTS - reconnectAttempts
+ if (attemptsLeft <= 0) {
+ logError("Could not connect to ZooKeeper: system failure")
+ zkWatcher.zkDown()
+ close()
+ } else {
+ logWarning("ZooKeeper connection failed, retrying " + attemptsLeft + " more times...")
+ connectToZooKeeper()
+ }
+ }
+ }
+ }
+
+ def close() {
+ if (!closed && zk != null) { zk.close() }
+ closed = true
+ }
+
+ private def connectToZooKeeper() {
+ if (zk != null) zk.close()
+ zk = new ZooKeeper(ZK_URL, ZK_TIMEOUT_MILLIS, watcher)
+ }
+
+ /**
+ * Attempts to maintain a live ZooKeeper exception despite (very) transient failures.
+ * Mainly useful for handling the natural ZooKeeper session expiration.
+ */
+ private class ZooKeeperWatcher extends Watcher {
+ def process(event: WatchedEvent) {
+ if (closed) { return }
+
+ event.getState match {
+ case KeeperState.SyncConnected =>
+ reconnectAttempts = 0
+ zkWatcher.zkSessionCreated()
+ case KeeperState.Expired =>
+ connectToZooKeeper()
+ case KeeperState.Disconnected =>
+ logWarning("ZooKeeper disconnected, will retry...")
+ }
+ }
+ }
+
+ def create(path: String, bytes: Array[Byte], createMode: CreateMode): String = {
+ retry {
+ zk.create(path, bytes, ZK_ACL, createMode)
+ }
+ }
+
+ def exists(path: String, watcher: Watcher = null): Stat = {
+ retry {
+ zk.exists(path, watcher)
+ }
+ }
+
+ def getChildren(path: String, watcher: Watcher = null): List[String] = {
+ retry {
+ zk.getChildren(path, watcher).toList
+ }
+ }
+
+ def getData(path: String): Array[Byte] = {
+ retry {
+ zk.getData(path, false, null)
+ }
+ }
+
+ def delete(path: String, version: Int = -1): Unit = {
+ retry {
+ zk.delete(path, version)
+ }
+ }
+
+ /**
+ * Creates the given directory (non-recursively) if it doesn't exist.
+ * All znodes are created in PERSISTENT mode with no data.
+ */
+ def mkdir(path: String) {
+ if (exists(path) == null) {
+ try {
+ create(path, "".getBytes, CreateMode.PERSISTENT)
+ } catch {
+ case e: Exception =>
+ // If the exception caused the directory not to be created, bubble it up,
+ // otherwise ignore it.
+ if (exists(path) == null) { throw e }
+ }
+ }
+ }
+
+ /**
+ * Recursively creates all directories up to the given one.
+ * All znodes are created in PERSISTENT mode with no data.
+ */
+ def mkdirRecursive(path: String) {
+ var fullDir = ""
+ for (dentry <- path.split("/").tail) {
+ fullDir += "/" + dentry
+ mkdir(fullDir)
+ }
+ }
+
+ /**
+ * Retries the given function up to 3 times. The assumption is that failure is transient,
+ * UNLESS it is a semantic exception (i.e., trying to get data from a node that doesn't exist),
+ * in which case the exception will be thrown without retries.
+ *
+ * @param fn Block to execute, possibly multiple times.
+ */
+ def retry[T](fn: => T, n: Int = MAX_RECONNECT_ATTEMPTS): T = {
+ try {
+ fn
+ } catch {
+ case e: KeeperException.NoNodeException => throw e
+ case e: KeeperException.NodeExistsException => throw e
+ case e if n > 0 =>
+ logError("ZooKeeper exception, " + n + " more retries...", e)
+ Thread.sleep(RETRY_WAIT_MILLIS)
+ retry(fn, n-1)
+ }
+ }
+}
+
+trait SparkZooKeeperWatcher {
+ /**
+ * Called whenever a ZK session is created --
+ * this will occur when we create our first session as well as each time
+ * the session expires or errors out.
+ */
+ def zkSessionCreated()
+
+ /**
+ * Called if ZK appears to be completely down (i.e., not just a transient error).
+ * We will no longer attempt to reconnect to ZK, and the SparkZooKeeperSession is considered dead.
+ */
+ def zkDown()
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index 6219f11f2a..e05f587b58 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -22,28 +22,44 @@ import scala.collection.mutable
import org.apache.spark.util.Utils
private[spark] class WorkerInfo(
- val id: String,
- val host: String,
- val port: Int,
- val cores: Int,
- val memory: Int,
- val actor: ActorRef,
- val webUiPort: Int,
- val publicAddress: String) {
+ val id: String,
+ val host: String,
+ val port: Int,
+ val cores: Int,
+ val memory: Int,
+ val actor: ActorRef,
+ val webUiPort: Int,
+ val publicAddress: String)
+ extends Serializable {
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
- var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
- var state: WorkerState.Value = WorkerState.ALIVE
- var coresUsed = 0
- var memoryUsed = 0
+ @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // fullId => info
+ @transient var state: WorkerState.Value = _
+ @transient var coresUsed: Int = _
+ @transient var memoryUsed: Int = _
- var lastHeartbeat = System.currentTimeMillis()
+ @transient var lastHeartbeat: Long = _
+
+ init()
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
+ private def readObject(in: java.io.ObjectInputStream) : Unit = {
+ in.defaultReadObject()
+ init()
+ }
+
+ private def init() {
+ executors = new mutable.HashMap
+ state = WorkerState.ALIVE
+ coresUsed = 0
+ memoryUsed = 0
+ lastHeartbeat = System.currentTimeMillis()
+ }
+
def hostPort: String = {
assert (port > 0)
host + ":" + port
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
index fb3fe88d92..0b36ef6005 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
@@ -20,5 +20,5 @@ package org.apache.spark.deploy.master
private[spark] object WorkerState extends Enumeration {
type WorkerState = Value
- val ALIVE, DEAD, DECOMMISSIONED = Value
+ val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
new file mode 100644
index 0000000000..7809013e83
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import akka.actor.ActorRef
+import org.apache.zookeeper._
+import org.apache.zookeeper.Watcher.Event.EventType
+
+import org.apache.spark.deploy.master.MasterMessages._
+import org.apache.spark.Logging
+
+private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, masterUrl: String)
+ extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging {
+
+ val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
+
+ private val watcher = new ZooKeeperWatcher()
+ private val zk = new SparkZooKeeperSession(this)
+ private var status = LeadershipStatus.NOT_LEADER
+ private var myLeaderFile: String = _
+ private var leaderUrl: String = _
+
+ override def preStart() {
+ logInfo("Starting ZooKeeper LeaderElection agent")
+ zk.connect()
+ }
+
+ override def zkSessionCreated() {
+ synchronized {
+ zk.mkdirRecursive(WORKING_DIR)
+ myLeaderFile =
+ zk.create(WORKING_DIR + "/master_", masterUrl.getBytes, CreateMode.EPHEMERAL_SEQUENTIAL)
+ self ! CheckLeader
+ }
+ }
+
+ override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) {
+ logError("LeaderElectionAgent failed, waiting " + zk.ZK_TIMEOUT_MILLIS + "...", reason)
+ Thread.sleep(zk.ZK_TIMEOUT_MILLIS)
+ super.preRestart(reason, message)
+ }
+
+ override def zkDown() {
+ logError("ZooKeeper down! LeaderElectionAgent shutting down Master.")
+ System.exit(1)
+ }
+
+ override def postStop() {
+ zk.close()
+ }
+
+ override def receive = {
+ case CheckLeader => checkLeader()
+ }
+
+ private class ZooKeeperWatcher extends Watcher {
+ def process(event: WatchedEvent) {
+ if (event.getType == EventType.NodeDeleted) {
+ logInfo("Leader file disappeared, a master is down!")
+ self ! CheckLeader
+ }
+ }
+ }
+
+ /** Uses ZK leader election. Navigates several ZK potholes along the way. */
+ def checkLeader() {
+ val masters = zk.getChildren(WORKING_DIR).toList
+ val leader = masters.sorted.head
+ val leaderFile = WORKING_DIR + "/" + leader
+
+ // Setup a watch for the current leader.
+ zk.exists(leaderFile, watcher)
+
+ try {
+ leaderUrl = new String(zk.getData(leaderFile))
+ } catch {
+ // A NoNodeException may be thrown if old leader died since the start of this method call.
+ // This is fine -- just check again, since we're guaranteed to see the new values.
+ case e: KeeperException.NoNodeException =>
+ logInfo("Leader disappeared while reading it -- finding next leader")
+ checkLeader()
+ return
+ }
+
+ // Synchronization used to ensure no interleaving between the creation of a new session and the
+ // checking of a leader, which could cause us to delete our real leader file erroneously.
+ synchronized {
+ val isLeader = myLeaderFile == leaderFile
+ if (!isLeader && leaderUrl == masterUrl) {
+ // We found a different master file pointing to this process.
+ // This can happen in the following two cases:
+ // (1) The master process was restarted on the same node.
+ // (2) The ZK server died between creating the node and returning the name of the node.
+ // For this case, we will end up creating a second file, and MUST explicitly delete the
+ // first one, since our ZK session is still open.
+ // Note that this deletion will cause a NodeDeleted event to be fired so we check again for
+ // leader changes.
+ assert(leaderFile < myLeaderFile)
+ logWarning("Cleaning up old ZK master election file that points to this master.")
+ zk.delete(leaderFile)
+ } else {
+ updateLeadershipStatus(isLeader)
+ }
+ }
+ }
+
+ def updateLeadershipStatus(isLeader: Boolean) {
+ if (isLeader && status == LeadershipStatus.NOT_LEADER) {
+ status = LeadershipStatus.LEADER
+ masterActor ! ElectedLeader
+ } else if (!isLeader && status == LeadershipStatus.LEADER) {
+ status = LeadershipStatus.NOT_LEADER
+ masterActor ! RevokedLeadership
+ }
+ }
+
+ private object LeadershipStatus extends Enumeration {
+ type LeadershipStatus = Value
+ val LEADER, NOT_LEADER = Value
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
new file mode 100644
index 0000000000..825344b3bb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import org.apache.spark.Logging
+import org.apache.zookeeper._
+
+import akka.serialization.Serialization
+
+class ZooKeeperPersistenceEngine(serialization: Serialization)
+ extends PersistenceEngine
+ with SparkZooKeeperWatcher
+ with Logging
+{
+ val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
+
+ val zk = new SparkZooKeeperSession(this)
+
+ zk.connect()
+
+ override def zkSessionCreated() {
+ zk.mkdirRecursive(WORKING_DIR)
+ }
+
+ override def zkDown() {
+ logError("PersistenceEngine disconnected from ZooKeeper -- ZK looks down.")
+ }
+
+ override def addApplication(app: ApplicationInfo) {
+ serializeIntoFile(WORKING_DIR + "/app_" + app.id, app)
+ }
+
+ override def removeApplication(app: ApplicationInfo) {
+ zk.delete(WORKING_DIR + "/app_" + app.id)
+ }
+
+ override def addWorker(worker: WorkerInfo) {
+ serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker)
+ }
+
+ override def removeWorker(worker: WorkerInfo) {
+ zk.delete(WORKING_DIR + "/worker_" + worker.id)
+ }
+
+ override def close() {
+ zk.close()
+ }
+
+ override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
+ val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted
+ val appFiles = sortedFiles.filter(_.startsWith("app_"))
+ val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
+ val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
+ val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
+ (apps, workers)
+ }
+
+ private def serializeIntoFile(path: String, value: AnyRef) {
+ val serializer = serialization.findSerializerFor(value)
+ val serialized = serializer.toBinary(value)
+ zk.create(path, serialized, CreateMode.PERSISTENT)
+ }
+
+ def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = {
+ val fileData = zk.getData("/spark/master_status/" + filename)
+ val clazz = m.runtimeClass.asInstanceOf[Class[T]]
+ val serializer = serialization.serializerFor(clazz)
+ serializer.fromBinary(fileData).asInstanceOf[T]
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index e3dc30eefc..fff9cb60c7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -43,7 +43,8 @@ private[spark] class ExecutorRunner(
val workerId: String,
val host: String,
val sparkHome: File,
- val workDir: File)
+ val workDir: File,
+ var state: ExecutorState.Value)
extends Logging {
val fullId = appId + "/" + execId
@@ -83,7 +84,8 @@ private[spark] class ExecutorRunner(
process.destroy()
process.waitFor()
}
- worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None)
+ state = ExecutorState.KILLED
+ worker ! ExecutorStateChanged(appId, execId, state, None, None)
Runtime.getRuntime.removeShutdownHook(shutdownHook)
}
}
@@ -102,7 +104,7 @@ private[spark] class ExecutorRunner(
// SPARK-698: do not call the run.cmd script, as process.destroy()
// fails to kill a process tree on Windows
Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
- command.arguments.map(substituteVariables)
+ (command.arguments ++ Seq(appId)).map(substituteVariables)
}
/**
@@ -180,9 +182,9 @@ private[spark] class ExecutorRunner(
// long-lived processes only. However, in the future, we might restart the executor a few
// times on the same machine.
val exitCode = process.waitFor()
+ state = ExecutorState.FAILED
val message = "Command exited with code " + exitCode
- worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message),
- Some(exitCode))
+ worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))
} catch {
case interrupted: InterruptedException =>
logInfo("Runner thread for executor " + fullId + " interrupted")
@@ -192,8 +194,9 @@ private[spark] class ExecutorRunner(
if (process != null) {
process.destroy()
}
+ state = ExecutorState.FAILED
val message = e.getClass + ": " + e.getMessage
- worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None)
+ worker ! ExecutorStateChanged(appId, execId, state, Some(message), None)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 3904b701b2..991b22d9f8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -23,26 +23,42 @@ import java.io.File
import scala.collection.mutable.HashMap
import scala.concurrent.duration._
+import scala.concurrent.ExecutionContext.Implicits.global
-import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
+import akka.actor._
import akka.remote.{RemotingLifecycleEvent, AssociationErrorEvent, DisassociatedEvent}
import org.apache.spark.Logging
-import org.apache.spark.deploy.ExecutorState
+import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
import org.apache.spark.deploy.worker.ui.WorkerWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{Utils, AkkaUtils}
-
-
+import org.apache.spark.deploy.DeployMessages.WorkerStateResponse
+import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed
+import org.apache.spark.deploy.DeployMessages.KillExecutor
+import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
+import scala.Some
+import akka.remote.DisassociatedEvent
+import org.apache.spark.deploy.DeployMessages.LaunchExecutor
+import org.apache.spark.deploy.DeployMessages.RegisterWorker
+import org.apache.spark.deploy.DeployMessages.WorkerSchedulerStateResponse
+import org.apache.spark.deploy.DeployMessages.MasterChanged
+import org.apache.spark.deploy.DeployMessages.Heartbeat
+import org.apache.spark.deploy.DeployMessages.RegisteredWorker
+import akka.actor.Terminated
+
+/**
+ * @param masterUrls Each url should look like spark://host:port.
+ */
private[spark] class Worker(
host: String,
port: Int,
webUiPort: Int,
cores: Int,
memory: Int,
- masterUrl: String,
+ masterUrls: Array[String],
workDirPath: String = null)
extends Actor with Logging {
@@ -54,8 +70,18 @@ private[spark] class Worker(
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4
+ val REGISTRATION_TIMEOUT = 20.seconds
+ val REGISTRATION_RETRIES = 3
+
+ // Index into masterUrls that we're currently trying to register with.
+ var masterIndex = 0
+
+ val masterLock: Object = new Object()
var master: ActorRef = null
- var masterWebUiUrl : String = ""
+ var activeMasterUrl: String = ""
+ var activeMasterWebUiUrl : String = ""
+ @volatile var registered = false
+ @volatile var connected = false
val workerId = generateWorkerId()
var sparkHome: File = null
var workDir: File = null
@@ -95,6 +121,7 @@ private[spark] class Worker(
}
override def preStart() {
+ assert(!registered)
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
host, port, cores, Utils.megabytesToString(memory)))
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
@@ -103,46 +130,100 @@ private[spark] class Worker(
webUi = new WorkerWebUI(this, workDir, Some(webUiPort))
webUi.start()
- connectToMaster()
+ registerWithMaster()
metricsSystem.registerSource(workerSource)
metricsSystem.start()
}
- def connectToMaster() {
- logInfo("Connecting to master " + masterUrl)
- master = context.actorFor(Master.toAkkaUrl(masterUrl))
- master ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress)
- context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ def changeMaster(url: String, uiUrl: String) {
+ masterLock.synchronized {
+ activeMasterUrl = url
+ activeMasterWebUiUrl = uiUrl
+ master = context.actorFor(Master.toAkkaUrl(activeMasterUrl))
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ connected = true
+ }
+ }
+
+ def tryRegisterAllMasters() {
+ for (masterUrl <- masterUrls) {
+ logInfo("Connecting to master " + masterUrl + "...")
+ val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get,
+ publicAddress)
+ }
+ }
+
+ def registerWithMaster() {
+ tryRegisterAllMasters()
+
+ var retries = 0
+ lazy val retryTimer: Cancellable =
+ context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
+ retries += 1
+ if (registered) {
+ retryTimer.cancel()
+ } else if (retries >= REGISTRATION_RETRIES) {
+ logError("All masters are unresponsive! Giving up.")
+ System.exit(1)
+ } else {
+ tryRegisterAllMasters()
+ }
+ }
+ retryTimer // start timer
}
import context.dispatcher
override def receive = {
- case RegisteredWorker(url) =>
- masterWebUiUrl = url
- logInfo("Successfully registered with master")
- context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) {
- master ! Heartbeat(workerId)
+ case RegisteredWorker(masterUrl, masterWebUiUrl) =>
+ logInfo("Successfully registered with master " + masterUrl)
+ registered = true
+ changeMaster(masterUrl, masterWebUiUrl)
+ context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat)
+
+ case SendHeartbeat =>
+ masterLock.synchronized {
+ if (connected) { master ! Heartbeat(workerId) }
}
+ case MasterChanged(masterUrl, masterWebUiUrl) =>
+ logInfo("Master has changed, new master is at " + masterUrl)
+ context.unwatch(master)
+ changeMaster(masterUrl, masterWebUiUrl)
+
+ val execs = executors.values.
+ map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
+ sender ! WorkerSchedulerStateResponse(workerId, execs.toList)
+
case RegisterWorkerFailed(message) =>
- logError("Worker registration failed: " + message)
- System.exit(1)
-
- case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
- logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
- val manager = new ExecutorRunner(
- appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir)
- executors(appId + "/" + execId) = manager
- manager.start()
- coresUsed += cores_
- memoryUsed += memory_
- master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None)
+ if (!registered) {
+ logError("Worker registration failed: " + message)
+ System.exit(1)
+ }
+
+ case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
+ if (masterUrl != activeMasterUrl) {
+ logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.")
+ } else {
+ logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
+ val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
+ self, workerId, host, new File(execSparkHome_), workDir, ExecutorState.RUNNING)
+ executors(appId + "/" + execId) = manager
+ manager.start()
+ coresUsed += cores_
+ memoryUsed += memory_
+ masterLock.synchronized {
+ master ! ExecutorStateChanged(appId, execId, manager.state, None, None)
+ }
+ }
case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
- master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
+ masterLock.synchronized {
+ master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
+ }
val fullId = appId + "/" + execId
if (ExecutorState.isFinished(state)) {
val executor = executors(fullId)
@@ -155,14 +236,18 @@ private[spark] class Worker(
memoryUsed -= executor.memory
}
- case KillExecutor(appId, execId) =>
- val fullId = appId + "/" + execId
- executors.get(fullId) match {
- case Some(executor) =>
- logInfo("Asked to kill executor " + fullId)
- executor.kill()
- case None =>
- logInfo("Asked to kill unknown executor " + fullId)
+ case KillExecutor(masterUrl, appId, execId) =>
+ if (masterUrl != activeMasterUrl) {
+ logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor " + execId)
+ } else {
+ val fullId = appId + "/" + execId
+ executors.get(fullId) match {
+ case Some(executor) =>
+ logInfo("Asked to kill executor " + fullId)
+ executor.kill()
+ case None =>
+ logInfo("Asked to kill unknown executor " + fullId)
+ }
}
case DisassociatedEvent(_, _, _) =>
@@ -170,17 +255,14 @@ private[spark] class Worker(
case RequestWorkerState => {
sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
- finishedExecutors.values.toList, masterUrl, cores, memory,
- coresUsed, memoryUsed, masterWebUiUrl)
+ finishedExecutors.values.toList, activeMasterUrl, cores, memory,
+ coresUsed, memoryUsed, activeMasterWebUiUrl)
}
}
def masterDisconnected() {
- // TODO: It would be nice to try to reconnect to the master, but just shut down for now.
- // (Note that if reconnecting we would also need to assign IDs differently.)
- logError("Connection to master failed! Shutting down.")
- executors.values.foreach(_.kill())
- System.exit(1)
+ logError("Connection to master failed! Waiting for master to reconnect...")
+ connected = false
}
def generateWorkerId(): String = {
@@ -198,17 +280,18 @@ private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
- args.memory, args.master, args.workDir)
+ args.memory, args.masters, args.workDir)
actorSystem.awaitTermination()
}
def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
- masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None)
+ : (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
- masterUrl, workDir), name = "Worker")
+ masterUrls, workDir), name = "Worker")
(actorSystem, boundPort)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
index 0ae89a864f..3ed528e6b3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -29,7 +29,7 @@ private[spark] class WorkerArguments(args: Array[String]) {
var webUiPort = 8081
var cores = inferDefaultCores()
var memory = inferDefaultMemory()
- var master: String = null
+ var masters: Array[String] = null
var workDir: String = null
// Check for settings in environment variables
@@ -86,14 +86,14 @@ private[spark] class WorkerArguments(args: Array[String]) {
printUsageAndExit(0)
case value :: tail =>
- if (master != null) { // Two positional arguments were given
+ if (masters != null) { // Two positional arguments were given
printUsageAndExit(1)
}
- master = value
+ masters = value.stripPrefix("spark://").split(",").map("spark://" + _)
parse(tail)
case Nil =>
- if (master == null) { // No positional argument was given
+ if (masters == null) { // No positional argument was given
printUsageAndExit(1)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index 07bc479c83..a38e32b339 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -108,7 +108,7 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
val logText = <node>{Utils.offsetBytes(path, startByte, endByte)}</node>
- val linkToMaster = <p><a href={worker.masterWebUiUrl}>Back to Master</a></p>
+ val linkToMaster = <p><a href={worker.activeMasterWebUiUrl}>Back to Master</a></p>
val range = <span>Bytes {startByte.toString} - {endByte.toString} of {logLength}</span>
diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index f705a5631a..73fa7d6b6a 100644
--- a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -24,23 +24,15 @@ import akka.remote._
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{Utils, AkkaUtils}
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisteredExecutor
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.LaunchTask
import akka.remote.DisassociatedEvent
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutor
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutorFailed
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisteredExecutor
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.LaunchTask
import akka.remote.AssociationErrorEvent
import akka.remote.DisassociatedEvent
import akka.actor.Terminated
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutor
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutorFailed
-private[spark] class StandaloneExecutorBackend(
+private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
executorId: String,
hostPort: String,
@@ -75,15 +67,28 @@ private[spark] class StandaloneExecutorBackend(
case LaunchTask(taskDesc) =>
logInfo("Got assigned task " + taskDesc.taskId)
if (executor == null) {
- logError("Received launchTask but executor was null")
+ logError("Received LaunchTask command but executor was null")
System.exit(1)
} else {
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
}
+ case KillTask(taskId, _) =>
+ if (executor == null) {
+ logError("Received KillTask command but executor was null")
+ System.exit(1)
+ } else {
+ executor.killTask(taskId)
+ }
+
case DisassociatedEvent(_, _, _) =>
logError("Driver terminated or disconnected! Shutting down.")
System.exit(1)
+
+ case StopExecutor =>
+ logInfo("Driver commanded a shutdown")
+ context.stop(self)
+ context.system.shutdown()
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
@@ -91,7 +96,7 @@ private[spark] class StandaloneExecutorBackend(
}
}
-private[spark] object StandaloneExecutorBackend {
+private[spark] object CoarseGrainedExecutorBackend {
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
// Debug code
Utils.checkHost(hostname)
@@ -102,16 +107,19 @@ private[spark] object StandaloneExecutorBackend {
// set it
val sparkHostPort = hostname + ":" + boundPort
System.setProperty("spark.hostPort", sparkHostPort)
+
actorSystem.actorOf(
- Props(classOf[StandaloneExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
+ Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
name = "Executor")
actorSystem.awaitTermination()
}
def main(args: Array[String]) {
if (args.length < 4) {
- //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
- System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]")
+ //the reason we allow the last appid argument is to make it easy to kill rogue executors
+ System.err.println(
+ "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " +
+ "[<appid>]")
System.exit(1)
}
run(args(0), args(1), args(2), args(3).toInt)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 3800063234..de4540493a 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -25,9 +25,10 @@ import java.util.concurrent._
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
-import org.apache.spark.scheduler._
import org.apache.spark._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler._
+import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils
/**
@@ -36,7 +37,8 @@ import org.apache.spark.util.Utils
private[spark] class Executor(
executorId: String,
slaveHostname: String,
- properties: Seq[(String, String)])
+ properties: Seq[(String, String)],
+ isLocal: Boolean = false)
extends Logging
{
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
@@ -73,46 +75,75 @@ private[spark] class Executor(
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)
- // Make any thread terminations due to uncaught exceptions kill the entire
- // executor process to avoid surprising stalls.
- Thread.setDefaultUncaughtExceptionHandler(
- new Thread.UncaughtExceptionHandler {
- override def uncaughtException(thread: Thread, exception: Throwable) {
- try {
- logError("Uncaught exception in thread " + thread, exception)
-
- // We may have been called from a shutdown hook. If so, we must not call System.exit().
- // (If we do, we will deadlock.)
- if (!Utils.inShutdown()) {
- if (exception.isInstanceOf[OutOfMemoryError]) {
- System.exit(ExecutorExitCode.OOM)
- } else {
- System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ if (!isLocal) {
+ // Setup an uncaught exception handler for non-local mode.
+ // Make any thread terminations due to uncaught exceptions kill the entire
+ // executor process to avoid surprising stalls.
+ Thread.setDefaultUncaughtExceptionHandler(
+ new Thread.UncaughtExceptionHandler {
+ override def uncaughtException(thread: Thread, exception: Throwable) {
+ try {
+ logError("Uncaught exception in thread " + thread, exception)
+
+ // We may have been called from a shutdown hook. If so, we must not call System.exit().
+ // (If we do, we will deadlock.)
+ if (!Utils.inShutdown()) {
+ if (exception.isInstanceOf[OutOfMemoryError]) {
+ System.exit(ExecutorExitCode.OOM)
+ } else {
+ System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ }
}
+ } catch {
+ case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
+ case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
- } catch {
- case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
- case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
}
- }
- )
+ )
+ }
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
- val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
- SparkEnv.set(env)
- env.metricsSystem.registerSource(executorSource)
+ private val env = {
+ if (!isLocal) {
+ val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
+ isDriver = false, isLocal = false)
+ SparkEnv.set(_env)
+ _env.metricsSystem.registerSource(executorSource)
+ _env
+ } else {
+ SparkEnv.get
+ }
+ }
private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
// Start worker thread pool
- val threadPool = new ThreadPoolExecutor(
- 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
+ val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
+
+ // Maintains the list of running tasks.
+ private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
+
+ val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
- threadPool.execute(new TaskRunner(context, taskId, serializedTask))
+ val tr = new TaskRunner(context, taskId, serializedTask)
+ runningTasks.put(taskId, tr)
+ threadPool.execute(tr)
+ }
+
+ def killTask(taskId: Long) {
+ val tr = runningTasks.get(taskId)
+ if (tr != null) {
+ tr.kill()
+ // We remove the task also in the finally block in TaskRunner.run.
+ // The reason we need to remove it here is because killTask might be called before the task
+ // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
+ // idempotent.
+ runningTasks.remove(taskId)
+ }
}
/** Get the Yarn approved local directories. */
@@ -124,56 +155,87 @@ private[spark] class Executor(
.getOrElse(Option(System.getenv("LOCAL_DIRS"))
.getOrElse(""))
- if (localDirs.isEmpty()) {
+ if (localDirs.isEmpty) {
throw new Exception("Yarn Local dirs can't be empty")
}
- return localDirs
+ localDirs
}
- class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
+ class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {
- override def run() {
+ @volatile private var killed = false
+ @volatile private var task: Task[Any] = _
+
+ def kill() {
+ logInfo("Executor is trying to kill task " + taskId)
+ killed = true
+ if (task != null) {
+ task.kill()
+ }
+ }
+
+ override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () =>
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
- context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
+ execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var attemptedTask: Option[Task[Any]] = None
var taskStart: Long = 0
- def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
- val startGCTime = getTotalGCTime
+ def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
+ val startGCTime = gcTime
try {
SparkEnv.set(env)
Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
- val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+
+ // If this task has been killed before we deserialized it, let's quit now. Otherwise,
+ // continue executing the task.
+ if (killed) {
+ logInfo("Executor killed task " + taskId)
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+ return
+ }
+
attemptedTask = Some(task)
- logInfo("Its epoch is " + task.epoch)
+ logDebug("Task " + taskId +"'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
+
+ // Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
+
+ // If the task has been killed, let's fail it.
+ if (task.killed) {
+ logInfo("Executor killed task " + taskId)
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+ return
+ }
+
for (m <- task.metrics) {
- m.hostname = Utils.localHostName
+ m.hostname = Utils.localHostName()
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
- m.jvmGCTime = getTotalGCTime - startGCTime
+ m.jvmGCTime = gcTime - startGCTime
}
- //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
- // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
- // just change the relevants bytes in the byte buffer
+ // TODO I'd also like to track the time it takes to serialize the task results, but that is
+ // huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a
+ // custom serialized format, we could just change the relevants bytes in the byte buffer
val accumUpdates = Accumulators.values
+
val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null))
val serializedDirectResult = ser.serialize(directResult)
logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
val serializedResult = {
if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
logInfo("Storing result for " + taskId + " in local BlockManager")
- val blockId = "taskresult_" + taskId
+ val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
ser.serialize(new IndirectTaskResult[Any](blockId))
@@ -182,12 +244,13 @@ private[spark] class Executor(
serializedDirectResult
}
}
- context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
+
+ execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
logInfo("Finished task ID " + taskId)
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
- context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}
case t: Throwable => {
@@ -195,10 +258,10 @@ private[spark] class Executor(
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
m.executorRunTime = serviceTime
- m.jvmGCTime = getTotalGCTime - startGCTime
+ m.jvmGCTime = gcTime - startGCTime
}
val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
- context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
// have left some weird state around depending on when the exception was thrown, but on
@@ -206,6 +269,8 @@ private[spark] class Executor(
logError("Exception in task ID " + taskId, t)
//System.exit(1)
}
+ } finally {
+ runningTasks.remove(taskId)
}
}
}
@@ -215,7 +280,7 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): ExecutorURLClassLoader = {
- var loader = this.getClass.getClassLoader
+ val loader = this.getClass.getClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
@@ -237,7 +302,7 @@ private[spark] class Executor(
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
- return constructor.newInstance(classUri, parent)
+ constructor.newInstance(classUri, parent)
} catch {
case _: ClassNotFoundException =>
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
@@ -245,7 +310,7 @@ private[spark] class Executor(
null
}
} else {
- return parent
+ parent
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index da62091980..b56d8c9912 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -18,14 +18,18 @@
package org.apache.spark.executor
import java.nio.ByteBuffer
-import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
-import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _}
-import org.apache.spark.TaskState.TaskState
+
import com.google.protobuf.ByteString
-import org.apache.spark.{Logging}
+
+import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
+import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
+
+import org.apache.spark.Logging
import org.apache.spark.TaskState
+import org.apache.spark.TaskState.TaskState
import org.apache.spark.util.Utils
+
private[spark] class MesosExecutorBackend
extends MesosExecutor
with ExecutorBackend
@@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend
}
override def killTask(d: ExecutorDriver, t: TaskID) {
- logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)")
+ if (executor == null) {
+ logError("Received KillTask but executor was null")
+ } else {
+ executor.killTask(t.getValue.toLong)
+ }
}
override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index f311141148..0b4892f98f 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -102,4 +102,9 @@ class ShuffleWriteMetrics extends Serializable {
* Number of bytes written for a shuffle
*/
var shuffleBytesWritten: Long = _
+
+ /**
+ * Time spent blocking on writes to disk or buffer cache, in nanoseconds.
+ */
+ var shuffleWriteTime: Long = _
}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index c24fd48c04..703bc6a9ca 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -79,7 +79,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]
- implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
+ implicit val futureExecContext = ExecutionContext.fromExecutor(
+ Utils.newDaemonCachedThreadPool("Connection manager future execution context"))
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
index 3c29700920..1b9fa1e53a 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
@@ -20,17 +20,18 @@ package org.apache.spark.network.netty
import io.netty.buffer._
import org.apache.spark.Logging
+import org.apache.spark.storage.{TestBlockId, BlockId}
private[spark] class FileHeader (
val fileLen: Int,
- val blockId: String) extends Logging {
+ val blockId: BlockId) extends Logging {
lazy val buffer = {
val buf = Unpooled.buffer()
buf.capacity(FileHeader.HEADER_SIZE)
buf.writeInt(fileLen)
- buf.writeInt(blockId.length)
- blockId.foreach((x: Char) => buf.writeByte(x))
+ buf.writeInt(blockId.name.length)
+ blockId.name.foreach((x: Char) => buf.writeByte(x))
//padding the rest of header
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
@@ -57,18 +58,15 @@ private[spark] object FileHeader {
for (i <- 1 to idLength) {
idBuilder += buf.readByte().asInstanceOf[Char]
}
- val blockId = idBuilder.toString()
+ val blockId = BlockId(idBuilder.toString())
new FileHeader(length, blockId)
}
-
- def main (args:Array[String]){
-
- val header = new FileHeader(25,"block_0");
- val buf = header.buffer;
- val newheader = FileHeader.create(buf);
- System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
-
+ def main (args:Array[String]) {
+ val header = new FileHeader(25, TestBlockId("my_block"))
+ val buf = header.buffer
+ val newHeader = FileHeader.create(buf)
+ System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
index 9493ccffd9..481ff8c3e0 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
@@ -27,12 +27,13 @@ import org.apache.spark.Logging
import org.apache.spark.network.ConnectionManagerId
import scala.collection.JavaConverters._
+import org.apache.spark.storage.BlockId
private[spark] class ShuffleCopier extends Logging {
- def getBlock(host: String, port: Int, blockId: String,
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ def getBlock(host: String, port: Int, blockId: BlockId,
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
@@ -41,7 +42,7 @@ private[spark] class ShuffleCopier extends Logging {
try {
fc.init()
fc.connect(host, port)
- fc.sendRequest(blockId)
+ fc.sendRequest(blockId.name)
fc.waitForClose()
fc.close()
} catch {
@@ -53,14 +54,14 @@ private[spark] class ShuffleCopier extends Logging {
}
}
- def getBlock(cmId: ConnectionManagerId, blockId: String,
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ def getBlock(cmId: ConnectionManagerId, blockId: BlockId,
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
}
def getBlocks(cmId: ConnectionManagerId,
- blocks: Seq[(String, Long)],
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ blocks: Seq[(BlockId, Long)],
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
for ((blockId, size) <- blocks) {
getBlock(cmId, blockId, resultCollectCallback)
@@ -71,7 +72,7 @@ private[spark] class ShuffleCopier extends Logging {
private[spark] object ShuffleCopier extends Logging {
- private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
+ private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit)
extends FileClientHandler with Logging {
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
@@ -79,14 +80,14 @@ private[spark] object ShuffleCopier extends Logging {
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
- override def handleError(blockId: String) {
+ override def handleError(blockId: BlockId) {
if (!isComplete) {
resultCollectCallBack(blockId, -1, null)
}
}
}
- def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) {
if (size != -1) {
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
}
@@ -99,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging {
}
val host = args(0)
val port = args(1).toInt
- val file = args(2)
+ val blockId = BlockId(args(2))
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)
@@ -107,12 +108,12 @@ private[spark] object ShuffleCopier extends Logging {
Executors.callable(new Runnable() {
def run() {
val copier = new ShuffleCopier()
- copier.getBlock(host, port, file, echoResultCollectCallBack)
+ copier.getBlock(host, port, blockId, echoResultCollectCallBack)
}
})
}).asJava
copiers.invokeAll(tasks)
- copiers.shutdown
+ copiers.shutdown()
System.exit(0)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
index 8afcbe190a..546d921067 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -21,6 +21,7 @@ import java.io.File
import org.apache.spark.Logging
import org.apache.spark.util.Utils
+import org.apache.spark.storage.{BlockId, FileSegment}
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@@ -53,8 +54,8 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
- override def getAbsolutePath(blockId: String): String = {
- if (!blockId.startsWith("shuffle_")) {
+ override def getBlockLocation(blockId: BlockId): FileSegment = {
+ if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
// Figure out which local directory it hashes to, and which subdirectory in that
@@ -62,8 +63,8 @@ private[spark] object ShuffleSender {
val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
- val file = new File(subDir, blockId)
- return file.getAbsolutePath
+ val file = new File(subDir, blockId.name)
+ return new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index f132e2b735..70a5a8caff 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -15,6 +15,8 @@
* limitations under the License.
*/
+package org.apache
+
/**
* Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to
* Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection,
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
new file mode 100644
index 0000000000..44c5078621
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
+import scala.reflect.ClassTag
+
+/**
+ * A set of asynchronous RDD actions available through an implicit conversion.
+ * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
+ */
+class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
+
+ /**
+ * Returns a future for counting the number of elements in the RDD.
+ */
+ def countAsync(): FutureAction[Long] = {
+ val totalCount = new AtomicLong
+ self.context.submitJob(
+ self,
+ (iter: Iterator[T]) => {
+ var result = 0L
+ while (iter.hasNext) {
+ result += 1L
+ iter.next()
+ }
+ result
+ },
+ Range(0, self.partitions.size),
+ (index: Int, data: Long) => totalCount.addAndGet(data),
+ totalCount.get())
+ }
+
+ /**
+ * Returns a future for retrieving all elements of this RDD.
+ */
+ def collectAsync(): FutureAction[Seq[T]] = {
+ val results = new Array[Array[T]](self.partitions.size)
+ self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size),
+ (index, data) => results(index) = data, results.flatten.toSeq)
+ }
+
+ /**
+ * Returns a future for retrieving the first num elements of the RDD.
+ */
+ def takeAsync(num: Int): FutureAction[Seq[T]] = {
+ val f = new ComplexFutureAction[Seq[T]]
+
+ f.run {
+ val results = new ArrayBuffer[T](num)
+ val totalParts = self.partitions.length
+ var partsScanned = 0
+ while (results.size < num && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = 1
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+ // by 50%.
+ if (results.size == 0) {
+ numPartsToTry = totalParts - 1
+ } else {
+ numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
+
+ val left = num - results.size
+ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+
+ val buf = new Array[Array[T]](p.size)
+ f.runJob(self,
+ (it: Iterator[T]) => it.take(left).toArray,
+ p,
+ (index: Int, data: Array[T]) => buf(index) = data,
+ Unit)
+
+ buf.foreach(results ++= _.take(num - results.size))
+ partsScanned += numPartsToTry
+ }
+ results.toSeq
+ }
+
+ f
+ }
+
+ /**
+ * Applies a function f to all elements of this RDD.
+ */
+ def foreachAsync(f: T => Unit): FutureAction[Unit] = {
+ self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size),
+ (index, data) => Unit, Unit)
+ }
+
+ /**
+ * Applies a function f to each partition of this RDD.
+ */
+ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = {
+ self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size),
+ (index, data) => Unit, Unit)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index fe2946bcbe..63b9fe1478 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -18,15 +18,15 @@
package org.apache.spark.rdd
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
-import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.{BlockId, BlockManager}
import scala.reflect.ClassTag
-private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
+private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition {
val index = idx
}
private[spark]
-class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[String])
+class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId])
extends RDD[T](sc, Nil) {
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 3f4d4ad46a..99ea6e8ee8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.{NullWritable, BytesWritable}
@@ -84,9 +85,9 @@ private[spark] object CheckpointRDD extends Logging {
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val env = SparkEnv.get
val outputDir = new Path(path)
- val fs = outputDir.getFileSystem(env.hadoop.newConfiguration())
+ val fs = outputDir.getFileSystem(SparkHadoopUtil.get.newConfiguration())
- val finalOutputName = splitIdToFile(ctx.splitId)
+ val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
@@ -123,7 +124,7 @@ private[spark] object CheckpointRDD extends Logging {
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
val env = SparkEnv.get
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(path, bufferSize)
val serializer = env.serializer.newInstance()
@@ -146,7 +147,7 @@ private[spark] object CheckpointRDD extends Logging {
val sc = new SparkContext(cluster, "CheckpointRDD Test")
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
val path = new Path(hdfsPath, "temp")
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 0187256a8e..911a002884 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -18,13 +18,12 @@
package org.apache.spark.rdd
import java.io.{ObjectOutputStream, IOException}
-import java.util.{HashMap => JHashMap}
-import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+import org.apache.spark.util.AppendOnlyMap
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -105,17 +104,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
- val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
+ val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
- def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
- val seq = map.get(k)
- if (seq != null) {
- seq
- } else {
- val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
- map.put(k, seq)
- seq
- }
+ val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
+ if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any])
+ }
+
+ val getSeq = (k: K) => {
+ map.changeValue(k, update)
}
val ser = SparkEnv.get.serializerManager.get(serializerClass)
@@ -129,12 +125,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
+ fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
kv => getSeq(kv._1)(depNum) += kv._2
}
}
}
- JavaConversions.mapAsScalaMap(map).iterator
+ new InterruptibleIterator(context, map.iterator)
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index d3b3fffd40..32901a508f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -27,54 +27,19 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
-import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv,
- TaskContext}
+import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.NextIterator
import org.apache.hadoop.conf.{Configuration, Configurable}
-/**
- * An RDD that reads a file (or multiple files) from Hadoop (e.g. files in HDFS, the local file
- * system, or S3).
- * This accepts a general, broadcasted Hadoop Configuration because those tend to remain the same
- * across multiple reads; the 'path' is the only variable that is different across new JobConfs
- * created from the Configuration.
- */
-class HadoopFileRDD[K, V](
- sc: SparkContext,
- path: String,
- broadcastedConf: Broadcast[SerializableWritable[Configuration]],
- inputFormatClass: Class[_ <: InputFormat[K, V]],
- keyClass: Class[K],
- valueClass: Class[V],
- minSplits: Int)
- extends HadoopRDD[K, V](sc, broadcastedConf, inputFormatClass, keyClass, valueClass, minSplits) {
-
- override def getJobConf(): JobConf = {
- if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
- // getJobConf() has been called previously, so there is already a local cache of the JobConf
- // needed by this RDD.
- return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
- } else {
- // Create a new JobConf, set the input file/directory paths to read from, and cache the
- // JobConf (i.e., in a shared hash map in the slave's JVM process that's accessible through
- // HadoopRDD.putCachedMetadata()), so that we only create one copy across multiple
- // getJobConf() calls for this RDD in the local process.
- // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
- val newJobConf = new JobConf(broadcastedConf.value.value)
- FileInputFormat.setInputPaths(newJobConf, path)
- HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
- return newJobConf
- }
- }
-}
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit)
extends Partition {
-
+
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@@ -83,11 +48,24 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
}
/**
- * A base class that provides core functionality for reading data partitions stored in Hadoop.
+ * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
+ * sources in HBase, or S3).
+ *
+ * @param sc The SparkContext to associate the RDD with.
+ * @param broadCastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
+ * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job.
+ * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration.
+ * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD
+ * creates.
+ * @param inputFormatClass Storage format of the data to be read.
+ * @param keyClass Class of the key associated with the inputFormatClass.
+ * @param valueClass Class of the value associated with the inputFormatClass.
+ * @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate.
*/
class HadoopRDD[K, V](
sc: SparkContext,
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
+ initLocalJobConfFuncOpt: Option[JobConf => Unit],
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
@@ -105,6 +83,7 @@ class HadoopRDD[K, V](
sc,
sc.broadcast(new SerializableWritable(conf))
.asInstanceOf[Broadcast[SerializableWritable[Configuration]]],
+ None /* initLocalJobConfFuncOpt */,
inputFormatClass,
keyClass,
valueClass,
@@ -130,6 +109,7 @@ class HadoopRDD[K, V](
// local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
// The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
val newJobConf = new JobConf(broadcastedConf.value.value)
+ initLocalJobConfFuncOpt.map(f => f(newJobConf))
HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
return newJobConf
}
@@ -164,38 +144,41 @@ class HadoopRDD[K, V](
array
}
- override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
- val split = theSplit.asInstanceOf[HadoopPartition]
- logInfo("Input split: " + split.inputSplit)
- var reader: RecordReader[K, V] = null
-
- val jobConf = getJobConf()
- val inputFormat = getInputFormat(jobConf)
- reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
-
- // Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback{ () => closeIfNeeded() }
-
- val key: K = reader.createKey()
- val value: V = reader.createValue()
-
- override def getNext() = {
- try {
- finished = !reader.next(key, value)
- } catch {
- case eof: EOFException =>
- finished = true
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ val iter = new NextIterator[(K, V)] {
+ val split = theSplit.asInstanceOf[HadoopPartition]
+ logInfo("Input split: " + split.inputSplit)
+ var reader: RecordReader[K, V] = null
+
+ val jobConf = getJobConf()
+ val inputFormat = getInputFormat(jobConf)
+ reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback{ () => closeIfNeeded() }
+
+ val key: K = reader.createKey()
+ val value: V = reader.createValue()
+
+ override def getNext() = {
+ try {
+ finished = !reader.next(key, value)
+ } catch {
+ case eof: EOFException =>
+ finished = true
+ }
+ (key, value)
}
- (key, value)
- }
- override def close() {
- try {
- reader.close()
- } catch {
- case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ override def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
}
}
+ new InterruptibleIterator[(K, V)](context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {
@@ -216,10 +199,10 @@ private[spark] object HadoopRDD {
* The three methods below are helpers for accessing the local map, a property of the SparkEnv of
* the local process.
*/
- def getCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.get(key)
+ def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key)
- def containsCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.containsKey(key)
+ def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key)
def putCachedMetadata(key: String, value: Any) =
- SparkEnv.get.hadoop.hadoopJobMetadata.put(key, value)
+ SparkEnv.get.hadoopJobMetadata.put(key, value)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
index 3cf22851dd..67636751bb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
@@ -22,14 +22,14 @@ import scala.reflect.ClassTag
/**
- * A variant of the MapPartitionsRDD that passes the partition index into the
- * closure. This can be used to generate or collect partition specific
- * information such as the number of tuples in a partition.
+ * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
+ * TaskContext, the closure can either get access to the interruptible flag or get the index
+ * of the partition in the RDD.
*/
private[spark]
-class MapPartitionsWithIndexRDD[U: ClassTag, T: ClassTag](
+class MapPartitionsWithContextRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
- f: (Int, Iterator[T]) => Iterator[U],
+ f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean
) extends RDD[U](prev) {
@@ -38,5 +38,5 @@ class MapPartitionsWithIndexRDD[U: ClassTag, T: ClassTag](
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def compute(split: Partition, context: TaskContext) =
- f(split.index, firstParent[T].iterator(split, context))
+ f(context, firstParent[T].iterator(split, context))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 7b3a89f7e0..2662d48c84 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
private[spark]
@@ -71,49 +71,52 @@ class NewHadoopRDD[K, V](
result
}
- override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
- val split = theSplit.asInstanceOf[NewHadoopPartition]
- logInfo("Input split: " + split.serializableHadoopSplit)
- val conf = confBroadcast.value.value
- val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
- val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
- val format = inputFormatClass.newInstance
- if (format.isInstanceOf[Configurable]) {
- format.asInstanceOf[Configurable].setConf(conf)
- }
- val reader = format.createRecordReader(
- split.serializableHadoopSplit.value, hadoopAttemptContext)
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
-
- // Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback(() => close())
-
- var havePair = false
- var finished = false
-
- override def hasNext: Boolean = {
- if (!finished && !havePair) {
- finished = !reader.nextKeyValue
- havePair = !finished
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ val iter = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[NewHadoopPartition]
+ logInfo("Input split: " + split.serializableHadoopSplit)
+ val conf = confBroadcast.value.value
+ val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
+ val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
+ val format = inputFormatClass.newInstance
+ if (format.isInstanceOf[Configurable]) {
+ format.asInstanceOf[Configurable].setConf(conf)
+ }
+ val reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => close())
+
+ var havePair = false
+ var finished = false
+
+ override def hasNext: Boolean = {
+ if (!finished && !havePair) {
+ finished = !reader.nextKeyValue
+ havePair = !finished
+ }
+ !finished
}
- !finished
- }
- override def next: (K, V) = {
- if (!hasNext) {
- throw new java.util.NoSuchElementException("End of stream")
+ override def next(): (K, V) = {
+ if (!hasNext) {
+ throw new java.util.NoSuchElementException("End of stream")
+ }
+ havePair = false
+ (reader.getCurrentKey, reader.getCurrentValue)
}
- havePair = false
- return (reader.getCurrentKey, reader.getCurrentValue)
- }
- private def close() {
- try {
- reader.close()
- } catch {
- case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ private def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
}
}
+ new InterruptibleIterator(context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index c8e623081a..0c2a051a42 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -85,18 +85,24 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
}
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
- self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ self.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ }, preservesPartitioning = true)
} else if (mapSideCombine) {
val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
.setSerializer(serializerClass)
- partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
+ partitioned.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter))
+ }, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
- values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ values.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ }, preservesPartitioning = true)
}
}
@@ -565,7 +571,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
@@ -664,7 +670,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
- writer.setup(context.stageId, context.splitId, attemptNumber)
+ writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
var count = 0
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 78fe0cdcdb..09d0a8189d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -96,8 +96,9 @@ private[spark] class ParallelCollectionRDD[T: ClassTag](
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
- override def compute(s: Partition, context: TaskContext) =
- s.asInstanceOf[ParallelCollectionPartition[T]].iterator
+ override def compute(s: Partition, context: TaskContext) = {
+ new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
+ }
override def getPreferredLocations(s: Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 731ef90c90..3c237ca20a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -269,6 +269,19 @@ abstract class RDD[T: ClassTag](
def distinct(): RDD[T] = distinct(partitions.size)
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): RDD[T] = {
+ coalesce(numPartitions, true)
+ }
+
+ /**
* Return a new RDD that is reduced into `numPartitions` partitions.
*
* This results in a narrow dependency, e.g. if you go from 1000 partitions
@@ -421,26 +434,39 @@ abstract class RDD[T: ClassTag](
command: Seq[String],
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
- printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
+ printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = {
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
+ preservesPartitioning: Boolean = false): RDD[U] = {
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
def mapPartitionsWithIndex[U: ClassTag](
- f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+ f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+ val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
+ new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+ }
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD. This is a variant of
+ * mapPartitions that also passes the TaskContext into the closure.
+ */
+ def mapPartitionsWithContext[U: ClassTag](
+ f: (TaskContext, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] = {
+ new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
@@ -448,22 +474,23 @@ abstract class RDD[T: ClassTag](
*/
@deprecated("use mapPartitionsWithIndex", "0.7.0")
def mapPartitionsWithSplit[U: ClassTag](
- f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+ f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+ mapPartitionsWithIndex(f, preservesPartitioning)
+ }
/**
* Maps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def mapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false)
- (f:(T, A) => U): RDD[U] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(index)
- iter.map(t => f(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ def mapWith[A: ClassTag, U: ClassTag]
+ (constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f: (T, A) => U): RDD[U] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(context.partitionId)
+ iter.map(t => f(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@@ -471,13 +498,14 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def flatMapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false)
- (f:(T, A) => Seq[U]): RDD[U] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(index)
- iter.flatMap(t => f(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ def flatMapWith[A: ClassTag, U: ClassTag]
+ (constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f: (T, A) => Seq[U]): RDD[U] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(context.partitionId)
+ iter.flatMap(t => f(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@@ -485,13 +513,12 @@ abstract class RDD[T: ClassTag](
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def foreachWith[A: ClassTag](constructA: Int => A)
- (f:(T, A) => Unit) {
- def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(index)
- iter.map(t => {f(t, a); t})
- }
- (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
+ def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(context.partitionId)
+ iter.map(t => {f(t, a); t})
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
}
/**
@@ -499,13 +526,12 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def filterWith[A: ClassTag](constructA: Int => A)
- (p:(T, A) => Boolean): RDD[T] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(index)
- iter.filter(t => p(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
+ def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(context.partitionId)
+ iter.filter(t => p(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
}
/**
@@ -544,16 +570,14 @@ abstract class RDD[T: ClassTag](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
+ sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
+ sc.runJob(this, (iter: Iterator[T]) => f(iter))
}
/**
@@ -678,6 +702,8 @@ abstract class RDD[T: ClassTag](
*/
def count(): Long = {
sc.runJob(this, (iter: Iterator[T]) => {
+ // Use a while loop to count the number of elements rather than iter.size because
+ // iter.size uses a for loop, which is slightly slower in current version of Scala.
var result = 0L
while (iter.hasNext) {
result += 1L
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index b7205865cf..1d109a2496 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -57,7 +57,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics,
+ SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
SparkEnv.get.serializerManager.get(serializerClass))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 85c512f3de..aab30b1bb4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -111,7 +111,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
}
case ShuffleCoGroupSplitDep(shuffleId) => {
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
- context.taskMetrics, serializer)
+ context, serializer)
iter.foreach(op)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 03fe0e00f9..ab7b3a2e24 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -29,8 +29,8 @@ import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
-import org.apache.spark.storage.{BlockManager, BlockManagerMaster}
-import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
@@ -42,34 +42,40 @@ import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
* locations to run each task on, based on the current cache status, and passes these to the
* low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being
* lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are
- * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task
+ * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task
* a small number of times before cancelling the whole stage.
*
* THREADING: This class runs all its logic in a single thread executing the run() method, to which
- * events are submitted using a synchonized queue (eventQueue). The public API methods, such as
+ * events are submitted using a synchronized queue (eventQueue). The public API methods, such as
* runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods
* should be private.
*/
private[spark]
class DAGScheduler(
taskSched: TaskScheduler,
- mapOutputTracker: MapOutputTracker,
+ mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
- extends TaskSchedulerListener with Logging {
+ extends Logging {
def this(taskSched: TaskScheduler) {
- this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ SparkEnv.get.blockManager.master, SparkEnv.get)
}
- taskSched.setListener(this)
+ taskSched.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting.
- override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventQueue.put(BeginEvent(task, taskInfo))
}
+ // Called to report that a task has completed and results are being fetched remotely.
+ def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
+ eventQueue.put(GettingResultEvent(task, taskInfo))
+ }
+
// Called by TaskScheduler to report task completions or failures.
- override def taskEnded(
+ def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
@@ -80,17 +86,18 @@ class DAGScheduler(
}
// Called by TaskScheduler when an executor fails.
- override def executorLost(execId: String) {
+ def executorLost(execId: String) {
eventQueue.put(ExecutorLost(execId))
}
// Called by TaskScheduler when a host is added
- override def executorGained(execId: String, host: String) {
+ def executorGained(execId: String, host: String) {
eventQueue.put(ExecutorGained(execId, host))
}
- // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
- override def taskSetFailed(taskSet: TaskSet, reason: String) {
+ // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
+ // cancellation of the job itself.
+ def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
@@ -105,13 +112,15 @@ class DAGScheduler(
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
- val nextJobId = new AtomicInteger(0)
+ private[scheduler] val nextJobId = new AtomicInteger(0)
+
+ def numTotalJobs: Int = nextJobId.get()
- val nextStageId = new AtomicInteger(0)
+ private val nextStageId = new AtomicInteger(0)
- val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+ private val stageIdToStage = new TimeStampedHashMap[Int, Stage]
- val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+ private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
@@ -128,6 +137,7 @@ class DAGScheduler(
// stray messages to detect.
val failedEpoch = new HashMap[String, Long]
+ // stage id to the active job
val idToActiveJob = new HashMap[Int, ActiveJob]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
@@ -139,7 +149,7 @@ class DAGScheduler(
val activeJobs = new HashSet[ActiveJob]
val resultStageToJob = new HashMap[Stage, ActiveJob]
- val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup)
// Start a thread to run the DAGScheduler event loop
def start() {
@@ -157,7 +167,7 @@ class DAGScheduler(
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
if (!cacheLocs.contains(rdd.id)) {
- val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
+ val blockIds = rdd.partitions.indices.map(index=> RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
cacheLocs(rdd.id) = blockIds.map { id =>
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
@@ -179,7 +189,7 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
- val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
+ val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
@@ -192,6 +202,7 @@ class DAGScheduler(
*/
private def newStage(
rdd: RDD[_],
+ numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
jobId: Int,
callSite: Option[String] = None)
@@ -204,9 +215,10 @@ class DAGScheduler(
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
- val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
+ val stage =
+ new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
- stageToInfos(stage) = StageInfo(stage)
+ stageToInfos(stage) = new StageInfo(stage)
stage
}
@@ -262,32 +274,41 @@ class DAGScheduler(
}
/**
- * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
- * JobWaiter whose getResult() method will return the result of the job when it is complete.
- *
- * The job is assumed to have at least one partition; zero partition jobs should be handled
- * without a JobSubmitted event.
+ * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
+ * can be used to block until the the job finishes executing or can be used to cancel the job.
*/
- private[scheduler] def prepareJob[T, U: ClassTag](
- finalRdd: RDD[T],
+ def submitJob[T, U](
+ rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
- properties: Properties = null)
- : (JobSubmitted, JobWaiter[U]) =
+ properties: Properties = null): JobWaiter[U] =
{
+ // Check to make sure we are not launching a task on a partition that does not exist.
+ val maxPartitions = rdd.partitions.length
+ partitions.find(p => p >= maxPartitions).foreach { p =>
+ throw new IllegalArgumentException(
+ "Attempting to access a non-existent partition: " + p + ". " +
+ "Total number of partitions: " + maxPartitions)
+ }
+
+ val jobId = nextJobId.getAndIncrement()
+ if (partitions.size == 0) {
+ return new JobWaiter[U](this, jobId, 0, resultHandler)
+ }
+
assert(partitions.size > 0)
- val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
- properties)
- (toSubmit, waiter)
+ val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
+ eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite,
+ waiter, properties))
+ waiter
}
def runJob[T, U: ClassTag](
- finalRdd: RDD[T],
+ rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
@@ -295,21 +316,7 @@ class DAGScheduler(
resultHandler: (Int, U) => Unit,
properties: Properties = null)
{
- if (partitions.size == 0) {
- return
- }
-
- // Check to make sure we are not launching a task on a partition that does not exist.
- val maxPartitions = finalRdd.partitions.length
- partitions.find(p => p >= maxPartitions).foreach { p =>
- throw new IllegalArgumentException(
- "Attempting to access a non-existent partition: " + p + ". " +
- "Total number of partitions: " + maxPartitions)
- }
-
- val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
- finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
- eventQueue.put(toSubmit)
+ val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded => {}
case JobFailed(exception: Exception, _) =>
@@ -330,19 +337,40 @@ class DAGScheduler(
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
- eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
+ val jobId = nextJobId.getAndIncrement()
+ eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite,
+ listener, properties))
listener.awaitResult() // Will throw an exception if the job fails
}
/**
+ * Cancel a job that is running or waiting in the queue.
+ */
+ def cancelJob(jobId: Int) {
+ logInfo("Asked to cancel job " + jobId)
+ eventQueue.put(JobCancelled(jobId))
+ }
+
+ def cancelJobGroup(groupId: String) {
+ logInfo("Asked to cancel job group " + groupId)
+ eventQueue.put(JobGroupCancelled(groupId))
+ }
+
+ /**
+ * Cancel all jobs that are running or waiting in the queue.
+ */
+ def cancelAllJobs() {
+ eventQueue.put(AllJobsCancelled)
+ }
+
+ /**
* Process one event retrieved from the event queue.
* Returns true if we should stop the event loop.
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
- val jobId = nextJobId.getAndIncrement()
- val finalStage = newStage(finalRDD, None, jobId, Some(callSite))
+ case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
+ val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@@ -361,18 +389,43 @@ class DAGScheduler(
submitStage(finalStage)
}
+ case JobCancelled(jobId) =>
+ // Cancel a job: find all the running stages that are linked to this job, and cancel them.
+ running.filter(_.jobId == jobId).foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+
+ case JobGroupCancelled(groupId) =>
+ // Cancel all jobs belonging to this job group.
+ // First finds all active jobs with this group id, and then kill stages for them.
+ val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
+ .map(_.jobId)
+ if (!jobIds.isEmpty) {
+ running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+ }
+
+ case AllJobsCancelled =>
+ // Cancel all running jobs.
+ running.foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+
case ExecutorGained(execId, host) =>
handleExecutorGained(execId, host)
case ExecutorLost(execId) =>
handleExecutorLost(execId)
- case begin: BeginEvent =>
- listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
+ case BeginEvent(task, taskInfo) =>
+ listenerBus.post(SparkListenerTaskStart(task, taskInfo))
+
+ case GettingResultEvent(task, taskInfo) =>
+ listenerBus.post(SparkListenerTaskGettingResult(task, taskInfo))
- case completion: CompletionEvent =>
- listenerBus.post(SparkListenerTaskEnd(
- completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
+ case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
+ listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
@@ -542,7 +595,7 @@ class DAGScheduler(
// must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded"
- listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
+ listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
@@ -563,9 +616,7 @@ class DAGScheduler(
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
- if (!stage.submissionTime.isDefined) {
- stage.submissionTime = Some(System.currentTimeMillis())
- }
+ stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -579,15 +630,20 @@ class DAGScheduler(
*/
private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
+
+ if (!stageIdToStage.contains(task.stageId)) {
+ // Skip all the actions if the stage has been cancelled.
+ return
+ }
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
- val serviceTime = stage.submissionTime match {
+ val serviceTime = stageToInfos(stage).submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
- case _ => "Unkown"
+ case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
- stage.completionTime = Some(System.currentTimeMillis)
+ stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
listenerBus.post(StageCompleted(stageToInfos(stage)))
running -= stage
}
@@ -627,7 +683,7 @@ class DAGScheduler(
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
} else {
- stage.addOutputLoc(smt.partition, status)
+ stage.addOutputLoc(smt.partitionId, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
markStageAsFinished(stage)
@@ -753,14 +809,14 @@ class DAGScheduler(
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
- * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
+ * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
- failedStage.completionTime = Some(System.currentTimeMillis())
+ stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
- val error = new SparkException("Job failed: " + reason)
+ val error = new SparkException("Job aborted: " + reason)
job.listener.jobFailed(error)
listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
idToActiveJob -= resultStage.jobId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 10ff1b4376..708d221d60 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -31,9 +31,10 @@ import org.apache.spark.executor.TaskMetrics
* submitted) but there is a single "logic" thread that reads these events and takes decisions.
* This greatly simplifies synchronization.
*/
-private[spark] sealed trait DAGSchedulerEvent
+private[scheduler] sealed trait DAGSchedulerEvent
-private[spark] case class JobSubmitted(
+private[scheduler] case class JobSubmitted(
+ jobId: Int,
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
@@ -43,9 +44,19 @@ private[spark] case class JobSubmitted(
properties: Properties = null)
extends DAGSchedulerEvent
-private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
-private[spark] case class CompletionEvent(
+private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
+
+private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
+
+private[scheduler]
+case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+
+private[scheduler]
+case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+
+private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
@@ -54,10 +65,12 @@ private[spark] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
-private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
+private[scheduler]
+case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
-private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
+private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
-private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+private[scheduler]
+case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
-private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
+private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
index 151514896f..7b5c0e29ad 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
@@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar
})
metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] {
- override def getValue: Int = dagScheduler.nextJobId.get()
+ override def getValue: Int = dagScheduler.numTotalJobs
})
metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 370ccd183c..1791ee660d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.deploy.SparkHadoopUtil
import scala.collection.immutable.Set
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.hadoop.security.UserGroupInformation
@@ -87,9 +88,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
- val env = SparkEnv.get
val conf = new JobConf(configuration)
- env.hadoop.addCredentials(conf)
+ SparkHadoopUtil.get.addCredentials(conf)
FileInputFormat.setInputPaths(conf, path)
val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
@@ -108,9 +108,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
- val env = SparkEnv.get
val jobConf = new JobConf(configuration)
- env.hadoop.addCredentials(jobConf)
+ SparkHadoopUtil.get.addCredentials(jobConf)
FileInputFormat.setInputPaths(jobConf, path)
val instance: org.apache.hadoop.mapred.InputFormat[_, _] =
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 3628b1b078..60927831a1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -1,292 +1,384 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import java.io.PrintWriter
-import java.io.File
-import java.io.FileNotFoundException
-import java.text.SimpleDateFormat
-import java.util.{Date, Properties}
-import java.util.concurrent.LinkedBlockingQueue
-
-import scala.collection.mutable.{Map, HashMap, ListBuffer}
-import scala.io.Source
-
-import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.executor.TaskMetrics
-
-// Used to record runtime information for each job, including RDD graph
-// tasks' start/stop shuffle information and information from outside
-
-class JobLogger(val logDirName: String) extends SparkListener with Logging {
- private val logDir =
- if (System.getenv("SPARK_LOG_DIR") != null)
- System.getenv("SPARK_LOG_DIR")
- else
- "/tmp/spark"
- private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
- private val stageIDToJobID = new HashMap[Int, Int]
- private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
- private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
-
- createLogDir()
- def this() = this(String.valueOf(System.currentTimeMillis()))
-
- def getLogDir = logDir
- def getJobIDtoPrintWriter = jobIDToPrintWriter
- def getStageIDToJobID = stageIDToJobID
- def getJobIDToStages = jobIDToStages
- def getEventQueue = eventQueue
-
- // Create a folder for log files, the folder's name is the creation time of the jobLogger
- protected def createLogDir() {
- val dir = new File(logDir + "/" + logDirName + "/")
- if (dir.exists()) {
- return
- }
- if (dir.mkdirs() == false) {
- logError("create log directory error:" + logDir + "/" + logDirName + "/")
- }
- }
-
- // Create a log file for one job, the file name is the jobID
- protected def createLogWriter(jobID: Int) {
- try{
- val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
- jobIDToPrintWriter += (jobID -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
- }
-
- // Close log file, and clean the stage relationship in stageIDToJobID
- protected def closeLogWriter(jobID: Int) =
- jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
- fileWriter.close()
- jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
- stageIDToJobID -= stage.id
- })
- jobIDToPrintWriter -= jobID
- jobIDToStages -= jobID
- }
-
- // Write log information to log file, withTime parameter controls whether to recored
- // time stamp for the information
- protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
- var writeInfo = info
- if (withTime) {
- val date = new Date(System.currentTimeMillis())
- writeInfo = DATE_FORMAT.format(date) + ": " +info
- }
- jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
- }
-
- protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
- stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
-
- protected def buildJobDep(jobID: Int, stage: Stage) {
- if (stage.jobId == jobID) {
- jobIDToStages.get(jobID) match {
- case Some(stageList) => stageList += stage
- case None => val stageList = new ListBuffer[Stage]
- stageList += stage
- jobIDToStages += (jobID -> stageList)
- }
- stageIDToJobID += (stage.id -> jobID)
- stage.parents.foreach(buildJobDep(jobID, _))
- }
- }
-
- protected def recordStageDep(jobID: Int) {
- def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
- var rddList = new ListBuffer[RDD[_]]
- rddList += rdd
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- case _ => rddList ++= getRddsInStage(dep.rdd)
- }
- }
- rddList
- }
- jobIDToStages.get(jobID).foreach {_.foreach { stage =>
- var depRddDesc: String = ""
- getRddsInStage(stage.rdd).foreach { rdd =>
- depRddDesc += rdd.id + ","
- }
- var depStageDesc: String = ""
- stage.parents.foreach { stage =>
- depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
- }
- jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
- depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
- " STAGE_DEP=" + depStageDesc, false)
- }
- }
- }
-
- // Generate indents and convert to String
- protected def indentString(indent: Int) = {
- val sb = new StringBuilder()
- for (i <- 1 to indent) {
- sb.append(" ")
- }
- sb.toString()
- }
-
- protected def getRddName(rdd: RDD[_]) = {
- var rddName = rdd.getClass.getName
- if (rdd.name != null) {
- rddName = rdd.name
- }
- rddName
- }
-
- protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
- val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
- jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
- }
- }
-
- protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
- var stageInfo: String = ""
- if (stage.isShuffleMap) {
- stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
- stage.shuffleDep.get.shuffleId
- }else{
- stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
- }
- if (stage.jobId == jobID) {
- jobLogInfo(jobID, indentString(indent) + stageInfo, false)
- recordRddInStageGraph(jobID, stage.rdd, indent)
- stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
- } else
- jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
- }
-
- // Record task metrics into job log files
- protected def recordTaskMetrics(stageID: Int, status: String,
- taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
- val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
- " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
- " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
- val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val readMetrics =
- taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
- case None => ""
- }
- val writeMetrics =
- taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) =>
- " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
- case None => ""
- }
- stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
- }
-
- override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(
- stageSubmitted.stage.id,
- "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.id, stageSubmitted.taskSize))
- }
-
- override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(
- stageCompleted.stageInfo.stage.id,
- "STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
-
- }
-
- override def onTaskStart(taskStart: SparkListenerTaskStart) { }
-
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- val task = taskEnd.task
- val taskInfo = taskEnd.taskInfo
- var taskStatus = ""
- task match {
- case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
- case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
- }
- taskEnd.reason match {
- case Success => taskStatus += " STATUS=SUCCESS"
- recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
- case Resubmitted =>
- taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId
- stageLogInfo(task.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
- taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
- task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
- mapId + " REDUCE_ID=" + reduceId
- stageLogInfo(task.stageId, taskStatus)
- case OtherFailure(message) =>
- taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId + " INFO=" + message
- stageLogInfo(task.stageId, taskStatus)
- case _ =>
- }
- }
-
- override def onJobEnd(jobEnd: SparkListenerJobEnd) {
- val job = jobEnd.job
- var info = "JOB_ID=" + job.jobId
- jobEnd.jobResult match {
- case JobSucceeded => info += " STATUS=SUCCESS"
- case JobFailed(exception, _) =>
- info += " STATUS=FAILED REASON="
- exception.getMessage.split("\\s+").foreach(info += _ + "_")
- case _ =>
- }
- jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
- closeLogWriter(job.jobId)
- }
-
- protected def recordJobProperties(jobID: Int, properties: Properties) {
- if(properties != null) {
- val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
- jobLogInfo(jobID, description, false)
- }
- }
-
- override def onJobStart(jobStart: SparkListenerJobStart) {
- val job = jobStart.job
- val properties = jobStart.properties
- createLogWriter(job.jobId)
- recordJobProperties(job.jobId, properties)
- buildJobDep(job.jobId, job.finalStage)
- recordStageDep(job.jobId)
- recordStageDepGraph(job.jobId, job.finalStage)
- jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
- }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{IOException, File, FileNotFoundException, PrintWriter}
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * A logger class to record runtime information for jobs in Spark. This class outputs one log file
+ * for each Spark job, containing RDD graph, tasks start/stop, shuffle information.
+ * JobLogger is a subclass of SparkListener, use addSparkListener to add JobLogger to a SparkContext
+ * after the SparkContext is created.
+ * Note that each JobLogger only works for one SparkContext
+ * @param logDirName The base directory for the log files.
+ */
+
+class JobLogger(val user: String, val logDirName: String)
+ extends SparkListener with Logging {
+
+ def this() = this(System.getProperty("user.name", "<unknown>"),
+ String.valueOf(System.currentTimeMillis()))
+
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark-%s".format(user)
+
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+
+ // The following 5 functions are used only in testing.
+ private[scheduler] def getLogDir = logDir
+ private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
+ private[scheduler] def getStageIDToJobID = stageIDToJobID
+ private[scheduler] def getJobIDToStages = jobIDToStages
+ private[scheduler] def getEventQueue = eventQueue
+
+ /** Create a folder for log files, the folder's name is the creation time of jobLogger */
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ // JobLogger should throw a exception rather than continue to construct this object.
+ throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ /**
+ * Create a log file for one job
+ * @param jobID ID of the job
+ * @exception FileNotFoundException Fail to create log file
+ */
+ protected def createLogWriter(jobID: Int) {
+ try {
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ /**
+ * Close log file, and clean the stage relationship in stageIDToJobID
+ * @param jobID ID of the job
+ */
+ protected def closeLogWriter(jobID: Int) {
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+ }
+
+ /**
+ * Write info into log file
+ * @param jobID ID of the job
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ /**
+ * Write info into log file
+ * @param stageID ID of the stage
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) {
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+ }
+
+ /**
+ * Build stage dependency for a job
+ * @param jobID ID of the job
+ * @param stage Root stage of the job
+ */
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.jobId == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ /**
+ * Record stage dependency and RDD dependency for a stage
+ * @param jobID Job ID of the stage
+ */
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ /**
+ * Generate indents and convert to String
+ * @param indent Number of indents
+ * @return string of indents
+ */
+ protected def indentString(indent: Int): String = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ /**
+ * Get RDD's name
+ * @param rdd Input RDD
+ * @return String of RDD's name
+ */
+ protected def getRddName(rdd: RDD[_]): String = {
+ var rddName = rdd.getClass.getSimpleName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ /**
+ * Record RDD dependency graph in a stage
+ * @param jobID Job ID of the stage
+ * @param rdd Root RDD of the stage
+ * @param indent Indent number before info
+ */
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo =
+ if (rdd.getStorageLevel != StorageLevel.NONE) {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " +
+ rdd.origin + " " + rdd.generator
+ } else {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " +
+ rdd.origin + " " + rdd.generator
+ }
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+
+ /**
+ * Record stage dependency graph of a job
+ * @param jobID Job ID of the stage
+ * @param stage Root stage of the job
+ * @param indent Indent number before info, default is 0
+ */
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, idSet: HashSet[Int], indent: Int = 0) {
+ val stageInfo = if (stage.isShuffleMap) {
+ "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
+ } else {
+ "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.jobId == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ if (!idSet.contains(stage.id)) {
+ idSet += stage.id
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, idSet, indent + 2))
+ }
+ } else {
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
+ }
+
+ /**
+ * Record task metrics into job log files, including execution info and shuffle metrics
+ * @param stageID Stage ID of the task
+ * @param status Status info of the task
+ * @param taskInfo Task description info
+ * @param taskMetrics Task running metrics
+ */
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics = taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics = taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ /**
+ * When stage is submitted, record stage submit info
+ * @param stageSubmitted Stage submitted event
+ */
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
+ }
+
+ /**
+ * When stage is completed, record stage completion status
+ * @param stageCompleted Stage completed event
+ */
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
+ stageCompleted.stage.stageId))
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) { }
+
+ /**
+ * When task ends, record task completion status and metrics
+ * @param taskEnd Task end event
+ */
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val task = taskEnd.task
+ val taskInfo = taskEnd.taskInfo
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ taskEnd.reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ /**
+ * When job ends, recording job completion status and close log file
+ * @param jobEnd Job end event
+ */
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ val job = jobEnd.job
+ var info = "JOB_ID=" + job.jobId
+ jobEnd.jobResult match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception, _) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.jobId)
+ }
+
+ /**
+ * Record job properties into job log file
+ * @param jobID ID of the job
+ * @param properties Properties of the job
+ */
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
+ jobLogInfo(jobID, description, false)
+ }
+ }
+
+ /**
+ * When job starts, record job property and stage graph
+ * @param jobStart Job start event
+ */
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ val job = jobStart.job
+ val properties = jobStart.properties
+ createLogWriter(job.jobId)
+ recordJobProperties(job.jobId, properties)
+ buildJobDep(job.jobId, job.finalStage)
+ recordStageDep(job.jobId)
+ recordStageDepGraph(job.jobId, job.finalStage, new HashSet[Int])
+ jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 200d881799..58f238d8cf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -17,48 +17,58 @@
package org.apache.spark.scheduler
-import scala.collection.mutable.ArrayBuffer
-
/**
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
*/
-private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
+private[spark] class JobWaiter[T](
+ dagScheduler: DAGScheduler,
+ jobId: Int,
+ totalTasks: Int,
+ resultHandler: (Int, T) => Unit)
extends JobListener {
private var finishedTasks = 0
- private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
- private var jobResult: JobResult = null // If the job is finished, this will be its result
+ // Is the job as a whole finished (succeeded or failed)?
+ private var _jobFinished = totalTasks == 0
- override def taskSucceeded(index: Int, result: Any) {
- synchronized {
- if (jobFinished) {
- throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
- }
- resultHandler(index, result.asInstanceOf[T])
- finishedTasks += 1
- if (finishedTasks == totalTasks) {
- jobFinished = true
- jobResult = JobSucceeded
- this.notifyAll()
- }
- }
+ def jobFinished = _jobFinished
+
+ // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero
+ // partition RDDs), we set the jobResult directly to JobSucceeded.
+ private var jobResult: JobResult = if (jobFinished) JobSucceeded else null
+
+ /**
+ * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled
+ * asynchronously. After the low level scheduler cancels all the tasks belonging to this job, it
+ * will fail this job with a SparkException.
+ */
+ def cancel() {
+ dagScheduler.cancelJob(jobId)
}
- override def jobFailed(exception: Exception) {
- synchronized {
- if (jobFinished) {
- throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
- }
- jobFinished = true
- jobResult = JobFailed(exception, None)
+ override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
+ if (_jobFinished) {
+ throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
+ }
+ resultHandler(index, result.asInstanceOf[T])
+ finishedTasks += 1
+ if (finishedTasks == totalTasks) {
+ _jobFinished = true
+ jobResult = JobSucceeded
this.notifyAll()
}
}
+ override def jobFailed(exception: Exception): Unit = synchronized {
+ _jobFinished = true
+ jobResult = JobFailed(exception, None)
+ this.notifyAll()
+ }
+
def awaitResult(): JobResult = synchronized {
- while (!jobFinished) {
+ while (!_jobFinished) {
this.wait()
}
return jobResult
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 9eb8d48501..596f9adde9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -43,7 +43,10 @@ private[spark] class Pool(
var runningTasks = 0
var priority = 0
- var stageId = 0
+
+ // A pool's stage id is used to break the tie in scheduling.
+ var stageId = -1
+
var name = poolName
var parent: Pool = null
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 07e8317e3a..310ec62ca8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -23,7 +23,7 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
-import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
private[spark] object ResultTask {
@@ -32,23 +32,23 @@ private[spark] object ResultTask {
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
- val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
synchronized {
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
- return old
+ old
} else {
val out = new ByteArrayOutputStream
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(func)
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
- return bytes
+ bytes
}
}
}
@@ -56,11 +56,11 @@ private[spark] object ResultTask {
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
- return (rdd, func)
+ (rdd, func)
}
def clearCache() {
@@ -71,29 +71,37 @@ private[spark] object ResultTask {
}
+/**
+ * A task that sends back the output to the driver application.
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd input to func
+ * @param func a function to apply on a partition of the RDD
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ * @param outputId index of the task in this job (a job can launch tasks on only a subset of the
+ * input RDD's partitions).
+ */
private[spark] class ResultTask[T, U](
stageId: Int,
var rdd: RDD[T],
var func: (TaskContext, Iterator[T]) => U,
- var partition: Int,
+ _partitionId: Int,
@transient locs: Seq[TaskLocation],
var outputId: Int)
- extends Task[U](stageId) with Externalizable {
+ extends Task[U](stageId, _partitionId) with Externalizable {
def this() = this(0, null, null, 0, null, 0)
- var split = if (rdd == null) {
- null
- } else {
- rdd.partitions(partition)
- }
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
@transient private val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}
- override def run(attemptId: Long): U = {
- val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
+ override def runTask(context: TaskContext): U = {
metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
@@ -104,17 +112,17 @@ private[spark] class ResultTask[T, U](
override def preferredLocations: Seq[TaskLocation] = preferredLocs
- override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+ override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.partitions(partition)
+ split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ResultTask.serializeInfo(
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
out.writeInt(bytes.length)
out.write(bytes)
- out.writeInt(partition)
+ out.writeInt(partitionId)
out.writeInt(outputId)
out.writeLong(epoch)
out.writeObject(split)
@@ -129,7 +137,7 @@ private[spark] class ResultTask[T, U](
val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
rdd = rdd_.asInstanceOf[RDD[T]]
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
- partition = in.readInt()
+ partitionId = in.readInt()
outputId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index 4e25086ec9..356fe56bf3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -30,7 +30,10 @@ import scala.xml.XML
* addTaskSetManager: build the leaf nodes(TaskSetManagers)
*/
private[spark] trait SchedulableBuilder {
+ def rootPool: Pool
+
def buildPools()
+
def addTaskSetManager(manager: Schedulable, properties: Properties)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index d23df0dd2b..1dc71a0428 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.storage._
-import org.apache.spark.util.{TimeStampedHashMap, MetadataCleaner}
+import org.apache.spark.util.{MetadataCleanerType, TimeStampedHashMap, MetadataCleaner}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
@@ -37,7 +37,7 @@ private[spark] object ShuffleMapTask {
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
- val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues)
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
@@ -53,7 +53,7 @@ private[spark] object ShuffleMapTask {
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
- return bytes
+ bytes
}
}
}
@@ -66,7 +66,7 @@ private[spark] object ShuffleMapTask {
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
- return (rdd, dep)
+ (rdd, dep)
}
}
@@ -75,7 +75,7 @@ private[spark] object ShuffleMapTask {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val objIn = new ObjectInputStream(in)
val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
- return (HashMap(set.toSeq: _*))
+ HashMap(set.toSeq: _*)
}
def clearCache() {
@@ -85,13 +85,25 @@ private[spark] object ShuffleMapTask {
}
}
+/**
+ * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
+ * specified in the ShuffleDependency).
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd the final RDD in this stage
+ * @param dep the ShuffleDependency
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ */
private[spark] class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
- var partition: Int,
+ _partitionId: Int,
@transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId)
+ extends Task[MapStatus](stageId, _partitionId)
with Externalizable
with Logging {
@@ -101,16 +113,16 @@ private[spark] class ShuffleMapTask(
if (locs == null) Nil else locs.toSet.toSeq
}
- var split = if (rdd == null) null else rdd.partitions(partition)
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.partitions(partition)
+ split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
out.write(bytes)
- out.writeInt(partition)
+ out.writeInt(partitionId)
out.writeLong(epoch)
out.writeObject(split)
}
@@ -124,68 +136,70 @@ private[spark] class ShuffleMapTask(
val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
rdd = rdd_
dep = dep_
- partition = in.readInt()
+ partitionId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
}
- override def run(attemptId: Long): MapStatus = {
+ override def runTask(context: TaskContext): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
-
- val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
- metrics = Some(taskContext.taskMetrics)
+ metrics = Some(context.taskMetrics)
val blockManager = SparkEnv.get.blockManager
- var shuffle: ShuffleBlocks = null
- var buckets: ShuffleWriterGroup = null
+ val shuffleBlockManager = blockManager.shuffleBlockManager
+ var shuffle: ShuffleWriterGroup = null
+ var success = false
try {
// Obtain all the block writers for shuffle blocks.
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
- shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
- buckets = shuffle.acquireWriters(partition)
+ shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
// Write the map output to its associated buckets.
- for (elem <- rdd.iterator(split, taskContext)) {
+ for (elem <- rdd.iterator(split, context)) {
val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
- buckets.writers(bucketId).write(pair)
+ shuffle.writers(bucketId).write(pair)
}
// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
- val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+ var totalTime = 0L
+ val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
- writer.close()
- val size = writer.size()
+ val size = writer.fileSegment().length
totalBytes += size
+ totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
}
// Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
+ shuffleMetrics.shuffleWriteTime = totalTime
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
- return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ success = true
+ new MapStatus(blockManager.blockManagerId, compressedSizes)
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
- if (buckets != null) {
- buckets.writers.foreach(_.revertPartialWrites())
+ if (shuffle != null) {
+ shuffle.writers.foreach(_.revertPartialWrites())
}
throw e
} finally {
// Release the writers back to the shuffle block manager.
- if (shuffle != null && buckets != null) {
- shuffle.releaseWriters(buckets)
+ if (shuffle != null && shuffle.writers != null) {
+ shuffle.writers.foreach(_.close())
+ shuffle.releaseWriters(success)
}
// Execute the callbacks on task completion.
- taskContext.executeOnCompleteCallbacks()
+ context.executeOnCompleteCallbacks()
}
}
override def preferredLocations: Seq[TaskLocation] = preferredLocs
- override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
+ override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 62b521ad45..a35081f7b1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -24,13 +24,16 @@ import org.apache.spark.executor.TaskMetrics
sealed trait SparkListenerEvents
-case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties)
+case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents
-case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+case class SparkListenerTaskGettingResult(
+ task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
taskMetrics: TaskMetrics) extends SparkListenerEvents
@@ -54,7 +57,13 @@ trait SparkListener {
/**
* Called when a task starts
*/
- def onTaskStart(taskEnd: SparkListenerTaskStart) { }
+ def onTaskStart(taskStart: SparkListenerTaskStart) { }
+
+ /**
+ * Called when a task begins remotely fetching its result (will not be called for tasks that do
+ * not need to fetch the result remotely).
+ */
+ def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
/**
* Called when a task ends
@@ -80,7 +89,7 @@ class StatsReportListener extends SparkListener with Logging {
override def onStageCompleted(stageCompleted: StageCompleted) {
import org.apache.spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
- this.logInfo("Finished stage: " + stageCompleted.stageInfo)
+ this.logInfo("Finished stage: " + stageCompleted.stage)
showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
//shuffle write
@@ -93,7 +102,7 @@ class StatsReportListener extends SparkListener with Logging {
//runtime breakdown
- val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
+ val runtimePcts = stageCompleted.stage.taskInfos.map{
case (info, metrics) => RuntimePercentage(info.duration, metrics)
}
showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
@@ -111,7 +120,7 @@ object StatsReportListener extends Logging {
val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
- Distribution(stage.stageInfo.taskInfos.flatMap{
+ Distribution(stage.stage.taskInfos.flatMap {
case ((info,metric)) => getMetric(info, metric)})
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index 4d3e4a17ba..d5824e7954 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging {
sparkListeners.foreach(_.onJobEnd(jobEnd))
case taskStart: SparkListenerTaskStart =>
sparkListeners.foreach(_.onTaskStart(taskStart))
+ case taskGettingResult: SparkListenerTaskGettingResult =>
+ sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index aa293dc6b3..7cb3fe46e5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId
private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
+ val numTasks: Int,
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
@@ -49,11 +50,6 @@ private[spark] class Stage(
val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
-
- /** When first task was submitted to scheduler. */
- var submissionTime: Option[Long] = None
- var completionTime: Option[Long] = None
-
private var nextAttemptId = 0
def isAvailable: Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index b6f11969e5..93599dfdc8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -21,9 +21,16 @@ import scala.collection._
import org.apache.spark.executor.TaskMetrics
-case class StageInfo(
- val stage: Stage,
+class StageInfo(
+ stage: Stage,
val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
) {
- override def toString = stage.rdd.toString
+ val stageId = stage.id
+ /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
+ var submissionTime: Option[Long] = None
+ var completionTime: Option[Long] = None
+ val rddName = stage.rdd.name
+ val name = stage.name
+ val numPartitions = stage.numPartitions
+ val numTasks = stage.numTasks
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 598d91752a..69b42e86ea 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -17,25 +17,74 @@
package org.apache.spark.scheduler
-import org.apache.spark.serializer.SerializerInstance
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import org.apache.spark.util.ByteBufferInputStream
+
import scala.collection.mutable.HashMap
+
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
+import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.util.ByteBufferInputStream
+
/**
- * A task to execute on a worker node.
+ * A unit of execution. We have two kinds of Task's in Spark:
+ * - [[org.apache.spark.scheduler.ShuffleMapTask]]
+ * - [[org.apache.spark.scheduler.ResultTask]]
+ *
+ * A Spark job consists of one or more stages. The very last stage in a job consists of multiple
+ * ResultTask's, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task
+ * and sends the task output back to the driver application. A ShuffleMapTask executes the task
+ * and divides the task output to multiple buckets (based on the task's partitioner).
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param partitionId index of the number in the RDD
*/
-private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
- def run(attemptId: Long): T
+private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
+
+ final def run(attemptId: Long): T = {
+ context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
+ if (_killed) {
+ kill()
+ }
+ runTask(context)
+ }
+
+ def runTask(context: TaskContext): T
+
def preferredLocations: Seq[TaskLocation] = Nil
- var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler.
+ // Map output tracker epoch. Will be set by TaskScheduler.
+ var epoch: Long = -1
var metrics: Option[TaskMetrics] = None
+ // Task context, to be initialized in run().
+ @transient protected var context: TaskContext = _
+
+ // A flag to indicate whether the task is killed. This is used in case context is not yet
+ // initialized when kill() is invoked.
+ @volatile @transient private var _killed = false
+
+ /**
+ * Whether the task has been killed.
+ */
+ def killed: Boolean = _killed
+
+ /**
+ * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
+ * code and user code to properly handle the flag. This function should be idempotent so it can
+ * be called multiple times.
+ */
+ def kill() {
+ _killed = true
+ if (context != null) {
+ context.interrupted = true
+ }
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 7c2a422aff..4bae26f3a6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -31,9 +31,25 @@ class TaskInfo(
val host: String,
val taskLocality: TaskLocality.TaskLocality) {
+ /**
+ * The time when the task started remotely getting the result. Will not be set if the
+ * task result was sent immediately when the task finished (as opposed to sending an
+ * IndirectTaskResult and later fetching the result from the block manager).
+ */
+ var gettingResultTime: Long = 0
+
+ /**
+ * The time when the task has completed successfully (including the time to remotely fetch
+ * results, if necessary).
+ */
var finishTime: Long = 0
+
var failed = false
+ def markGettingResult(time: Long = System.currentTimeMillis) {
+ gettingResultTime = time
+ }
+
def markSuccessful(time: Long = System.currentTimeMillis) {
finishTime = time
}
@@ -43,6 +59,8 @@ class TaskInfo(
failed = true
}
+ def gettingResult: Boolean = gettingResultTime != 0
+
def finished: Boolean = finishTime != 0
def successful: Boolean = finished && !failed
@@ -52,6 +70,8 @@ class TaskInfo(
def status: String = {
if (running)
"RUNNING"
+ else if (gettingResult)
+ "GET RESULT"
else if (failed)
"FAILED"
else if (successful)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index db3954a9d3..7e468d0d67 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -24,13 +24,14 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.{SparkEnv}
import java.nio.ByteBuffer
import org.apache.spark.util.Utils
+import org.apache.spark.storage.BlockId
// Task result. Also contains updates to accumulator variables.
private[spark] sealed trait TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
private[spark]
-case class IndirectTaskResult[T](val blockId: String) extends TaskResult[T] with Serializable
+case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 7c2a9f03d7..10e0478108 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -24,8 +24,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
* Each TaskScheduler schedulers task for a single SparkContext.
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
* and are responsible for sending the tasks to the cluster, running them, retrying if there
- * are failures, and mitigating stragglers. They return events to the DAGScheduler through
- * the TaskSchedulerListener interface.
+ * are failures, and mitigating stragglers. They return events to the DAGScheduler.
*/
private[spark] trait TaskScheduler {
@@ -45,8 +44,11 @@ private[spark] trait TaskScheduler {
// Submit a sequence of tasks to run.
def submitTasks(taskSet: TaskSet): Unit
- // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
- def setListener(listener: TaskSchedulerListener): Unit
+ // Cancel a stage.
+ def cancelTasks(stageId: Int)
+
+ // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
+ def setDAGScheduler(dagScheduler: DAGScheduler): Unit
// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
deleted file mode 100644
index 593fa9fb93..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import scala.collection.mutable.Map
-
-import org.apache.spark.TaskEndReason
-import org.apache.spark.executor.TaskMetrics
-
-/**
- * Interface for getting events back from the TaskScheduler.
- */
-private[spark] trait TaskSchedulerListener {
- // A task has started.
- def taskStarted(task: Task[_], taskInfo: TaskInfo)
-
- // A task has finished or failed.
- def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
- taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
-
- // A node was added to the cluster.
- def executorGained(execId: String, host: String): Unit
-
- // A node was lost from the cluster.
- def executorLost(execId: String): Unit
-
- // The TaskScheduler wants to abort an entire task set.
- def taskSetFailed(taskSet: TaskSet, reason: String): Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
index c3ad325156..03bf760837 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -31,5 +31,9 @@ private[spark] class TaskSet(
val properties: Properties) {
val id: String = stageId + "." + attempt
+ def kill() {
+ tasks.foreach(_.kill())
+ }
+
override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 1a844b7e7e..85033958ef 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -17,7 +17,6 @@
package org.apache.spark.scheduler.cluster
-import java.lang.{Boolean => JBoolean}
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
import java.util.{TimerTask, Timer}
@@ -79,14 +78,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
private val executorIdToHost = new HashMap[String, String]
- // JAR server, if any JARs were added by the user to the SparkContext
- var jarServer: HttpServer = null
-
- // URIs of JARs to pass to executor
- var jarUris: String = ""
-
// Listener object to pass upcalls into
- var listener: TaskSchedulerListener = null
+ var dagScheduler: DAGScheduler = null
var backend: SchedulerBackend = null
@@ -101,8 +94,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
+ override def setDAGScheduler(dagScheduler: DAGScheduler) {
+ this.dagScheduler = dagScheduler
}
def initialize(context: SchedulerBackend) {
@@ -171,8 +164,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.reviveOffers()
}
- def taskSetFinished(manager: TaskSetManager) {
- this.synchronized {
+ override def cancelTasks(stageId: Int): Unit = synchronized {
+ logInfo("Cancelling stage " + stageId)
+ activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task and then abort
+ // the stage.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the stage.
+ val taskIds = taskSetTaskIds(tsm.taskSet.id)
+ if (taskIds.size > 0) {
+ taskIds.foreach { tid =>
+ val execId = taskIdToExecutorId(tid)
+ backend.killTask(tid, execId)
+ }
+ }
+ tsm.error("Stage %d was cancelled".format(stageId))
+ }
+ }
+
+ def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
+ // Check to see if the given task set has been removed. This is possible in the case of
+ // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
+ // more than one running tasks).
+ if (activeTaskSets.contains(manager.taskSet.id)) {
activeTaskSets -= manager.taskSet.id
manager.parent.removeSchedulable(manager)
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
@@ -281,7 +297,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
// Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor != None) {
- listener.executorLost(failedExecutor.get)
+ dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
if (taskFailed) {
@@ -290,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
+ def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
+ taskSetManager.handleTaskGettingResult(tid)
+ }
+
def handleSuccessfulTask(
taskSetManager: ClusterTaskSetManager,
tid: Long,
@@ -334,9 +354,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (backend != null) {
backend.stop()
}
- if (jarServer != null) {
- jarServer.stop()
- }
if (taskResultGetter != null) {
taskResultGetter.stop()
}
@@ -384,9 +401,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
logError("Lost an executor " + executorId + " (already removed): " + reason)
}
}
- // Call listener.executorLost without holding the lock on this to prevent deadlock
+ // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor != None) {
- listener.executorLost(failedExecutor.get)
+ dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
}
@@ -405,7 +422,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
def executorGained(execId: String, host: String) {
- listener.executorGained(execId, host)
+ dagScheduler.executorGained(execId, host)
}
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 194ab55102..ee47aaffca 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -17,18 +17,16 @@
package org.apache.spark.scheduler.cluster
-import java.nio.ByteBuffer
-import java.util.{Arrays, NoSuchElementException}
+import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
-import scala.Some
import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
- SparkException, Success, TaskEndReason, TaskResultLost, TaskState}
+ Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
import org.apache.spark.util.{SystemClock, Clock}
@@ -417,11 +415,17 @@ private[spark] class ClusterTaskSetManager(
}
private def taskStarted(task: Task[_], info: TaskInfo) {
- sched.listener.taskStarted(task, info)
+ sched.dagScheduler.taskStarted(task, info)
+ }
+
+ def handleTaskGettingResult(tid: Long) = {
+ val info = taskInfos(tid)
+ info.markGettingResult()
+ sched.dagScheduler.taskGettingResult(tasks(info.index), info)
}
/**
- * Marks the task as successful and notifies the listener that a task has ended.
+ * Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
val info = taskInfos(tid)
@@ -431,7 +435,7 @@ private[spark] class ClusterTaskSetManager(
if (!successful(index)) {
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
tid, info.duration, info.host, tasksSuccessful, numTasks))
- sched.listener.taskEnded(
+ sched.dagScheduler.taskEnded(
tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
// Mark successful and stop if all the tasks have succeeded.
@@ -447,7 +451,8 @@ private[spark] class ClusterTaskSetManager(
}
/**
- * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener.
+ * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
+ * DAG Scheduler.
*/
def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
val info = taskInfos(tid)
@@ -458,54 +463,57 @@ private[spark] class ClusterTaskSetManager(
val index = info.index
info.markFailed()
if (!successful(index)) {
- logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
reason.foreach {
- _ match {
- case fetchFailed: FetchFailed =>
- logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- successful(index) = true
- tasksSuccessful += 1
- sched.taskSetFinished(this)
- removeAllRunningTasks()
- return
-
- case ef: ExceptionFailure =>
- sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
- val key = ef.description
- val now = clock.getTime()
- val (printFull, dupCount) = {
- if (recentExceptions.contains(key)) {
- val (dupCount, printTime) = recentExceptions(key)
- if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
- recentExceptions(key) = (0, now)
- (true, 0)
- } else {
- recentExceptions(key) = (dupCount + 1, printTime)
- (false, dupCount + 1)
- }
- } else {
+ case fetchFailed: FetchFailed =>
+ logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ successful(index) = true
+ tasksSuccessful += 1
+ sched.taskSetFinished(this)
+ removeAllRunningTasks()
+ return
+
+ case TaskKilled =>
+ logWarning("Task %d was killed.".format(tid))
+ sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
+ return
+
+ case ef: ExceptionFailure =>
+ sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ val key = ef.description
+ val now = clock.getTime()
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
recentExceptions(key) = (0, now)
(true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
}
- }
- if (printFull) {
- val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s\n%s".format(
- ef.className, ef.description, locs.mkString("\n")))
} else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ recentExceptions(key) = (0, now)
+ (true, 0)
}
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logWarning("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
- case TaskResultLost =>
- logInfo("Lost result for TID %s on host %s".format(tid, info.host))
- sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+ case TaskResultLost =>
+ logWarning("Lost result for TID %s on host %s".format(tid, info.host))
+ sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
- case _ => {}
- }
+ case _ => {}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
@@ -532,7 +540,7 @@ private[spark] class ClusterTaskSetManager(
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.listener.taskSetFailed(taskSet, message)
+ sched.dagScheduler.taskSetFailed(taskSet, message)
removeAllRunningTasks()
sched.taskSetFinished(this)
}
@@ -605,7 +613,7 @@ private[spark] class ClusterTaskSetManager(
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
@@ -630,11 +638,11 @@ private[spark] class ClusterTaskSetManager(
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksSuccessful >= minFinishedForSpeculation) {
+ if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
val time = clock.getTime()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
- val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
+ val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index c0b836bf1a..53316dae2a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -24,26 +24,28 @@ import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.util.{Utils, SerializableBuffer}
-private[spark] sealed trait StandaloneClusterMessage extends Serializable
+private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
-private[spark] object StandaloneClusterMessages {
+private[spark] object CoarseGrainedClusterMessages {
// Driver to executors
- case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
+ case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage
+
+ case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage
case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
- extends StandaloneClusterMessage
+ extends CoarseGrainedClusterMessage
- case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage
+ case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage
// Executors to driver
case class RegisterExecutor(executorId: String, hostPort: String, cores: Int)
- extends StandaloneClusterMessage {
+ extends CoarseGrainedClusterMessage {
Utils.checkHostPort(hostPort, "Expected host port")
}
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState,
- data: SerializableBuffer) extends StandaloneClusterMessage
+ data: SerializableBuffer) extends CoarseGrainedClusterMessage
object StatusUpdate {
/** Alternate factory method that takes a ByteBuffer directly for the data field */
@@ -54,10 +56,14 @@ private[spark] object StandaloneClusterMessages {
}
// Internal messages in driver
- case object ReviveOffers extends StandaloneClusterMessage
+ case object ReviveOffers extends CoarseGrainedClusterMessage
+
+ case object StopDriver extends CoarseGrainedClusterMessage
+
+ case object StopExecutor extends CoarseGrainedClusterMessage
- case object StopDriver extends StandaloneClusterMessage
+ case object StopExecutors extends CoarseGrainedClusterMessage
- case class RemoveExecutor(executorId: String, reason: String) extends StandaloneClusterMessage
+ case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index b6f0ec961a..3ccc38d72b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -29,16 +29,19 @@ import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycle
import org.apache.spark.{SparkException, Logging, TaskState}
import org.apache.spark.scheduler.TaskDescription
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.Utils
/**
- * A standalone scheduler backend, which waits for standalone executors to connect to it through
- * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained
- * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*).
+ * A scheduler backend that waits for coarse grained executors to connect to it through Akka.
+ * This backend holds onto each executor for the duration of the Spark job rather than relinquishing
+ * executors whenever a task is done and asking the scheduler to launch a new executor for
+ * each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the
+ * coarse-grained Mesos mode or standalone processes for Spark's standalone deploy mode
+ * (spark.deploy.*).
*/
private[spark]
-class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
+class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
extends SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
@@ -84,17 +87,33 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
- freeCores(executorId) += 1
- makeOffers(executorId)
+ if (executorActor.contains(executorId)) {
+ freeCores(executorId) += 1
+ makeOffers(executorId)
+ } else {
+ // Ignoring the update since we don't know about the executor.
+ val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s"
+ logWarning(msg.format(taskId, state, sender, executorId))
+ }
}
case ReviveOffers =>
makeOffers()
+ case KillTask(taskId, executorId) =>
+ executorActor(executorId) ! KillTask(taskId, executorId)
+
case StopDriver =>
sender ! true
context.stop(self)
+ case StopExecutors =>
+ logInfo("Asking each executor to shut down")
+ for (executor <- executorActor.values) {
+ executor ! StopExecutor
+ }
+ sender ! true
+
case RemoveExecutor(executorId, reason) =>
removeExecutor(executorId, reason)
sender ! true
@@ -159,16 +178,31 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
}
driverActor = actorSystem.actorOf(
- Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
+ Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
}
- private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ private val timeout = {
+ Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ }
+
+ def stopExecutors() {
+ try {
+ if (driverActor != null) {
+ logInfo("Shutting down all executors")
+ val future = driverActor.ask(StopExecutors)(timeout)
+ Await.ready(future, timeout)
+ }
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Error asking standalone scheduler to shut down executors", e)
+ }
+ }
override def stop() {
try {
if (driverActor != null) {
val future = driverActor.ask(StopDriver)(timeout)
- Await.result(future, timeout)
+ Await.ready(future, timeout)
}
} catch {
case e: Exception =>
@@ -180,6 +214,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
driverActor ! ReviveOffers
}
+ override def killTask(taskId: Long, executorId: String) {
+ driverActor ! KillTask(taskId, executorId)
+ }
+
override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
.map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2))
@@ -187,7 +225,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
def removeExecutor(executorId: String, reason: String) {
try {
val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
- Await.result(future, timeout)
+ Await.ready(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error notifying standalone scheduler's driver actor", e)
@@ -195,6 +233,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
}
-private[spark] object StandaloneSchedulerBackend {
- val ACTOR_NAME = "StandaloneScheduler"
+private[spark] object CoarseGrainedSchedulerBackend {
+ val ACTOR_NAME = "CoarseGrainedScheduler"
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
index d57eb3276f..5367218faa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler.cluster
-import org.apache.spark.{SparkContext}
+import org.apache.spark.SparkContext
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
@@ -30,8 +30,8 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
+ def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException
+
// Memory used by each executor (in megabytes)
protected val executorMemory: Int = SparkContext.executorMemoryRequested
-
- // TODO: Probably want to add a killTask too
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
new file mode 100644
index 0000000000..d78bdbaa7a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{Path, FileSystem}
+import org.apache.spark.{Logging, SparkContext}
+
+private[spark] class SimrSchedulerBackend(
+ scheduler: ClusterScheduler,
+ sc: SparkContext,
+ driverFilePath: String)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ with Logging {
+
+ val tmpPath = new Path(driverFilePath + "_tmp")
+ val filePath = new Path(driverFilePath)
+
+ val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt
+
+ override def start() {
+ super.start()
+
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+ val conf = new Configuration()
+ val fs = FileSystem.get(conf)
+
+ logInfo("Writing to HDFS file: " + driverFilePath)
+ logInfo("Writing Akka address: " + driverUrl)
+
+ // Create temporary file to prevent race condition where executors get empty driverUrl file
+ val temp = fs.create(tmpPath, true)
+ temp.writeUTF(driverUrl)
+ temp.writeInt(maxCores)
+ temp.close()
+
+ // "Atomic" rename
+ fs.rename(tmpPath, filePath)
+ }
+
+ override def stop() {
+ val conf = new Configuration()
+ val fs = FileSystem.get(conf)
+ fs.delete(new Path(driverFilePath), false)
+ super.stopExecutors()
+ super.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index fa83ae19d6..7127a72d6d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -26,9 +26,9 @@ import org.apache.spark.util.Utils
private[spark] class SparkDeploySchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
- master: String,
+ masters: Array[String],
appName: String)
- extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with ClientListener
with Logging {
@@ -44,15 +44,15 @@ private[spark] class SparkDeploySchedulerBackend(
// The endpoint for executors to talk to us
val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
- StandaloneSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command(
- "org.apache.spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
+ "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs)
val sparkHome = sc.getSparkHome().getOrElse(null)
val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome,
"http://" + sc.ui.appUIAddress)
- client = new Client(sc.env.actorSystem, master, appDesc, this)
+ client = new Client(sc.env.actorSystem, masters, appDesc, this)
client.start()
}
@@ -71,8 +71,14 @@ private[spark] class SparkDeploySchedulerBackend(
override def disconnected() {
if (!stopping) {
- logError("Disconnected from Spark cluster!")
- scheduler.error("Disconnected from Spark cluster")
+ logWarning("Disconnected from Spark cluster! Waiting for reconnection...")
+ }
+ }
+
+ override def dead() {
+ if (!stopping) {
+ logError("Spark cluster looks dead, giving up.")
+ scheduler.error("Spark cluster looks down")
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
index b2a8f06472..e68c527713 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
@@ -24,33 +24,16 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.util.Utils
/**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/
private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends Logging {
- private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt
- private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt
- private val getTaskResultExecutor = new ThreadPoolExecutor(
- MIN_THREADS,
- MAX_THREADS,
- 0L,
- TimeUnit.SECONDS,
- new LinkedBlockingDeque[Runnable],
- new ResultResolverThreadFactory)
-
- class ResultResolverThreadFactory extends ThreadFactory {
- private var counter = 0
- private var PREFIX = "Result resolver thread"
-
- override def newThread(r: Runnable): Thread = {
- val thread = new Thread(r, "%s-%s".format(PREFIX, counter))
- counter += 1
- thread.setDaemon(true)
- return thread
- }
- }
+ private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt
+ private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
+ THREADS, "Result resolver thread")
protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
@@ -67,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
case directResult: DirectTaskResult[_] => directResult
case IndirectTaskResult(blockId) =>
logDebug("Fetching indirect task result for TID %s".format(tid))
+ scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
if (!serializedTaskResult.isDefined) {
/* We won't be able to get the task result if the machine that ran the task failed
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index bf4040fafc..8de9b72b2f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -30,13 +30,14 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{SparkException, Logging, SparkContext, TaskState}
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
* onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever
* a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the
- * StandaloneBackend mechanism. This class is useful for lower and more predictable latency.
+ * CoarseGrainedSchedulerBackend mechanism. This class is useful for lower and more predictable
+ * latency.
*
* Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to
* remove this.
@@ -46,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend(
sc: SparkContext,
master: String,
appName: String)
- extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with MScheduler
with Logging {
@@ -122,20 +123,20 @@ private[spark] class CoarseMesosSchedulerBackend(
val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"),
System.getProperty("spark.driver.port"),
- StandaloneSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
val uri = System.getProperty("spark.executor.uri")
if (uri == null) {
val runScript = new File(sparkHome, "spark-class").getCanonicalPath
command.setValue(
- "\"%s\" org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
+ "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format(
runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
- "cd %s*; ./spark-class org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
- basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
+ "cd %s*; ./spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d"
+ .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
return command.build()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index 4d1bb1c639..2699f0b33e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -17,23 +17,19 @@
package org.apache.spark.scheduler.local
-import java.io.File
-import java.lang.management.ManagementFactory
-import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicInteger
-import scala.collection.JavaConversions._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
+import akka.actor._
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.ExecutorURLClassLoader
+import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-import akka.actor._
-import org.apache.spark.util.Utils
+
/**
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
@@ -41,52 +37,57 @@ import org.apache.spark.util.Utils
* testing fault recovery.
*/
-private[spark]
+private[local]
case class LocalReviveOffers()
-private[spark]
+private[local]
case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+private[local]
+case class KillTask(taskId: Long)
+
private[spark]
-class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
+class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
+ extends Actor with Logging {
+
+ val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true)
def receive = {
case LocalReviveOffers =>
launchTask(localScheduler.resourceOffer(freeCores))
+
case LocalStatusUpdate(taskId, state, serializeData) =>
- freeCores += 1
- localScheduler.statusUpdate(taskId, state, serializeData)
- launchTask(localScheduler.resourceOffer(freeCores))
+ if (TaskState.isFinished(state)) {
+ freeCores += 1
+ launchTask(localScheduler.resourceOffer(freeCores))
+ }
+
+ case KillTask(taskId) =>
+ executor.killTask(taskId)
}
- def launchTask(tasks : Seq[TaskDescription]) {
+ private def launchTask(tasks: Seq[TaskDescription]) {
for (task <- tasks) {
freeCores -= 1
- localScheduler.threadPool.submit(new Runnable {
- def run() {
- localScheduler.runTask(task.taskId, task.serializedTask)
- }
- })
+ executor.launchTask(localScheduler, task.taskId, task.serializedTask)
}
}
}
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
+ with ExecutorBackend
with Logging {
- var attemptId = new AtomicInteger(0)
- var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get
- var listener: TaskSchedulerListener = null
+ val attemptId = new AtomicInteger
+ var dagScheduler: DAGScheduler = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
- val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
-
var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
val schedulingMode: SchedulingMode = SchedulingMode.withName(
@@ -113,8 +114,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
}
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
+ override def setDAGScheduler(dagScheduler: DAGScheduler) {
+ this.dagScheduler = dagScheduler
}
override def submitTasks(taskSet: TaskSet) {
@@ -127,6 +128,26 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}
+ override def cancelTasks(stageId: Int): Unit = synchronized {
+ logInfo("Cancelling stage " + stageId)
+ logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId))
+ activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task and then abort
+ // the stage.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the stage.
+ val taskIds = taskSetTaskIds(tsm.taskSet.id)
+ if (taskIds.size > 0) {
+ taskIds.foreach { tid =>
+ localActor ! KillTask(tid)
+ }
+ }
+ tsm.error("Stage %d was cancelled".format(stageId))
+ }
+ }
+
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
synchronized {
var freeCpuCores = freeCores
@@ -166,107 +187,32 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}
- def runTask(taskId: Long, bytes: ByteBuffer) {
- logInfo("Running " + taskId)
- val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
- // Set the Spark execution environment for the worker thread
- SparkEnv.set(env)
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val objectSer = SparkEnv.get.serializer.newInstance()
- var attemptedTask: Option[Task[_]] = None
- val start = System.currentTimeMillis()
- var taskStart: Long = 0
- def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
- val startGCTime = getTotalGCTime
-
- try {
- Accumulators.clear()
- Thread.currentThread().setContextClassLoader(classLoader)
-
- // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
- // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
- updateDependencies(taskFiles, taskJars) // Download any files added with addFile
- val deserializedTask = ser.deserialize[Task[_]](
- taskBytes, Thread.currentThread.getContextClassLoader)
- attemptedTask = Some(deserializedTask)
- val deserTime = System.currentTimeMillis() - start
- taskStart = System.currentTimeMillis()
-
- // Run it
- val result: Any = deserializedTask.run(taskId)
-
- // Serialize and deserialize the result to emulate what the Mesos
- // executor does. This is useful to catch serialization errors early
- // on in development (so when users move their local Spark programs
- // to the cluster, they don't get surprised by serialization errors).
- val serResult = objectSer.serialize(result)
- deserializedTask.metrics.get.resultSize = serResult.limit()
- val resultToReturn = objectSer.deserialize[Any](serResult)
- val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
- ser.serialize(Accumulators.values))
- val serviceTime = System.currentTimeMillis() - taskStart
- logInfo("Finished " + taskId)
- deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
- deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime
- deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
- val taskResult = new DirectTaskResult(
- result, accumUpdates, deserializedTask.metrics.getOrElse(null))
- val serializedResult = ser.serialize(taskResult)
- localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
- } catch {
- case t: Throwable => {
- val serviceTime = System.currentTimeMillis() - taskStart
- val metrics = attemptedTask.flatMap(t => t.metrics)
- for (m <- metrics) {
- m.executorRunTime = serviceTime.toInt
- m.jvmGCTime = getTotalGCTime - startGCTime
- }
- val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
- localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
- }
- }
- }
-
- /**
- * Download any missing dependencies if we receive a new set of files and JARs from the
- * SparkContext. Also adds any new JARs we fetched to the class loader.
- */
- private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- synchronized {
- // Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
- currentFiles(name) = timestamp
- }
-
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
- currentJars(name) = timestamp
- // Add it to our class loader
- val localName = name.split("/").last
- val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
- if (!classLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- classLoader.addURL(url)
+ override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
+ if (TaskState.isFinished(state)) {
+ synchronized {
+ taskIdToTaskSetId.get(taskId) match {
+ case Some(taskSetId) =>
+ val taskSetManager = activeTaskSets(taskSetId)
+ taskSetTaskIds(taskSetId) -= taskId
+
+ state match {
+ case TaskState.FINISHED =>
+ taskSetManager.taskEnded(taskId, state, serializedData)
+ case TaskState.FAILED =>
+ taskSetManager.taskFailed(taskId, state, serializedData)
+ case TaskState.KILLED =>
+ taskSetManager.error("Task %d was killed".format(taskId))
+ case _ => {}
+ }
+ case None =>
+ logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
}
}
+ localActor ! LocalStatusUpdate(taskId, state, serializedData)
}
}
- def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
- synchronized {
- val taskSetId = taskIdToTaskSetId(taskId)
- val taskSetManager = activeTaskSets(taskSetId)
- taskSetTaskIds(taskSetId) -= taskId
- taskSetManager.statusUpdate(taskId, state, serializedData)
- }
- }
-
- override def stop() {
- threadPool.shutdownNow()
+ override def stop() {
}
override def defaultParallelism() = threads
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
index c2e2399ccb..53bf78267e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -132,19 +132,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
return None
}
- def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- SparkEnv.set(env)
- state match {
- case TaskState.FINISHED =>
- taskEnded(tid, state, serializedData)
- case TaskState.FAILED =>
- taskFailed(tid, state, serializedData)
- case _ => {}
- }
- }
-
def taskStarted(task: Task[_], info: TaskInfo) {
- sched.listener.taskStarted(task, info)
+ sched.dagScheduler.taskStarted(task, info)
}
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
@@ -159,7 +148,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
}
result.metrics.resultSize = serializedData.limit()
- sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
+ sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
+ result.metrics)
numFinished += 1
decreaseRunningTasks(1)
finished(index) = true
@@ -176,7 +166,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
serializedData, getClass.getClassLoader)
- sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
+ sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
if (!finished(index)) {
copiesRunning(index) -= 1
numFailures(index) += 1
@@ -185,9 +175,9 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
reason.className, reason.description, locs.mkString("\n")))
if (numFailures(index) > MAX_TASK_FAILURES) {
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
- taskSet.id, index, 4, reason.description)
+ taskSet.id, index, MAX_TASK_FAILURES, reason.description)
decreaseRunningTasks(runningTasks)
- sched.listener.taskSetFailed(taskSet, errorMessage)
+ sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue
sched.taskSetFinished(this)
}
@@ -195,5 +185,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
override def error(message: String) {
+ sched.dagScheduler.taskSetFailed(taskSet, message)
+ sched.taskSetFinished(this)
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index e936b1cfed..55b25f145a 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -26,9 +26,8 @@ import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar}
import org.apache.spark.{SerializableWritable, Logging}
-import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock, StorageLevel}
-
import org.apache.spark.broadcast.HttpBroadcast
+import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId}
/**
* A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
@@ -43,13 +42,14 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
val kryo = instantiator.newKryo()
val classLoader = Thread.currentThread.getContextClassLoader
+ val blockId = TestBlockId("1")
// Register some commonly used classes
val toRegister: Seq[AnyRef] = Seq(
ByteBuffer.allocate(1),
StorageLevel.MEMORY_ONLY,
- PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
- GotBlock("1", ByteBuffer.allocate(1)),
- GetBlock("1"),
+ PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
+ GotBlock(blockId, ByteBuffer.allocate(1)),
+ GetBlock(blockId),
1 to 10,
1 until 10,
1L to 10L,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockException.scala b/core/src/main/scala/org/apache/spark/storage/BlockException.scala
index 290dbce4f5..0d0a2dadc7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockException.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockException.scala
@@ -18,5 +18,5 @@
package org.apache.spark.storage
private[spark]
-case class BlockException(blockId: String, message: String) extends Exception(message)
+case class BlockException(blockId: BlockId, message: String) extends Exception(message)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index 3aeda3879d..e51c5b30a3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -47,7 +47,7 @@ import org.apache.spark.util.Utils
*/
private[storage]
-trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])]
+trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])]
with Logging with BlockFetchTracker {
def initialize()
}
@@ -57,20 +57,20 @@ private[storage]
object BlockFetcherIterator {
// A request to fetch one or more blocks, complete with their sizes
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
// A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
// the block (since we want all deserializaton to happen in the calling thread); can also
// represent a fetch failure if size == -1.
- class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
def failed: Boolean = size == -1
}
class BasicBlockFetcherIterator(
private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
extends BlockFetcherIterator {
@@ -92,12 +92,12 @@ object BlockFetcherIterator {
// This represents the number of local blocks, also counting zero-sized blocks
private var numLocal = 0
// BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
- protected val localBlocksToFetch = new ArrayBuffer[String]()
+ protected val localBlocksToFetch = new ArrayBuffer[BlockId]()
// This represents the number of remote blocks, also counting zero-sized blocks
private var numRemote = 0
// BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
- protected val remoteBlocksToFetch = new HashSet[String]()
+ protected val remoteBlocksToFetch = new HashSet[BlockId]()
// A queue to hold our results.
protected val results = new LinkedBlockingQueue[FetchResult]
@@ -167,7 +167,7 @@ object BlockFetcherIterator {
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
val iterator = blockInfos.iterator
var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(String, Long)]
+ var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
// Skip empty blocks
@@ -183,7 +183,7 @@ object BlockFetcherIterator {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curRequestSize = 0
- curBlocks = new ArrayBuffer[(String, Long)]
+ curBlocks = new ArrayBuffer[(BlockId, Long)]
}
}
// Add in the final request
@@ -241,7 +241,7 @@ object BlockFetcherIterator {
override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
- override def next(): (String, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Option[Iterator[Any]]) = {
resultsGotten += 1
val startFetchWait = System.currentTimeMillis()
val result = results.take()
@@ -267,7 +267,7 @@ object BlockFetcherIterator {
class NettyBlockFetcherIterator(
blockManager: BlockManager,
- blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
@@ -303,7 +303,7 @@ object BlockFetcherIterator {
override protected def sendRequest(req: FetchRequest) {
- def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) {
+ def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) {
val fetchResult = new FetchResult(blockId, blockSize,
() => dataDeserialize(blockId, blockData.nioBuffer, serializer))
results.put(fetchResult)
@@ -337,7 +337,7 @@ object BlockFetcherIterator {
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
}
- override def next(): (String, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Option[Iterator[Any]]) = {
resultsGotten += 1
val result = results.take()
// If all the results has been retrieved, copiers will exit automatically
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
new file mode 100644
index 0000000000..7156d855d8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+/**
+ * Identifies a particular Block of data, usually associated with a single file.
+ * A Block can be uniquely identified by its filename, but each type of Block has a different
+ * set of keys which produce its unique name.
+ *
+ * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method.
+ */
+private[spark] sealed abstract class BlockId {
+ /** A globally unique identifier for this Block. Can be used for ser/de. */
+ def name: String
+
+ // convenience methods
+ def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
+ def isRDD = isInstanceOf[RDDBlockId]
+ def isShuffle = isInstanceOf[ShuffleBlockId]
+ def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
+
+ override def toString = name
+ override def hashCode = name.hashCode
+ override def equals(other: Any): Boolean = other match {
+ case o: BlockId => getClass == o.getClass && name.equals(o.name)
+ case _ => false
+ }
+}
+
+private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
+ def name = "rdd_" + rddId + "_" + splitIndex
+}
+
+private[spark]
+case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+ def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
+}
+
+private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
+ def name = "broadcast_" + broadcastId
+}
+
+private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
+ def name = broadcastId.name + "_" + hType
+}
+
+private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
+ def name = "taskresult_" + taskId
+}
+
+private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
+ def name = "input-" + streamId + "-" + uniqueId
+}
+
+// Intended only for testing purposes
+private[spark] case class TestBlockId(id: String) extends BlockId {
+ def name = "test_" + id
+}
+
+private[spark] object BlockId {
+ val RDD = "rdd_([0-9]+)_([0-9]+)".r
+ val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+ val BROADCAST = "broadcast_([0-9]+)".r
+ val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
+ val TASKRESULT = "taskresult_([0-9]+)".r
+ val STREAM = "input-([0-9]+)-([0-9]+)".r
+ val TEST = "test_(.*)".r
+
+ /** Converts a BlockId "name" String back into a BlockId. */
+ def apply(id: String) = id match {
+ case RDD(rddId, splitIndex) =>
+ RDDBlockId(rddId.toInt, splitIndex.toInt)
+ case SHUFFLE(shuffleId, mapId, reduceId) =>
+ ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+ case BROADCAST(broadcastId) =>
+ BroadcastBlockId(broadcastId.toLong)
+ case BROADCAST_HELPER(broadcastId, hType) =>
+ BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
+ case TASKRESULT(taskId) =>
+ TaskResultBlockId(taskId.toLong)
+ case STREAM(streamId, uniqueId) =>
+ StreamBlockId(streamId.toInt, uniqueId.toLong)
+ case TEST(value) =>
+ TestBlockId(value)
+ case _ =>
+ throw new IllegalStateException("Unrecognized BlockId: " + id)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
new file mode 100644
index 0000000000..c8f397609a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.ConcurrentHashMap
+
+private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
+ // To save space, 'pending' and 'failed' are encoded as special sizes:
+ @volatile var size: Long = BlockInfo.BLOCK_PENDING
+ private def pending: Boolean = size == BlockInfo.BLOCK_PENDING
+ private def failed: Boolean = size == BlockInfo.BLOCK_FAILED
+ private def initThread: Thread = BlockInfo.blockInfoInitThreads.get(this)
+
+ setInitThread()
+
+ private def setInitThread() {
+ // Set current thread as init thread - waitForReady will not block this thread
+ // (in case there is non trivial initialization which ends up calling waitForReady as part of
+ // initialization itself)
+ BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread())
+ }
+
+ /**
+ * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
+ * Return true if the block is available, false otherwise.
+ */
+ def waitForReady(): Boolean = {
+ if (pending && initThread != Thread.currentThread()) {
+ synchronized {
+ while (pending) this.wait()
+ }
+ }
+ !failed
+ }
+
+ /** Mark this BlockInfo as ready (i.e. block is finished writing) */
+ def markReady(sizeInBytes: Long) {
+ require (sizeInBytes >= 0, "sizeInBytes was negative: " + sizeInBytes)
+ assert (pending)
+ size = sizeInBytes
+ BlockInfo.blockInfoInitThreads.remove(this)
+ synchronized {
+ this.notifyAll()
+ }
+ }
+
+ /** Mark this BlockInfo as ready but failed */
+ def markFailure() {
+ assert (pending)
+ size = BlockInfo.BLOCK_FAILED
+ BlockInfo.blockInfoInitThreads.remove(this)
+ synchronized {
+ this.notifyAll()
+ }
+ }
+}
+
+private object BlockInfo {
+ // initThread is logically a BlockInfo field, but we store it here because
+ // it's only needed while this block is in the 'pending' state and we want
+ // to minimize BlockInfo's memory footprint.
+ private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread]
+
+ private val BLOCK_PENDING: Long = -1L
+ private val BLOCK_FAILED: Long = -2L
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 7852849ce5..252329c4e1 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,17 +17,18 @@
package org.apache.spark.storage
-import java.io.{InputStream, OutputStream}
+import java.io.{File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
+import scala.collection.mutable.{HashMap, ArrayBuffer}
+import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration.Duration
import scala.concurrent.duration._
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
@@ -37,7 +38,6 @@ import org.apache.spark.util._
import sun.nio.ch.DirectBuffer
-
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -46,74 +46,20 @@ private[spark] class BlockManager(
maxMemory: Long)
extends Logging {
- private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
- @volatile var pending: Boolean = true
- @volatile var size: Long = -1L
- @volatile var initThread: Thread = null
- @volatile var failed = false
-
- setInitThread()
-
- private def setInitThread() {
- // Set current thread as init thread - waitForReady will not block this thread
- // (in case there is non trivial initialization which ends up calling waitForReady as part of
- // initialization itself)
- this.initThread = Thread.currentThread()
- }
-
- /**
- * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
- * Return true if the block is available, false otherwise.
- */
- def waitForReady(): Boolean = {
- if (initThread != Thread.currentThread() && pending) {
- synchronized {
- while (pending) this.wait()
- }
- }
- !failed
- }
-
- /** Mark this BlockInfo as ready (i.e. block is finished writing) */
- def markReady(sizeInBytes: Long) {
- assert (pending)
- size = sizeInBytes
- initThread = null
- failed = false
- initThread = null
- pending = false
- synchronized {
- this.notifyAll()
- }
- }
-
- /** Mark this BlockInfo as ready but failed */
- def markFailure() {
- assert (pending)
- size = 0
- initThread = null
- failed = true
- initThread = null
- pending = false
- synchronized {
- this.notifyAll()
- }
- }
- }
-
val shuffleBlockManager = new ShuffleBlockManager(this)
+ val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
- private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
+ private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: DiskStore =
- new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+ private[storage] val diskStore = new DiskStore(this, diskBlockManager)
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
- if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
+ if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
val connectionManager = new ConnectionManager(0)
@@ -154,7 +100,8 @@ private[spark] class BlockManager(
var heartBeatTask: Cancellable = null
- val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks)
+ private val metadataCleaner = new MetadataCleaner(MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks)
+ private val broadcastCleaner = new MetadataCleaner(MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks)
initialize()
// The compression codec to use. Note that the "lazy" val is necessary because we want to delay
@@ -248,7 +195,7 @@ private[spark] class BlockManager(
/**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
- def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
+ def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
/**
* Tell the master about the current storage status of a block. This will send a block update
@@ -258,7 +205,7 @@ private[spark] class BlockManager(
* droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid).
* This ensures that update in master will compensate for the increase in memory on slave.
*/
- def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) {
+ def reportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L) {
val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
@@ -269,11 +216,11 @@ private[spark] class BlockManager(
}
/**
- * Actually send a UpdateBlockInfo message. Returns the mater's response,
+ * Actually send a UpdateBlockInfo message. Returns the master's response,
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
- private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
+ private def tryToReportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
info.level match {
case null =>
@@ -298,7 +245,7 @@ private[spark] class BlockManager(
/**
* Get locations of an array of blocks.
*/
- def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = {
+ def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = {
val startTimeMs = System.currentTimeMillis
val locations = master.getLocations(blockIds).toArray
logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
@@ -310,7 +257,7 @@ private[spark] class BlockManager(
* shuffle blocks. It is safe to do so without a lock on block info since disk store
* never deletes (recent) items.
*/
- def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
diskStore.getValues(blockId, serializer).orElse(
sys.error("Block " + blockId + " not found on disk, though it should be"))
}
@@ -318,94 +265,19 @@ private[spark] class BlockManager(
/**
* Get block from local block manager.
*/
- def getLocal(blockId: String): Option[Iterator[Any]] = {
+ def getLocal(blockId: BlockId): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
- val info = blockInfo.get(blockId).orNull
- if (info != null) {
- info.synchronized {
-
- // In the another thread is writing the block, wait for it to become ready.
- if (!info.waitForReady()) {
- // If we get here, the block write failed.
- logWarning("Block " + blockId + " was marked as failure.")
- return None
- }
-
- val level = info.level
- logDebug("Level for block " + blockId + " is " + level)
-
- // Look for the block in memory
- if (level.useMemory) {
- logDebug("Getting block " + blockId + " from memory")
- memoryStore.getValues(blockId) match {
- case Some(iterator) =>
- return Some(iterator)
- case None =>
- logDebug("Block " + blockId + " not found in memory")
- }
- }
-
- // Look for block on disk, potentially loading it back into memory if required
- if (level.useDisk) {
- logDebug("Getting block " + blockId + " from disk")
- if (level.useMemory && level.deserialized) {
- diskStore.getValues(blockId) match {
- case Some(iterator) =>
- // Put the block back in memory before returning it
- // TODO: Consider creating a putValues that also takes in a iterator ?
- val elements = new ArrayBuffer[Any]
- elements ++= iterator
- memoryStore.putValues(blockId, elements, level, true).data match {
- case Left(iterator2) =>
- return Some(iterator2)
- case _ =>
- throw new Exception("Memory store did not return back an iterator")
- }
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- } else if (level.useMemory && !level.deserialized) {
- // Read it as a byte buffer into memory first, then return it
- diskStore.getBytes(blockId) match {
- case Some(bytes) =>
- // Put a copy of the block back in memory before returning it. Note that we can't
- // put the ByteBuffer returned by the disk store as that's a memory-mapped file.
- // The use of rewind assumes this.
- assert (0 == bytes.position())
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- memoryStore.putBytes(blockId, copyForMemory, level)
- bytes.rewind()
- return Some(dataDeserialize(blockId, bytes))
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- } else {
- diskStore.getValues(blockId) match {
- case Some(iterator) =>
- return Some(iterator)
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- }
- }
- }
- } else {
- logDebug("Block " + blockId + " not registered locally")
- }
- return None
+ doGetLocal(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
}
/**
* Get block from the local block manager as serialized bytes.
*/
- def getLocalBytes(blockId: String): Option[ByteBuffer] = {
- // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow
+ def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = {
logDebug("Getting local block " + blockId + " as bytes")
-
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
- if (ShuffleBlockManager.isShuffle(blockId)) {
+ if (blockId.isShuffle) {
return diskStore.getBytes(blockId) match {
case Some(bytes) =>
Some(bytes)
@@ -413,12 +285,15 @@ private[spark] class BlockManager(
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
}
+ doGetLocal(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
+ }
+ private def doGetLocal(blockId: BlockId, asValues: Boolean): Option[Any] = {
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
- // In the another thread is writing the block, wait for it to become ready.
+ // If another thread is writing the block, wait for it to become ready.
if (!info.waitForReady()) {
// If we get here, the block write failed.
logWarning("Block " + blockId + " was marked as failure.")
@@ -431,62 +306,104 @@ private[spark] class BlockManager(
// Look for the block in memory
if (level.useMemory) {
logDebug("Getting block " + blockId + " from memory")
- memoryStore.getBytes(blockId) match {
- case Some(bytes) =>
- return Some(bytes)
+ val result = if (asValues) {
+ memoryStore.getValues(blockId)
+ } else {
+ memoryStore.getBytes(blockId)
+ }
+ result match {
+ case Some(values) =>
+ return Some(values)
case None =>
logDebug("Block " + blockId + " not found in memory")
}
}
- // Look for block on disk
+ // Look for block on disk, potentially storing it back into memory if required:
if (level.useDisk) {
- // Read it as a byte buffer into memory first, then return it
- diskStore.getBytes(blockId) match {
- case Some(bytes) =>
- assert (0 == bytes.position())
- if (level.useMemory) {
- if (level.deserialized) {
- memoryStore.putBytes(blockId, bytes, level)
- } else {
- // The memory store will hang onto the ByteBuffer, so give it a copy instead of
- // the memory-mapped file buffer we got from the disk store
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- memoryStore.putBytes(blockId, copyForMemory, level)
- }
- }
- bytes.rewind()
- return Some(bytes)
+ logDebug("Getting block " + blockId + " from disk")
+ val bytes: ByteBuffer = diskStore.getBytes(blockId) match {
+ case Some(bytes) => bytes
case None =>
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
+ assert (0 == bytes.position())
+
+ if (!level.useMemory) {
+ // If the block shouldn't be stored in memory, we can just return it:
+ if (asValues) {
+ return Some(dataDeserialize(blockId, bytes))
+ } else {
+ return Some(bytes)
+ }
+ } else {
+ // Otherwise, we also have to store something in the memory store:
+ if (!level.deserialized || !asValues) {
+ // We'll store the bytes in memory if the block's storage level includes
+ // "memory serialized", or if it should be cached as objects in memory
+ // but we only requested its serialized bytes:
+ val copyForMemory = ByteBuffer.allocate(bytes.limit)
+ copyForMemory.put(bytes)
+ memoryStore.putBytes(blockId, copyForMemory, level)
+ bytes.rewind()
+ }
+ if (!asValues) {
+ return Some(bytes)
+ } else {
+ val values = dataDeserialize(blockId, bytes)
+ if (level.deserialized) {
+ // Cache the values before returning them:
+ // TODO: Consider creating a putValues that also takes in a iterator?
+ val valuesBuffer = new ArrayBuffer[Any]
+ valuesBuffer ++= values
+ memoryStore.putValues(blockId, valuesBuffer, level, true).data match {
+ case Left(values2) =>
+ return Some(values2)
+ case _ =>
+ throw new Exception("Memory store did not return back an iterator")
+ }
+ } else {
+ return Some(values)
+ }
+ }
+ }
}
}
} else {
logDebug("Block " + blockId + " not registered locally")
}
- return None
+ None
}
/**
* Get block from remote block managers.
*/
- def getRemote(blockId: String): Option[Iterator[Any]] = {
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
+ def getRemote(blockId: BlockId): Option[Iterator[Any]] = {
logDebug("Getting remote block " + blockId)
- // Get locations of block
- val locations = master.getLocations(blockId)
+ doGetRemote(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
+ }
- // Get block from remote locations
+ /**
+ * Get block from remote block managers as serialized bytes.
+ */
+ def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
+ logDebug("Getting remote block " + blockId + " as bytes")
+ doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
+ }
+
+ private def doGetRemote(blockId: BlockId, asValues: Boolean): Option[Any] = {
+ require(blockId != null, "BlockId is null")
+ val locations = Random.shuffle(master.getLocations(blockId))
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
- return Some(dataDeserialize(blockId, data))
+ if (asValues) {
+ return Some(dataDeserialize(blockId, data))
+ } else {
+ return Some(data)
+ }
}
logDebug("The value of block " + blockId + " is null")
}
@@ -495,34 +412,9 @@ private[spark] class BlockManager(
}
/**
- * Get block from remote block managers as serialized bytes.
- */
- def getRemoteBytes(blockId: String): Option[ByteBuffer] = {
- // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be
- // refactored.
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- logDebug("Getting remote block " + blockId + " as bytes")
-
- val locations = master.getLocations(blockId)
- for (loc <- locations) {
- logDebug("Getting remote block " + blockId + " from " + loc)
- val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
- if (data != null) {
- return Some(data)
- }
- logDebug("The value of block " + blockId + " is null")
- }
- logDebug("Block " + blockId + " not found")
- return None
- }
-
- /**
* Get a block from the block manager (either local or remote).
*/
- def get(blockId: String): Option[Iterator[Any]] = {
+ def get(blockId: BlockId): Option[Iterator[Any]] = {
val local = getLocal(blockId)
if (local.isDefined) {
logInfo("Found block %s locally".format(blockId))
@@ -543,7 +435,7 @@ private[spark] class BlockManager(
* so that we can control the maxMegabytesInFlight for the fetch.
*/
def getMultiple(
- blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
val iter =
@@ -557,7 +449,7 @@ private[spark] class BlockManager(
iter
}
- def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
+ def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
: Long = {
val elements = new ArrayBuffer[Any]
elements ++= values
@@ -566,35 +458,38 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
+ * The Block will be appended to the File specified by filename.
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
- def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
- val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
- writer.registerCloseEventHandler(() => {
- val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
- blockInfo.put(blockId, myInfo)
- myInfo.markReady(writer.size())
- })
- writer
+ val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
+ new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
}
/**
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
- def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
- tellMaster: Boolean = true) : Long = {
+ def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
+ tellMaster: Boolean = true) : Long = {
+ require(values != null, "Values is null")
+ doPut(blockId, Left(values), level, tellMaster)
+ }
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (values == null) {
- throw new IllegalArgumentException("Values is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
+ /**
+ * Put a new block of serialized bytes to the block manager.
+ */
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel,
+ tellMaster: Boolean = true) {
+ require(bytes != null, "Bytes is null")
+ doPut(blockId, Right(bytes), level, tellMaster)
+ }
+
+ private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer],
+ level: StorageLevel, tellMaster: Boolean = true): Long = {
+ require(blockId != null, "BlockId is null")
+ require(level != null && level.isValid, "StorageLevel is null or invalid")
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
@@ -610,7 +505,8 @@ private[spark] class BlockManager(
return oldBlockOpt.get.size
}
- // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ // TODO: So the block info exists - but previous attempt to load it (?) failed.
+ // What do we do now ? Retry on it ?
oldBlockOpt.get
} else {
tinfo
@@ -619,10 +515,10 @@ private[spark] class BlockManager(
val startTimeMs = System.currentTimeMillis
- // If we need to replicate the data, we'll want access to the values, but because our
- // put will read the whole iterator, there will be no values left. For the case where
- // the put serializes data, we'll remember the bytes, above; but for the case where it
- // doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
+ // If we're storing values and we need to replicate the data, we'll want access to the values,
+ // but because our put will read the whole iterator, there will be no values left. For the
+ // case where the put serializes data, we'll remember the bytes, above; but for the case where
+ // it doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
var valuesAfterPut: Iterator[Any] = null
// Ditto for the bytes after the put
@@ -631,30 +527,51 @@ private[spark] class BlockManager(
// Size of the block in bytes (to return to caller)
var size = 0L
+ // If we're storing bytes, then initiate the replication before storing them locally.
+ // This is faster as data is already serialized and ready to send.
+ val replicationFuture = if (data.isRight && level.replication > 1) {
+ val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper
+ Future {
+ replicate(blockId, bufferView, level)
+ }
+ } else {
+ null
+ }
+
myInfo.synchronized {
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
var marked = false
try {
- if (level.useMemory) {
- // Save it just to memory first, even if it also has useDisk set to true; we will later
- // drop it to disk if the memory store can't hold it.
- val res = memoryStore.putValues(blockId, values, level, true)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case Left(newIterator) => valuesAfterPut = newIterator
+ data match {
+ case Left(values) => {
+ if (level.useMemory) {
+ // Save it just to memory first, even if it also has useDisk set to true; we will
+ // drop it to disk later if the memory store can't hold it.
+ val res = memoryStore.putValues(blockId, values, level, true)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case Left(newIterator) => valuesAfterPut = newIterator
+ }
+ } else {
+ // Save directly to disk.
+ // Don't get back the bytes unless we replicate them.
+ val askForBytes = level.replication > 1
+ val res = diskStore.putValues(blockId, values, level, askForBytes)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
+ }
+ }
}
- } else {
- // Save directly to disk.
- // Don't get back the bytes unless we replicate them.
- val askForBytes = level.replication > 1
- val res = diskStore.putValues(blockId, values, level, askForBytes)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case _ =>
+ case Right(bytes) => {
+ bytes.rewind()
+ // Store it only in memory at first, even if useDisk is also set to true
+ (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level)
+ size = bytes.limit
}
}
@@ -679,132 +596,46 @@ private[spark] class BlockManager(
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
- // Replicate block if required
+ // Either we're storing bytes and we asynchronously started replication, or we're storing
+ // values and need to serialize and replicate them now:
if (level.replication > 1) {
- val remoteStartTime = System.currentTimeMillis
- // Serialize the block if not already done
- if (bytesAfterPut == null) {
- if (valuesAfterPut == null) {
- throw new SparkException(
- "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
- }
- bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
- }
- replicate(blockId, bytesAfterPut, level)
- logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime))
- }
- BlockManager.dispose(bytesAfterPut)
-
- return size
- }
-
-
- /**
- * Put a new block of serialized bytes to the block manager.
- */
- def putBytes(
- blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
-
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (bytes == null) {
- throw new IllegalArgumentException("Bytes is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
-
- // Remember the block's storage level so that we can correctly drop it to disk if it needs
- // to be dropped right after it got put into memory. Note, however, that other threads will
- // not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = {
- val tinfo = new BlockInfo(level, tellMaster)
- // Do atomically !
- val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
-
- if (oldBlockOpt.isDefined) {
- if (oldBlockOpt.get.waitForReady()) {
- logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return
- }
-
- // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
- oldBlockOpt.get
- } else {
- tinfo
- }
- }
-
- val startTimeMs = System.currentTimeMillis
-
- // Initiate the replication before storing it locally. This is faster as
- // data is already serialized and ready for sending
- val replicationFuture = if (level.replication > 1) {
- val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper
- Future {
- replicate(blockId, bufferView, level)
- }
- } else {
- null
- }
-
- myInfo.synchronized {
- logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
- + " to get into synchronized block")
-
- var marked = false
- try {
- if (level.useMemory) {
- // Store it only in memory at first, even if useDisk is also set to true
- bytes.rewind()
- memoryStore.putBytes(blockId, bytes, level)
- } else {
- bytes.rewind()
- diskStore.putBytes(blockId, bytes, level)
- }
-
- // assert (0 == bytes.position(), "" + bytes)
-
- // Now that the block is in either the memory or disk store, let other threads read it,
- // and tell the master about it.
- marked = true
- myInfo.markReady(bytes.limit)
- if (tellMaster) {
- reportBlockStatus(blockId, myInfo)
- }
- } finally {
- // If we failed at putting the block to memory/disk, notify other possible readers
- // that it has failed, and then remove it from the block info map.
- if (! marked) {
- // Note that the remove must happen before markFailure otherwise another thread
- // could've inserted a new BlockInfo before we remove it.
- blockInfo.remove(blockId)
- myInfo.markFailure()
- logWarning("Putting block " + blockId + " failed")
+ data match {
+ case Right(bytes) => Await.ready(replicationFuture, Duration.Inf)
+ case Left(values) => {
+ val remoteStartTime = System.currentTimeMillis
+ // Serialize the block if not already done
+ if (bytesAfterPut == null) {
+ if (valuesAfterPut == null) {
+ throw new SparkException(
+ "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
+ }
+ bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
+ }
+ replicate(blockId, bytesAfterPut, level)
+ logDebug("Put block " + blockId + " remotely took " +
+ Utils.getUsedTimeMs(remoteStartTime))
}
}
}
- // If replication had started, then wait for it to finish
- if (level.replication > 1) {
- Await.ready(replicationFuture, Duration.Inf)
- }
+ BlockManager.dispose(bytesAfterPut)
if (level.replication > 1) {
- logDebug("PutBytes for block " + blockId + " with replication took " +
+ logDebug("Put for block " + blockId + " with replication took " +
Utils.getUsedTimeMs(startTimeMs))
} else {
- logDebug("PutBytes for block " + blockId + " without replication took " +
+ logDebug("Put for block " + blockId + " without replication took " +
Utils.getUsedTimeMs(startTimeMs))
}
+
+ size
}
/**
* Replicate block to another node.
*/
var cachedPeers: Seq[BlockManagerId] = null
- private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
+ private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) {
val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
if (cachedPeers == null) {
cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
@@ -827,14 +658,14 @@ private[spark] class BlockManager(
/**
* Read a block consisting of a single object.
*/
- def getSingle(blockId: String): Option[Any] = {
+ def getSingle(blockId: BlockId): Option[Any] = {
get(blockId).map(_.next())
}
/**
* Write a block consisting of a single object.
*/
- def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
+ def putSingle(blockId: BlockId, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
put(blockId, Iterator(value), level, tellMaster)
}
@@ -842,7 +673,7 @@ private[spark] class BlockManager(
* Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
* store reaches its limit and needs to free up space.
*/
- def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) {
+ def dropFromMemory(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer]) {
logInfo("Dropping block " + blockId + " from memory")
val info = blockInfo.get(blockId).orNull
if (info != null) {
@@ -891,16 +722,15 @@ private[spark] class BlockManager(
// TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
// from RDD.id to blocks.
logInfo("Removing RDD " + rddId)
- val rddPrefix = "rdd_" + rddId + "_"
- val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1)
- blocksToRemove.foreach(blockId => removeBlock(blockId, false))
+ val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
+ blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false))
blocksToRemove.size
}
/**
* Remove a block from both memory and disk.
*/
- def removeBlock(blockId: String, tellMaster: Boolean = true) {
+ def removeBlock(blockId: BlockId, tellMaster: Boolean = true) {
logInfo("Removing block " + blockId)
val info = blockInfo.get(blockId).orNull
if (info != null) info.synchronized {
@@ -921,13 +751,22 @@ private[spark] class BlockManager(
}
}
- def dropOldBlocks(cleanupTime: Long) {
- logInfo("Dropping blocks older than " + cleanupTime)
+ private def dropOldNonBroadcastBlocks(cleanupTime: Long) {
+ logInfo("Dropping non broadcast blocks older than " + cleanupTime)
+ dropOldBlocks(cleanupTime, !_.isBroadcast)
+ }
+
+ private def dropOldBroadcastBlocks(cleanupTime: Long) {
+ logInfo("Dropping broadcast blocks older than " + cleanupTime)
+ dropOldBlocks(cleanupTime, _.isBroadcast)
+ }
+
+ private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
val iterator = blockInfo.internalMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
- if (time < cleanupTime) {
+ if (time < cleanupTime && shouldDrop(id)) {
info.synchronized {
val level = info.level
if (level.useMemory) {
@@ -944,39 +783,45 @@ private[spark] class BlockManager(
}
}
- def shouldCompress(blockId: String): Boolean = {
- if (ShuffleBlockManager.isShuffle(blockId)) {
- compressShuffle
- } else if (blockId.startsWith("broadcast_")) {
- compressBroadcast
- } else if (blockId.startsWith("rdd_")) {
- compressRdds
- } else {
- false // Won't happen in a real cluster, but it can in tests
- }
+ def shouldCompress(blockId: BlockId): Boolean = blockId match {
+ case ShuffleBlockId(_, _, _) => compressShuffle
+ case BroadcastBlockId(_) => compressBroadcast
+ case RDDBlockId(_, _) => compressRdds
+ case _ => false
}
/**
* Wrap an output stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(blockId: String, s: OutputStream): OutputStream = {
+ def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
}
/**
* Wrap an input stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(blockId: String, s: InputStream): InputStream = {
+ def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
+ /** Serializes into a stream. */
+ def dataSerializeStream(
+ blockId: BlockId,
+ outputStream: OutputStream,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer) {
+ val byteStream = new FastBufferedOutputStream(outputStream)
+ val ser = serializer.newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ }
+
+ /** Serializes into a byte buffer. */
def dataSerialize(
- blockId: String,
+ blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
- val ser = serializer.newInstance()
- ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ dataSerializeStream(blockId, byteStream, values, serializer)
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}
@@ -986,7 +831,7 @@ private[spark] class BlockManager(
* the iterator is reached.
*/
def dataDeserialize(
- blockId: String,
+ blockId: BlockId,
bytes: ByteBuffer,
serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind()
@@ -1004,6 +849,7 @@ private[spark] class BlockManager(
memoryStore.clear()
diskStore.clear()
metadataCleaner.cancel()
+ broadcastCleaner.cancel()
logInfo("BlockManager stopped")
}
}
@@ -1041,10 +887,10 @@ private[spark] object BlockManager extends Logging {
}
def blockIdsToBlockManagers(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[BlockManagerId]] =
+ : Map[BlockId, Seq[BlockManagerId]] =
{
// env == null and blockManagerMaster != null is used in tests
assert (env != null || blockManagerMaster != null)
@@ -1054,7 +900,7 @@ private[spark] object BlockManager extends Logging {
blockManagerMaster.getLocations(blockIds)
}
- val blockManagers = new HashMap[String, Seq[BlockManagerId]]
+ val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]]
for (i <- 0 until blockIds.length) {
blockManagers(blockIds(i)) = blockLocations(i)
}
@@ -1062,19 +908,19 @@ private[spark] object BlockManager extends Logging {
}
def blockIdsToExecutorIds(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[String]] =
+ : Map[BlockId, Seq[String]] =
{
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId))
}
def blockIdsToHosts(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[String]] =
+ : Map[BlockId, Seq[String]] =
{
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 0c977f05d1..48d7101b0a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -69,7 +69,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
def updateBlockInfo(
blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): Boolean = {
@@ -80,12 +80,12 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
}
/** Get locations of the blockId from the driver */
- def getLocations(blockId: String): Seq[BlockManagerId] = {
+ def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
}
/** Get locations of multiple blockIds from the driver */
- def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
@@ -103,7 +103,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
*/
- def removeBlock(blockId: String) {
+ def removeBlock(blockId: BlockId) {
askDriverWithReply(RemoveBlock(blockId))
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 3776951782..154a3980e9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -48,7 +48,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
// Mapping from block id to the set of block managers that have the block.
- private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]]
+ private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
val akkaTimeout = Duration.create(
System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
@@ -130,10 +130,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// First remove the metadata for the given RDD, and then asynchronously remove the blocks
// from the slaves.
- val prefix = "rdd_" + rddId + "_"
// Find all blocks for the given RDD, remove the block from both blockLocations and
// the blockManagerInfo that is tracking the blocks.
- val blocks = blockLocations.keySet().filter(_.startsWith(prefix))
+ val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
blocks.foreach { blockId =>
val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId)
bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId)))
@@ -199,7 +198,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about.
- private def removeBlockFromWorkers(blockId: String) {
+ private def removeBlockFromWorkers(blockId: BlockId) {
val locations = blockLocations.get(blockId)
if (locations != null) {
locations.foreach { blockManagerId: BlockManagerId =>
@@ -229,9 +228,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
- if (id.executorId == "<driver>" && !isLocal) {
- // Got a register message from the master node; don't register it
- } else if (!blockManagerInfo.contains(id)) {
+ if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
// A block manager of the same executor already exists.
@@ -248,7 +245,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private def updateBlockInfo(
blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long) {
@@ -293,11 +290,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! true
}
- private def getLocations(blockId: String): Seq[BlockManagerId] = {
+ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
}
- private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
blockIds.map(blockId => getLocations(blockId))
}
@@ -331,7 +328,7 @@ object BlockManagerMasterActor {
private var _remainingMem: Long = maxMem
// Mapping from block id to its status.
- private val _blocks = new JHashMap[String, BlockStatus]
+ private val _blocks = new JHashMap[BlockId, BlockStatus]
logInfo("Registering block manager %s with %s RAM".format(
blockManagerId.hostPort, Utils.bytesToString(maxMem)))
@@ -340,7 +337,7 @@ object BlockManagerMasterActor {
_lastSeenMs = System.currentTimeMillis()
}
- def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long,
+ def updateBlockInfo(blockId: BlockId, storageLevel: StorageLevel, memSize: Long,
diskSize: Long) {
updateLastSeenMs()
@@ -384,7 +381,7 @@ object BlockManagerMasterActor {
}
}
- def removeBlock(blockId: String) {
+ def removeBlock(blockId: BlockId) {
if (_blocks.containsKey(blockId)) {
_remainingMem += _blocks.get(blockId).memSize
_blocks.remove(blockId)
@@ -395,7 +392,7 @@ object BlockManagerMasterActor {
def lastSeenMs: Long = _lastSeenMs
- def blocks: JHashMap[String, BlockStatus] = _blocks
+ def blocks: JHashMap[BlockId, BlockStatus] = _blocks
override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 24333a179c..45f51da288 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -30,7 +30,7 @@ private[storage] object BlockManagerMessages {
// Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about.
- case class RemoveBlock(blockId: String) extends ToBlockManagerSlave
+ case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave
// Remove all blocks belonging to a specific RDD.
case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
@@ -51,7 +51,7 @@ private[storage] object BlockManagerMessages {
class UpdateBlockInfo(
var blockManagerId: BlockManagerId,
- var blockId: String,
+ var blockId: BlockId,
var storageLevel: StorageLevel,
var memSize: Long,
var diskSize: Long)
@@ -62,7 +62,7 @@ private[storage] object BlockManagerMessages {
override def writeExternal(out: ObjectOutput) {
blockManagerId.writeExternal(out)
- out.writeUTF(blockId)
+ out.writeUTF(blockId.name)
storageLevel.writeExternal(out)
out.writeLong(memSize)
out.writeLong(diskSize)
@@ -70,7 +70,7 @@ private[storage] object BlockManagerMessages {
override def readExternal(in: ObjectInput) {
blockManagerId = BlockManagerId(in)
- blockId = in.readUTF()
+ blockId = BlockId(in.readUTF())
storageLevel = StorageLevel(in)
memSize = in.readLong()
diskSize = in.readLong()
@@ -79,7 +79,7 @@ private[storage] object BlockManagerMessages {
object UpdateBlockInfo {
def apply(blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): UpdateBlockInfo = {
@@ -87,14 +87,14 @@ private[storage] object BlockManagerMessages {
}
// For pattern-matching
- def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
+ def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, BlockId, StorageLevel, Long, Long)] = {
Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
}
}
- case class GetLocations(blockId: String) extends ToBlockManagerMaster
+ case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster
- case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster
+ case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster
case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index 951503019f..3a65e55733 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -26,6 +26,7 @@ import org.apache.spark.storage.BlockManagerMessages._
* An actor to take commands from the master to execute options. For example,
* this is used to remove blocks from the slave's BlockManager.
*/
+private[storage]
class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
override def receive = {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
index 678c38203c..0c66addf9d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
@@ -77,7 +77,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
}
}
- private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) {
+ private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
blockManager.putBytes(id, bytes, level)
@@ -85,7 +85,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
+ " with data size: " + bytes.limit)
}
- private def getBlock(id: String): ByteBuffer = {
+ private def getBlock(id: BlockId): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + id + " started from " + startTimeMs)
val buffer = blockManager.getLocalBytes(id) match {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
index d8fa6a91d1..80dcb5a207 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
@@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.network._
-private[spark] case class GetBlock(id: String)
-private[spark] case class GotBlock(id: String, data: ByteBuffer)
-private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel)
+private[spark] case class GetBlock(id: BlockId)
+private[spark] case class GotBlock(id: BlockId, data: ByteBuffer)
+private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel)
private[spark] class BlockMessage() {
// Un-initialized: typ = 0
@@ -34,7 +34,7 @@ private[spark] class BlockMessage() {
// GotBlock: typ = 2
// PutBlock: typ = 3
private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
- private var id: String = null
+ private var id: BlockId = null
private var data: ByteBuffer = null
private var level: StorageLevel = null
@@ -74,7 +74,7 @@ private[spark] class BlockMessage() {
for (i <- 1 to idLength) {
idBuilder += buffer.getChar()
}
- id = idBuilder.toString()
+ id = BlockId(idBuilder.toString)
if (typ == BlockMessage.TYPE_PUT_BLOCK) {
@@ -109,28 +109,17 @@ private[spark] class BlockMessage() {
set(buffer)
}
- def getType: Int = {
- return typ
- }
-
- def getId: String = {
- return id
- }
-
- def getData: ByteBuffer = {
- return data
- }
-
- def getLevel: StorageLevel = {
- return level
- }
+ def getType: Int = typ
+ def getId: BlockId = id
+ def getData: ByteBuffer = data
+ def getLevel: StorageLevel = level
def toBufferMessage: BufferMessage = {
val startTime = System.currentTimeMillis
val buffers = new ArrayBuffer[ByteBuffer]()
- var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2)
- buffer.putInt(typ).putInt(id.length())
- id.foreach((x: Char) => buffer.putChar(x))
+ var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2)
+ buffer.putInt(typ).putInt(id.name.length)
+ id.name.foreach((x: Char) => buffer.putChar(x))
buffer.flip()
buffers += buffer
@@ -212,7 +201,8 @@ private[spark] object BlockMessage {
def main(args: Array[String]) {
val B = new BlockMessage()
- B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
+ val blockId = TestBlockId("ABC")
+ B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
val bMsg = B.toBufferMessage
val C = new BlockMessage()
C.set(bMsg)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
index 0aaf846b5b..6ce9127c74 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
@@ -111,14 +111,15 @@ private[spark] object BlockMessageArray {
}
def main(args: Array[String]) {
- val blockMessages =
+ val blockMessages =
(0 until 10).map { i =>
if (i % 2 == 0) {
val buffer = ByteBuffer.allocate(100)
buffer.clear
- BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER))
+ BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer,
+ StorageLevel.MEMORY_ONLY_SER))
} else {
- BlockMessage.fromGetBlock(GetBlock(i.toString))
+ BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString)))
}
}
val blockMessageArray = new BlockMessageArray(blockMessages)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 39f103297f..469e68fed7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -17,6 +17,13 @@
package org.apache.spark.storage
+import java.io.{FileOutputStream, File, OutputStream}
+import java.nio.channels.FileChannel
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import org.apache.spark.Logging
+import org.apache.spark.serializer.{SerializationStream, Serializer}
/**
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -25,22 +32,14 @@ package org.apache.spark.storage
*
* This interface does not support concurrent writes.
*/
-abstract class BlockObjectWriter(val blockId: String) {
-
- var closeEventHandler: () => Unit = _
+abstract class BlockObjectWriter(val blockId: BlockId) {
def open(): BlockObjectWriter
- def close() {
- closeEventHandler()
- }
+ def close()
def isOpen: Boolean
- def registerCloseEventHandler(handler: () => Unit) {
- closeEventHandler = handler
- }
-
/**
* Flush the partial writes and commit them as a single atomic block. Return the
* number of bytes written for this commit.
@@ -59,7 +58,126 @@ abstract class BlockObjectWriter(val blockId: String) {
def write(value: Any)
/**
- * Size of the valid writes, in bytes.
+ * Returns the file segment of committed data that this Writer has written.
+ */
+ def fileSegment(): FileSegment
+
+ /**
+ * Cumulative time spent performing blocking writes, in ns.
*/
- def size(): Long
+ def timeWriting(): Long
+}
+
+/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
+class DiskBlockObjectWriter(
+ blockId: BlockId,
+ file: File,
+ serializer: Serializer,
+ bufferSize: Int,
+ compressStream: OutputStream => OutputStream)
+ extends BlockObjectWriter(blockId)
+ with Logging
+{
+
+ /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
+ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
+ def timeWriting = _timeWriting
+ private var _timeWriting = 0L
+
+ private def callWithTiming(f: => Unit) = {
+ val start = System.nanoTime()
+ f
+ _timeWriting += (System.nanoTime() - start)
+ }
+
+ def write(i: Int): Unit = callWithTiming(out.write(i))
+ override def write(b: Array[Byte]) = callWithTiming(out.write(b))
+ override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
+ }
+
+ private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
+
+ /** The file channel, used for repositioning / truncating the file. */
+ private var channel: FileChannel = null
+ private var bs: OutputStream = null
+ private var fos: FileOutputStream = null
+ private var ts: TimeTrackingOutputStream = null
+ private var objOut: SerializationStream = null
+ private val initialPosition = file.length()
+ private var lastValidPosition = initialPosition
+ private var initialized = false
+ private var _timeWriting = 0L
+
+ override def open(): BlockObjectWriter = {
+ fos = new FileOutputStream(file, true)
+ ts = new TimeTrackingOutputStream(fos)
+ channel = fos.getChannel()
+ lastValidPosition = initialPosition
+ bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
+ objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
+ this
+ }
+
+ override def close() {
+ if (initialized) {
+ if (syncWrites) {
+ // Force outstanding writes to disk and track how long it takes
+ objOut.flush()
+ val start = System.nanoTime()
+ fos.getFD.sync()
+ _timeWriting += System.nanoTime() - start
+ }
+ objOut.close()
+
+ _timeWriting += ts.timeWriting
+
+ channel = null
+ bs = null
+ fos = null
+ ts = null
+ objOut = null
+ }
+ }
+
+ override def isOpen: Boolean = objOut != null
+
+ override def commit(): Long = {
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
+ }
+
+ override def revertPartialWrites() {
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
+ }
+
+ override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
+ objOut.writeObject(value)
+ }
+
+ override def fileSegment(): FileSegment = {
+ val bytesWritten = lastValidPosition - initialPosition
+ new FileSegment(file, initialPosition, bytesWritten)
+ }
+
+ // Only valid if called after close()
+ override def timeWriting() = _timeWriting
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
index fa834371f4..ea42656240 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
@@ -27,7 +27,7 @@ import org.apache.spark.Logging
*/
private[spark]
abstract class BlockStore(val blockManager: BlockManager) extends Logging {
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel)
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel)
/**
* Put in a block and, possibly, also return its content as either bytes or another Iterator.
@@ -36,26 +36,26 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
* @return a PutResult that contains the size of the data, as well as the values put if
* returnValues is true (if not, the result's data field can be null)
*/
- def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
+ def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
returnValues: Boolean) : PutResult
/**
* Return the size of a block in bytes.
*/
- def getSize(blockId: String): Long
+ def getSize(blockId: BlockId): Long
- def getBytes(blockId: String): Option[ByteBuffer]
+ def getBytes(blockId: BlockId): Option[ByteBuffer]
- def getValues(blockId: String): Option[Iterator[Any]]
+ def getValues(blockId: BlockId): Option[Iterator[Any]]
/**
* Remove a block, if it exists.
* @param blockId the block to remove.
* @return True if the block was found and removed, False otherwise.
*/
- def remove(blockId: String): Boolean
+ def remove(blockId: BlockId): Boolean
- def contains(blockId: String): Boolean
+ def contains(blockId: BlockId): Boolean
def clear() { }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
new file mode 100644
index 0000000000..fcd2e97982
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+import java.text.SimpleDateFormat
+import java.util.{Date, Random}
+
+import org.apache.spark.Logging
+import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
+import org.apache.spark.util.Utils
+
+/**
+ * Creates and maintains the logical mapping between logical blocks and physical on-disk
+ * locations. By default, one block is mapped to one file with a name given by its BlockId.
+ * However, it is also possible to have a block map to only a segment of a file, by calling
+ * mapBlockToFileSegment().
+ *
+ * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
+ */
+private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String)
+ extends PathResolver with Logging {
+
+ private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+
+ // Create one local directory for each path mentioned in spark.local.dir; then, inside this
+ // directory, create multiple subdirectories that we will hash files into, in order to avoid
+ // having really large inodes at the top level.
+ private val localDirs: Array[File] = createLocalDirs()
+ private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ private var shuffleSender : ShuffleSender = null
+
+ addShutdownHook()
+
+ /**
+ * Returns the phyiscal file segment in which the given BlockId is located.
+ * If the BlockId has been mapped to a specific FileSegment, that will be returned.
+ * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
+ */
+ def getBlockLocation(blockId: BlockId): FileSegment = {
+ if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
+ shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
+ } else {
+ val file = getFile(blockId.name)
+ new FileSegment(file, 0, file.length())
+ }
+ }
+
+ def getFile(filename: String): File = {
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = Utils.nonNegativeHash(filename)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+
+ // Create the subdirectory if it doesn't already exist
+ var subDir = subDirs(dirId)(subDirId)
+ if (subDir == null) {
+ subDir = subDirs(dirId).synchronized {
+ val old = subDirs(dirId)(subDirId)
+ if (old != null) {
+ old
+ } else {
+ val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ newDir.mkdir()
+ subDirs(dirId)(subDirId) = newDir
+ newDir
+ }
+ }
+ }
+
+ new File(subDir, filename)
+ }
+
+ def getFile(blockId: BlockId): File = getFile(blockId.name)
+
+ private def createLocalDirs(): Array[File] = {
+ logDebug("Creating local directories at root dirs '" + rootDirs + "'")
+ val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirId: String = null
+ var tries = 0
+ val rand = new Random()
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
+ localDir = new File(rootDir, "spark-local-" + localDirId)
+ if (!localDir.exists) {
+ foundLocalDir = localDir.mkdirs()
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
+ " attempts to create local dir in " + rootDir)
+ System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
+ }
+ logInfo("Created local directory at " + localDir)
+ localDir
+ }
+ }
+
+ private def addShutdownHook() {
+ localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ localDirs.foreach { localDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
+ }
+ }
+
+ if (shuffleSender != null) {
+ shuffleSender.stop()
+ }
+ }
+ })
+ }
+
+ private[storage] def startShuffleBlockSender(port: Int): Int = {
+ shuffleSender = new ShuffleSender(port, this)
+ logInfo("Created ShuffleSender binding to port : " + shuffleSender.port)
+ shuffleSender.port
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 63447baf8c..5a1e7b4444 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,153 +17,46 @@
package org.apache.spark.storage
-import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
+import java.io.{FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
-import java.util.{Random, Date}
-import java.text.SimpleDateFormat
import scala.collection.mutable.ArrayBuffer
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.serializer.{Serializer, SerializationStream}
import org.apache.spark.Logging
-import org.apache.spark.network.netty.ShuffleSender
-import org.apache.spark.network.netty.PathResolver
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
/**
* Stores BlockManager blocks on disk.
*/
-private class DiskStore(blockManager: BlockManager, rootDirs: String)
+private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
- class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int)
- extends BlockObjectWriter(blockId) {
-
- private val f: File = createFile(blockId /*, allowAppendExisting */)
-
- // The file channel, used for repositioning / truncating the file.
- private var channel: FileChannel = null
- private var bs: OutputStream = null
- private var objOut: SerializationStream = null
- private var lastValidPosition = 0L
- private var initialized = false
-
- override def open(): DiskBlockObjectWriter = {
- val fos = new FileOutputStream(f, true)
- channel = fos.getChannel()
- bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
- objOut = serializer.newInstance().serializeStream(bs)
- initialized = true
- this
- }
-
- override def close() {
- if (initialized) {
- objOut.close()
- channel = null
- bs = null
- objOut = null
- }
- // Invoke the close callback handler.
- super.close()
- }
-
- override def isOpen: Boolean = objOut != null
-
- // Flush the partial writes, and set valid length to be the length of the entire file.
- // Return the number of bytes written for this commit.
- override def commit(): Long = {
- if (initialized) {
- // NOTE: Flush the serializer first and then the compressed/buffered output stream
- objOut.flush()
- bs.flush()
- val prevPos = lastValidPosition
- lastValidPosition = channel.position()
- lastValidPosition - prevPos
- } else {
- // lastValidPosition is zero if stream is uninitialized
- lastValidPosition
- }
- }
-
- override def revertPartialWrites() {
- if (initialized) {
- // Discard current writes. We do this by flushing the outstanding writes and
- // truncate the file to the last valid position.
- objOut.flush()
- bs.flush()
- channel.truncate(lastValidPosition)
- }
- }
-
- override def write(value: Any) {
- if (!initialized) {
- open()
- }
- objOut.writeObject(value)
- }
-
- override def size(): Long = lastValidPosition
- }
-
- private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
-
- private var shuffleSender : ShuffleSender = null
- // Create one local directory for each path mentioned in spark.local.dir; then, inside this
- // directory, create multiple subdirectories that we will hash files into, in order to avoid
- // having really large inodes at the top level.
- private val localDirs: Array[File] = createLocalDirs()
- private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
-
- addShutdownHook()
-
- def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
- : BlockObjectWriter = {
- new DiskBlockObjectWriter(blockId, serializer, bufferSize)
+ override def getSize(blockId: BlockId): Long = {
+ diskManager.getBlockLocation(blockId).length
}
- override def getSize(blockId: String): Long = {
- getFile(blockId).length()
- }
-
- override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
// So that we do not modify the input offsets !
// duplicate does not copy buffer, so inexpensive
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val channel = new RandomAccessFile(file, "rw").getChannel()
+ val file = diskManager.getFile(blockId)
+ val channel = new FileOutputStream(file).getChannel()
while (bytes.remaining > 0) {
channel.write(bytes)
}
channel.close()
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
- }
-
- private def getFileBytes(file: File): ByteBuffer = {
- val length = file.length()
- val channel = new RandomAccessFile(file, "r").getChannel()
- val buffer = try {
- channel.map(MapMode.READ_ONLY, 0, length)
- } finally {
- channel.close()
- }
-
- buffer
+ file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
}
override def putValues(
- blockId: String,
+ blockId: BlockId,
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
@@ -171,159 +64,62 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val fileOut = blockManager.wrapForCompression(blockId,
- new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
- objOut.writeAll(values.iterator)
- objOut.close()
- val length = file.length()
+ val file = diskManager.getFile(blockId)
+ val outputStream = new FileOutputStream(file)
+ blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
+ val length = file.length
val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(length), timeTaken))
+ file.getName, Utils.bytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
- val buffer = getFileBytes(file)
+ val buffer = getBytes(blockId).get
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
}
}
- override def getBytes(blockId: String): Option[ByteBuffer] = {
- val file = getFile(blockId)
- val bytes = getFileBytes(file)
- Some(bytes)
+ override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
+ val segment = diskManager.getBlockLocation(blockId)
+ val channel = new RandomAccessFile(segment.file, "r").getChannel()
+ val buffer = try {
+ channel.map(MapMode.READ_ONLY, segment.offset, segment.length)
+ } finally {
+ channel.close()
+ }
+ Some(buffer)
}
- override def getValues(blockId: String): Option[Iterator[Any]] = {
- getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
+ override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
+ getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
}
/**
* A version of getValues that allows a custom serializer. This is used as part of the
* shuffle short-circuit code.
*/
- def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
}
- override def remove(blockId: String): Boolean = {
- val file = getFile(blockId)
- if (file.exists()) {
+ override def remove(blockId: BlockId): Boolean = {
+ val fileSegment = diskManager.getBlockLocation(blockId)
+ val file = fileSegment.file
+ if (file.exists() && file.length() == fileSegment.length) {
file.delete()
} else {
- false
- }
- }
-
- override def contains(blockId: String): Boolean = {
- getFile(blockId).exists()
- }
-
- private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
- val file = getFile(blockId)
- if (!allowAppendExisting && file.exists()) {
- // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
- // was rescheduled on the same machine as the old task.
- logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
- file.delete()
- }
- file
- }
-
- private def getFile(blockId: String): File = {
- logDebug("Getting file for block " + blockId)
-
- // Figure out which local directory it hashes to, and which subdirectory in that
- val hash = Utils.nonNegativeHash(blockId)
- val dirId = hash % localDirs.length
- val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
-
- // Create the subdirectory if it doesn't already exist
- var subDir = subDirs(dirId)(subDirId)
- if (subDir == null) {
- subDir = subDirs(dirId).synchronized {
- val old = subDirs(dirId)(subDirId)
- if (old != null) {
- old
- } else {
- val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
- newDir.mkdir()
- subDirs(dirId)(subDirId) = newDir
- newDir
- }
- }
- }
-
- new File(subDir, blockId)
- }
-
- private def createLocalDirs(): Array[File] = {
- logDebug("Creating local directories at root dirs '" + rootDirs + "'")
- val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").map { rootDir =>
- var foundLocalDir = false
- var localDir: File = null
- var localDirId: String = null
- var tries = 0
- val rand = new Random()
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
- localDir = new File(rootDir, "spark-local-" + localDirId)
- if (!localDir.exists) {
- foundLocalDir = localDir.mkdirs()
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
- }
+ if (fileSegment.length < file.length()) {
+ logWarning("Could not delete block associated with only a part of a file: " + blockId)
}
- if (!foundLocalDir) {
- logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
- " attempts to create local dir in " + rootDir)
- System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
- }
- logInfo("Created local directory at " + localDir)
- localDir
+ false
}
}
- private def addShutdownHook() {
- localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
- override def run() {
- logDebug("Shutdown hook called")
- localDirs.foreach { localDir =>
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case t: Throwable =>
- logError("Exception while deleting local spark dir: " + localDir, t)
- }
- }
- if (shuffleSender != null) {
- shuffleSender.stop
- }
- }
- })
- }
-
- private[storage] def startShuffleBlockSender(port: Int): Int = {
- val pResolver = new PathResolver {
- override def getAbsolutePath(blockId: String): String = {
- if (!blockId.startsWith("shuffle_")) {
- return null
- }
- DiskStore.this.getFile(blockId).getAbsolutePath()
- }
- }
- shuffleSender = new ShuffleSender(port, pResolver)
- logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
- shuffleSender.port
+ override def contains(blockId: BlockId): Boolean = {
+ val file = diskManager.getBlockLocation(blockId).file
+ file.exists()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
new file mode 100644
index 0000000000..555486830a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+
+/**
+ * References a particular segment of a file (potentially the entire file),
+ * based off an offset and a length.
+ */
+private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) {
+ override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 77a39c71ed..05f676c6e2 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -32,7 +32,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
case class Entry(value: Any, size: Long, deserialized: Boolean)
- private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true)
+ private val entries = new LinkedHashMap[BlockId, Entry](32, 0.75f, true)
@volatile private var currentMemory = 0L
// Object used to ensure that only one thread is putting blocks and if necessary, dropping
// blocks from the memory store.
@@ -42,13 +42,13 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
def freeMemory: Long = maxMemory - currentMemory
- override def getSize(blockId: String): Long = {
+ override def getSize(blockId: BlockId): Long = {
entries.synchronized {
entries.get(blockId).size
}
}
- override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
// Work on a duplicate - since the original input might be used elsewhere.
val bytes = _bytes.duplicate()
bytes.rewind()
@@ -64,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
override def putValues(
- blockId: String,
+ blockId: BlockId,
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
@@ -81,7 +81,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def getBytes(blockId: String): Option[ByteBuffer] = {
+ override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val entry = entries.synchronized {
entries.get(blockId)
}
@@ -94,7 +94,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def getValues(blockId: String): Option[Iterator[Any]] = {
+ override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
val entry = entries.synchronized {
entries.get(blockId)
}
@@ -108,7 +108,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def remove(blockId: String): Boolean = {
+ override def remove(blockId: BlockId): Boolean = {
entries.synchronized {
val entry = entries.remove(blockId)
if (entry != null) {
@@ -131,14 +131,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
/**
- * Return the RDD ID that a given block ID is from, or null if it is not an RDD block.
+ * Return the RDD ID that a given block ID is from, or None if it is not an RDD block.
*/
- private def getRddId(blockId: String): String = {
- if (blockId.startsWith("rdd_")) {
- blockId.split('_')(1)
- } else {
- null
- }
+ private def getRddId(blockId: BlockId): Option[Int] = {
+ blockId.asRDDId.map(_.rddId)
}
/**
@@ -151,7 +147,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* blocks to free memory for one block, another thread may use up the freed space for
* another block.
*/
- private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = {
+ private def tryToPut(blockId: BlockId, value: Any, size: Long, deserialized: Boolean): Boolean = {
// TODO: Its possible to optimize the locking by locking entries only when selecting blocks
// to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been
// released, it must be ensured that those to-be-dropped blocks are not double counted for
@@ -195,7 +191,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* Assumes that a lock is held by the caller to ensure only one thread is dropping blocks.
* Otherwise, the freed space may fill up before the caller puts in their new value.
*/
- private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = {
+ private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): Boolean = {
logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
space, currentMemory, maxMemory))
@@ -207,7 +203,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
if (maxMemory - currentMemory < space) {
val rddToAdd = getRddId(blockIdToAdd)
- val selectedBlocks = new ArrayBuffer[String]()
+ val selectedBlocks = new ArrayBuffer[BlockId]()
var selectedMemory = 0L
// This is synchronized to ensure that the set of entries is not changed
@@ -218,7 +214,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) {
val pair = iterator.next()
val blockId = pair.getKey
- if (rddToAdd != null && rddToAdd == getRddId(blockId)) {
+ if (rddToAdd != None && rddToAdd == getRddId(blockId)) {
logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " +
"block from the same RDD")
return false
@@ -252,7 +248,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
return true
}
- override def contains(blockId: String): Boolean = {
+ override def contains(blockId: BlockId): Boolean = {
entries.synchronized { entries.containsKey(blockId) }
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 9da11efb57..2f1b049ce4 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -17,51 +17,199 @@
package org.apache.spark.storage
-import org.apache.spark.serializer.Serializer
+import java.io.File
+import java.util.concurrent.ConcurrentLinkedQueue
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConversions._
-private[spark]
-class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
+/** A group of writers for a ShuffleMapTask, one writer per reducer. */
+private[spark] trait ShuffleWriterGroup {
+ val writers: Array[BlockObjectWriter]
-private[spark]
-trait ShuffleBlocks {
- def acquireWriters(mapId: Int): ShuffleWriterGroup
- def releaseWriters(group: ShuffleWriterGroup)
+ /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
+ def releaseWriters(success: Boolean)
}
-
+/**
+ * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file
+ * per reducer (this set of files is called a ShuffleFileGroup).
+ *
+ * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
+ * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle
+ * files, it releases them for another task.
+ * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
+ * - shuffleId: The unique id given to the entire shuffle stage.
+ * - bucketId: The id of the output partition (i.e., reducer id)
+ * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
+ * time owns a particular fileId, and this id is returned to a pool when the task finishes.
+ * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length)
+ * that specifies where in a given file the actual block data is located.
+ *
+ * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping
+ * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for
+ * each block stored in each file. In order to find the location of a shuffle block, we search the
+ * files within a ShuffleFileGroups associated with the block's reducer.
+ */
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
+ // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
+ // TODO: Remove this once the shuffle file consolidation feature is stable.
+ val consolidateShuffleFiles =
+ System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
+
+ private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+
+ /**
+ * Contains all the state related to a particular shuffle. This includes a pool of unused
+ * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle.
+ */
+ private class ShuffleState() {
+ val nextFileId = new AtomicInteger(0)
+ val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+ val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+ }
- def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
- new ShuffleBlocks {
- // Get a group of writers for a map task.
- override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
- val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
- val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
- val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
+ type ShuffleId = Int
+ private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
+
+ private
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup)
+
+ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
+ new ShuffleWriterGroup {
+ shuffleStates.putIfAbsent(shuffleId, new ShuffleState())
+ private val shuffleState = shuffleStates(shuffleId)
+ private var fileGroup: ShuffleFileGroup = null
+
+ val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+ fileGroup = getUnusedFileGroup()
+ Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+ blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
+ }
+ } else {
+ Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+ val blockFile = blockManager.diskBlockManager.getFile(blockId)
+ blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
+ }
+ }
+
+ override def releaseWriters(success: Boolean) {
+ if (consolidateShuffleFiles) {
+ if (success) {
+ val offsets = writers.map(_.fileSegment().offset)
+ fileGroup.recordMapOutput(mapId, offsets)
+ }
+ recycleFileGroup(fileGroup)
}
- new ShuffleWriterGroup(mapId, writers)
}
- override def releaseWriters(group: ShuffleWriterGroup) = {
- // Nothing really to release here.
+ private def getUnusedFileGroup(): ShuffleFileGroup = {
+ val fileGroup = shuffleState.unusedFileGroups.poll()
+ if (fileGroup != null) fileGroup else newFileGroup()
+ }
+
+ private def newFileGroup(): ShuffleFileGroup = {
+ val fileId = shuffleState.nextFileId.getAndIncrement()
+ val files = Array.tabulate[File](numBuckets) { bucketId =>
+ val filename = physicalFileName(shuffleId, bucketId, fileId)
+ blockManager.diskBlockManager.getFile(filename)
+ }
+ val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files)
+ shuffleState.allFileGroups.add(fileGroup)
+ fileGroup
+ }
+
+ private def recycleFileGroup(group: ShuffleFileGroup) {
+ shuffleState.unusedFileGroups.add(group)
}
}
}
-}
+ /**
+ * Returns the physical file segment in which the given BlockId is located.
+ * This function should only be called if shuffle file consolidation is enabled, as it is
+ * an error condition if we don't find the expected block.
+ */
+ def getBlockLocation(id: ShuffleBlockId): FileSegment = {
+ // Search all file groups associated with this shuffle.
+ val shuffleState = shuffleStates(id.shuffleId)
+ for (fileGroup <- shuffleState.allFileGroups) {
+ val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId)
+ if (segment.isDefined) { return segment.get }
+ }
+ throw new IllegalStateException("Failed to find shuffle block: " + id)
+ }
+
+ private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
+ "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
+ }
+
+ private def cleanup(cleanupTime: Long) {
+ shuffleStates.clearOldValues(cleanupTime)
+ }
+}
private[spark]
object ShuffleBlockManager {
+ /**
+ * A group of shuffle files, one per reducer.
+ * A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
+ */
+ private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) {
+ /**
+ * Stores the absolute index of each mapId in the files of this group. For instance,
+ * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
+ */
+ private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
- // Returns the block id for a given shuffle block.
- def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = {
- "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId
- }
+ /**
+ * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
+ * This ordering allows us to compute block lengths by examining the following block offset.
+ * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every
+ * reducer.
+ */
+ private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) {
+ new PrimitiveVector[Long]()
+ }
+
+ def numBlocks = mapIdToIndex.size
+
+ def apply(bucketId: Int) = files(bucketId)
+
+ def recordMapOutput(mapId: Int, offsets: Array[Long]) {
+ mapIdToIndex(mapId) = numBlocks
+ for (i <- 0 until offsets.length) {
+ blockOffsetsByReducer(i) += offsets(i)
+ }
+ }
- // Returns true if the block is a shuffle block.
- def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_")
+ /** Returns the FileSegment associated with the given map task, or None if no entry exists. */
+ def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = {
+ val file = files(reducerId)
+ val blockOffsets = blockOffsetsByReducer(reducerId)
+ val index = mapIdToIndex.getOrElse(mapId, -1)
+ if (index >= 0) {
+ val offset = blockOffsets(index)
+ val length =
+ if (index + 1 < numBlocks) {
+ blockOffsets(index + 1) - offset
+ } else {
+ file.length() - offset
+ }
+ assert(length >= 0)
+ Some(new FileSegment(file, offset, length))
+ } else {
+ None
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
new file mode 100644
index 0000000000..1e4db4f66b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -0,0 +1,86 @@
+package org.apache.spark.storage
+
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{CountDownLatch, Executors}
+
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.SparkContext
+import org.apache.spark.util.Utils
+
+/**
+ * Utility for micro-benchmarking shuffle write performance.
+ *
+ * Writes simulated shuffle output from several threads and records the observed throughput.
+ */
+object StoragePerfTester {
+ def main(args: Array[String]) = {
+ /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */
+ val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g"))
+
+ /** Number of map tasks. All tasks execute concurrently. */
+ val numMaps = sys.env.get("NUM_MAPS").map(_.toInt).getOrElse(8)
+
+ /** Number of reduce splits for each map task. */
+ val numOutputSplits = sys.env.get("NUM_REDUCERS").map(_.toInt).getOrElse(500)
+
+ val recordLength = 1000 // ~1KB records
+ val totalRecords = dataSizeMb * 1000
+ val recordsPerMap = totalRecords / numMaps
+
+ val writeData = "1" * recordLength
+ val executor = Executors.newFixedThreadPool(numMaps)
+
+ System.setProperty("spark.shuffle.compress", "false")
+ System.setProperty("spark.shuffle.sync", "true")
+
+ // This is only used to instantiate a BlockManager. All thread scheduling is done manually.
+ val sc = new SparkContext("local[4]", "Write Tester")
+ val blockManager = sc.env.blockManager
+
+ def writeOutputBytes(mapId: Int, total: AtomicLong) = {
+ val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
+ new KryoSerializer())
+ val writers = shuffle.writers
+ for (i <- 1 to recordsPerMap) {
+ writers(i % numOutputSplits).write(writeData)
+ }
+ writers.map {w =>
+ w.commit()
+ total.addAndGet(w.fileSegment().length)
+ w.close()
+ }
+
+ shuffle.releaseWriters(true)
+ }
+
+ val start = System.currentTimeMillis()
+ val latch = new CountDownLatch(numMaps)
+ val totalBytes = new AtomicLong()
+ for (task <- 1 to numMaps) {
+ executor.submit(new Runnable() {
+ override def run() = {
+ try {
+ writeOutputBytes(task, totalBytes)
+ latch.countDown()
+ } catch {
+ case e: Exception =>
+ println("Exception in child thread: " + e + " " + e.getMessage)
+ System.exit(1)
+ }
+ }
+ })
+ }
+ latch.await()
+ val end = System.currentTimeMillis()
+ val time = (end - start) / 1000.0
+ val bytesPerSecond = totalBytes.get() / time
+ val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong
+
+ System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits))
+ System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile)))
+ System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong)))
+
+ executor.shutdown()
+ sc.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index 2bb7715696..1720007e4e 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -23,20 +23,24 @@ import org.apache.spark.util.Utils
private[spark]
case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
- blocks: Map[String, BlockStatus]) {
+ blocks: Map[BlockId, BlockStatus]) {
- def memUsed(blockPrefix: String = "") = {
- blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
- reduceOption(_+_).getOrElse(0l)
- }
+ def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
- def diskUsed(blockPrefix: String = "") = {
- blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize).
- reduceOption(_+_).getOrElse(0l)
- }
+ def memUsedByRDD(rddId: Int) =
+ rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
+
+ def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
+
+ def diskUsedByRDD(rddId: Int) =
+ rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
def memRemaining : Long = maxMem - memUsed()
+ def rddBlocks = blocks.flatMap {
+ case (rdd: RDDBlockId, status) => Some(rdd, status)
+ case _ => None
+ }
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
@@ -60,7 +64,7 @@ object StorageUtils {
/* Returns RDD-level information, compiled from a list of StorageStatus objects */
def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus],
sc: SparkContext) : Array[RDDInfo] = {
- rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
+ rddInfoFromBlockStatusList(storageStatusList.flatMap(_.rddBlocks).toMap[RDDBlockId, BlockStatus], sc)
}
/* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */
@@ -71,26 +75,21 @@ object StorageUtils {
}
/* Given a list of BlockStatus objets, returns information for each RDD */
- def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
+ def rddInfoFromBlockStatusList(infos: Map[RDDBlockId, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
// Group by rddId, ignore the partition name
- val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) =>
- k.substring(0,k.lastIndexOf('_'))
- }.mapValues(_.values.toArray)
+ val groupedRddBlocks = infos.groupBy { case(k, v) => k.rddId }.mapValues(_.values.toArray)
// For each RDD, generate an RDDInfo object
- val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) =>
+ val rddInfos = groupedRddBlocks.map { case (rddId, rddBlocks) =>
// Add up memory and disk sizes
val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
- // Find the id of the RDD, e.g. rdd_1 => 1
- val rddId = rddKey.split("_").last.toInt
-
// Get the friendly name and storage level for the RDD, if available
sc.persistentRdds.get(rddId).map { r =>
- val rddName = Option(r.name).getOrElse(rddKey)
+ val rddName = Option(r.name).getOrElse(rddId.toString)
val rddStorageLevel = r.getStorageLevel
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize)
}
@@ -101,16 +100,14 @@ object StorageUtils {
rddInfos
}
- /* Removes all BlockStatus object that are not part of a block prefix */
- def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
- prefix: String) : Array[StorageStatus] = {
+ /* Filters storage status by a given RDD id. */
+ def filterStorageStatusByRDD(storageStatusList: Array[StorageStatus], rddId: Int)
+ : Array[StorageStatus] = {
storageStatusList.map { status =>
- val newBlocks = status.blocks.filterKeys(_.startsWith(prefix))
+ val newBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toMap[BlockId, BlockStatus]
//val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _)
StorageStatus(status.blockManagerId, status.maxMem, newBlocks)
}
-
}
-
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index f2ae8dd97d..860e680576 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -36,11 +36,11 @@ private[spark] object ThreadingTest {
val numBlocksPerProducer = 20000
private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread {
- val queue = new ArrayBlockingQueue[(String, Seq[Int])](100)
+ val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100)
override def run() {
for (i <- 1 to numBlocksPerProducer) {
- val blockId = "b-" + id + "-" + i
+ val blockId = TestBlockId("b-" + id + "-" + i)
val blockSize = Random.nextInt(1000)
val block = (1 to blockSize).map(_ => Random.nextInt())
val level = randomLevel()
@@ -64,7 +64,7 @@ private[spark] object ThreadingTest {
private[spark] class ConsumerThread(
manager: BlockManager,
- queue: ArrayBlockingQueue[(String, Seq[Int])]
+ queue: ArrayBlockingQueue[(BlockId, Seq[Int])]
) extends Thread {
var numBlockConsumed = 0
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index 453394dfda..fcd1b518d0 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -35,7 +35,7 @@ private[spark] object UIWorkloadGenerator {
def main(args: Array[String]) {
if (args.length < 2) {
- println("usage: ./spark-class spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
+ println("usage: ./spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
System.exit(1)
}
val master = args(0)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
index b39c0e9769..ca5a28625b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
@@ -38,7 +38,7 @@ private[spark] class IndexPage(parent: JobProgressUI) {
val now = System.currentTimeMillis()
var activeTime = 0L
- for (tasks <- listener.stageToTasksActive.values; t <- tasks) {
+ for (tasks <- listener.stageIdToTasksActive.values; t <- tasks) {
activeTime += t.timeRunning(now)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index eb3b4e8522..6b854740d6 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -36,52 +36,52 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt
val DEFAULT_POOL_NAME = "default"
- val stageToPool = new HashMap[Stage, String]()
- val stageToDescription = new HashMap[Stage, String]()
- val poolToActiveStages = new HashMap[String, HashSet[Stage]]()
+ val stageIdToPool = new HashMap[Int, String]()
+ val stageIdToDescription = new HashMap[Int, String]()
+ val poolToActiveStages = new HashMap[String, HashSet[StageInfo]]()
- val activeStages = HashSet[Stage]()
- val completedStages = ListBuffer[Stage]()
- val failedStages = ListBuffer[Stage]()
+ val activeStages = HashSet[StageInfo]()
+ val completedStages = ListBuffer[StageInfo]()
+ val failedStages = ListBuffer[StageInfo]()
// Total metrics reflect metrics only for completed tasks
var totalTime = 0L
var totalShuffleRead = 0L
var totalShuffleWrite = 0L
- val stageToTime = HashMap[Int, Long]()
- val stageToShuffleRead = HashMap[Int, Long]()
- val stageToShuffleWrite = HashMap[Int, Long]()
- val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
- val stageToTasksComplete = HashMap[Int, Int]()
- val stageToTasksFailed = HashMap[Int, Int]()
- val stageToTaskInfos =
+ val stageIdToTime = HashMap[Int, Long]()
+ val stageIdToShuffleRead = HashMap[Int, Long]()
+ val stageIdToShuffleWrite = HashMap[Int, Long]()
+ val stageIdToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
+ val stageIdToTasksComplete = HashMap[Int, Int]()
+ val stageIdToTasksFailed = HashMap[Int, Int]()
+ val stageIdToTaskInfos =
HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
override def onJobStart(jobStart: SparkListenerJobStart) {}
override def onStageCompleted(stageCompleted: StageCompleted) = synchronized {
- val stage = stageCompleted.stageInfo.stage
- poolToActiveStages(stageToPool(stage)) -= stage
+ val stage = stageCompleted.stage
+ poolToActiveStages(stageIdToPool(stage.stageId)) -= stage
activeStages -= stage
completedStages += stage
trimIfNecessary(completedStages)
}
/** If stages is too large, remove and garbage collect old stages */
- def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized {
+ def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
if (stages.size > RETAINED_STAGES) {
val toRemove = RETAINED_STAGES / 10
stages.takeRight(toRemove).foreach( s => {
- stageToTaskInfos.remove(s.id)
- stageToTime.remove(s.id)
- stageToShuffleRead.remove(s.id)
- stageToShuffleWrite.remove(s.id)
- stageToTasksActive.remove(s.id)
- stageToTasksComplete.remove(s.id)
- stageToTasksFailed.remove(s.id)
- stageToPool.remove(s)
- if (stageToDescription.contains(s)) {stageToDescription.remove(s)}
+ stageIdToTaskInfos.remove(s.stageId)
+ stageIdToTime.remove(s.stageId)
+ stageIdToShuffleRead.remove(s.stageId)
+ stageIdToShuffleWrite.remove(s.stageId)
+ stageIdToTasksActive.remove(s.stageId)
+ stageIdToTasksComplete.remove(s.stageId)
+ stageIdToTasksFailed.remove(s.stageId)
+ stageIdToPool.remove(s.stageId)
+ if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)}
})
stages.trimEnd(toRemove)
}
@@ -95,63 +95,69 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val poolName = Option(stageSubmitted.properties).map {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
- stageToPool(stage) = poolName
+ stageIdToPool(stage.stageId) = poolName
val description = Option(stageSubmitted.properties).flatMap {
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
}
- description.map(d => stageToDescription(stage) = d)
+ description.map(d => stageIdToDescription(stage.stageId) = d)
- val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]())
+ val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]())
stages += stage
}
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
val sid = taskStart.task.stageId
- val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+ val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive += taskStart.taskInfo
- val taskList = stageToTaskInfos.getOrElse(
+ val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList += ((taskStart.taskInfo, None, None))
- stageToTaskInfos(sid) = taskList
+ stageIdToTaskInfos(sid) = taskList
}
-
+
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult)
+ = synchronized {
+ // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in
+ // stageToTaskInfos already has the updated status.
+ }
+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
- val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+ val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive -= taskEnd.taskInfo
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
case e: ExceptionFailure =>
- stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1
+ stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1
(Some(e), e.metrics)
case _ =>
- stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1
+ stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1
(None, Option(taskEnd.taskMetrics))
}
- stageToTime.getOrElseUpdate(sid, 0L)
+ stageIdToTime.getOrElseUpdate(sid, 0L)
val time = metrics.map(m => m.executorRunTime).getOrElse(0)
- stageToTime(sid) += time
+ stageIdToTime(sid) += time
totalTime += time
- stageToShuffleRead.getOrElseUpdate(sid, 0L)
+ stageIdToShuffleRead.getOrElseUpdate(sid, 0L)
val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s =>
s.remoteBytesRead).getOrElse(0L)
- stageToShuffleRead(sid) += shuffleRead
+ stageIdToShuffleRead(sid) += shuffleRead
totalShuffleRead += shuffleRead
- stageToShuffleWrite.getOrElseUpdate(sid, 0L)
+ stageIdToShuffleWrite.getOrElseUpdate(sid, 0L)
val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s =>
s.shuffleBytesWritten).getOrElse(0L)
- stageToShuffleWrite(sid) += shuffleWrite
+ stageIdToShuffleWrite(sid) += shuffleWrite
totalShuffleWrite += shuffleWrite
- val taskList = stageToTaskInfos.getOrElse(
+ val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList -= ((taskEnd.taskInfo, None, None))
taskList += ((taskEnd.taskInfo, metrics, failureInfo))
- stageToTaskInfos(sid) = taskList
+ stageIdToTaskInfos(sid) = taskList
}
override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
@@ -159,10 +165,15 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
case end: SparkListenerJobEnd =>
end.jobResult match {
case JobFailed(ex, Some(stage)) =>
- activeStages -= stage
- poolToActiveStages(stageToPool(stage)) -= stage
- failedStages += stage
- trimIfNecessary(failedStages)
+ /* If two jobs share a stage we could get this failure message twice. So we first
+ * check whether we've already retired this stage. */
+ val stageInfo = activeStages.filter(s => s.stageId == stage.id).headOption
+ stageInfo.foreach {s =>
+ activeStages -= s
+ poolToActiveStages(stageIdToPool(stage.id)) -= s
+ failedStages += s
+ trimIfNecessary(failedStages)
+ }
case _ =>
}
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
index 06810d8dbc..cfeeccda41 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
@@ -21,13 +21,13 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.xml.Node
-import org.apache.spark.scheduler.{Schedulable, Stage}
+import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.UIUtils
/** Table showing list of pools */
private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) {
- var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages
+ var poolToActiveStages: HashMap[String, HashSet[StageInfo]] = listener.poolToActiveStages
def toNodeSeq(): Seq[Node] = {
listener.synchronized {
@@ -35,7 +35,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
}
}
- private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node],
+ private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[StageInfo]]) => Seq[Node],
rows: Seq[Schedulable]
): Seq[Node] = {
<table class="table table-bordered table-striped table-condensed sortable table-fixed">
@@ -53,7 +53,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
</table>
}
- private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]])
+ private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[StageInfo]])
: Seq[Node] = {
val activeStages = poolToActiveStages.get(p.name) match {
case Some(stages) => stages.size
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 163a3746ea..35b5d5fd59 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -40,7 +40,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val stageId = request.getParameter("id").toInt
val now = System.currentTimeMillis()
- if (!listener.stageToTaskInfos.contains(stageId)) {
+ if (!listener.stageIdToTaskInfos.contains(stageId)) {
val content =
<div>
<h4>Summary Metrics</h4> No tasks have started yet
@@ -49,23 +49,23 @@ private[spark] class StagePage(parent: JobProgressUI) {
return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages)
}
- val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
+ val tasks = listener.stageIdToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
val numCompleted = tasks.count(_._1.finished)
- val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L)
+ val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L)
val hasShuffleRead = shuffleReadBytes > 0
- val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L)
+ val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L)
val hasShuffleWrite = shuffleWriteBytes > 0
var activeTime = 0L
- listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
+ listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
val summary =
<div>
<ul class="unstyled">
<li>
<strong>CPU time: </strong>
- {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)}
+ {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
</li>
{if (hasShuffleRead)
<li>
@@ -83,10 +83,10 @@ private[spark] class StagePage(parent: JobProgressUI) {
</div>
val taskHeaders: Seq[String] =
- Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++
- Seq("GC Time") ++
+ Seq("Task Index", "Task ID", "Status", "Locality Level", "Executor", "Launch Time") ++
+ Seq("Duration", "GC Time") ++
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
- {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++
+ {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++
Seq("Errors")
val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks)
@@ -153,6 +153,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
<tr>
+ <td>{info.index}</td>
<td>{info.taskId}</td>
<td>{info.status}</td>
<td>{info.taskLocality}</td>
@@ -169,6 +170,8 @@ private[spark] class StagePage(parent: JobProgressUI) {
Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td>
}}
{if (shuffleWrite) {
+ <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+ parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td>
<td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td>
}}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 07db8622da..d7d0441c38 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -22,13 +22,13 @@ import java.util.Date
import scala.xml.Node
import scala.collection.mutable.HashSet
-import org.apache.spark.scheduler.{SchedulingMode, Stage, TaskInfo}
+import org.apache.spark.scheduler.{SchedulingMode, StageInfo, TaskInfo}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
/** Page showing list of all ongoing and recently finished stages */
-private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) {
+private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgressUI) {
val listener = parent.listener
val dateFmt = parent.dateFmt
@@ -73,40 +73,40 @@ private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressU
}
- private def stageRow(s: Stage): Seq[Node] = {
+ private def stageRow(s: StageInfo): Seq[Node] = {
val submissionTime = s.submissionTime match {
case Some(t) => dateFmt.format(new Date(t))
case None => "Unknown"
}
- val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match {
+ val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match {
+ val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size
- val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0)
- val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match {
+ val startedTasks = listener.stageIdToTasksActive.getOrElse(s.stageId, HashSet[TaskInfo]()).size
+ val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0)
+ val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match {
case f if f > 0 => "(%s failed)".format(f)
case _ => ""
}
- val totalTasks = s.numPartitions
+ val totalTasks = s.numTasks
- val poolName = listener.stageToPool.get(s)
+ val poolName = listener.stageIdToPool.get(s.stageId)
val nameLink =
- <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.id)}>{s.name}</a>
- val description = listener.stageToDescription.get(s)
+ <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.stageId)}>{s.name}</a>
+ val description = listener.stageIdToDescription.get(s.stageId)
.map(d => <div><em>{d}</em></div><div>{nameLink}</div>).getOrElse(nameLink)
val finishTime = s.completionTime.getOrElse(System.currentTimeMillis())
val duration = s.submissionTime.map(t => finishTime - t)
<tr>
- <td>{s.id}</td>
+ <td>{s.stageId}</td>
{if (isFairScheduler) {
<td><a href={"%s/stages/pool?poolname=%s".format(UIUtils.prependBaseUri(),poolName.get)}>
{poolName.get}</a></td>}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
index 43c1257677..b83cd54f3c 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import org.apache.spark.storage.{StorageStatus, StorageUtils}
+import org.apache.spark.storage.{BlockId, StorageStatus, StorageUtils}
import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus
import org.apache.spark.ui.UIUtils._
import org.apache.spark.ui.Page._
@@ -33,21 +33,20 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
val sc = parent.sc
def render(request: HttpServletRequest): Seq[Node] = {
- val id = request.getParameter("id")
- val prefix = "rdd_" + id.toString
+ val id = request.getParameter("id").toInt
val storageStatusList = sc.getExecutorStorageStatus
- val filteredStorageStatusList = StorageUtils.
- filterStorageStatusByPrefix(storageStatusList, prefix)
+ val filteredStorageStatusList = StorageUtils.filterStorageStatusByRDD(storageStatusList, id)
val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage")
- val workers = filteredStorageStatusList.map((prefix, _))
+ val workers = filteredStorageStatusList.map((id, _))
val workerTable = listingTable(workerHeaders, workerRow, workers)
val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk",
"Executors")
- val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1)
+ val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.
+ sortWith(_._1.name < _._1.name)
val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList)
val blocks = blockStatuses.map {
case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN")))
@@ -99,7 +98,7 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage)
}
- def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = {
+ def blockRow(row: (BlockId, BlockStatus, Seq[String])): Seq[Node] = {
val (id, block, locations) = row
<tr>
<td>{id}</td>
@@ -118,15 +117,15 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
</tr>
}
- def workerRow(worker: (String, StorageStatus)): Seq[Node] = {
- val (prefix, status) = worker
+ def workerRow(worker: (Int, StorageStatus)): Seq[Node] = {
+ val (rddId, status) = worker
<tr>
<td>{status.blockManagerId.host + ":" + status.blockManagerId.port}</td>
<td>
- {Utils.bytesToString(status.memUsed(prefix))}
+ {Utils.bytesToString(status.memUsedByRDD(rddId))}
({Utils.bytesToString(status.memRemaining)} Remaining)
</td>
- <td>{Utils.bytesToString(status.diskUsed(prefix))}</td>
+ <td>{Utils.bytesToString(status.diskUsedByRDD(rddId))}</td>
</tr>
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
new file mode 100644
index 0000000000..f60deafc6f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * A simple open hash table optimized for the append-only use case, where keys
+ * are never removed, but the value for each key may be changed.
+ *
+ * This implementation uses quadratic probing with a power-of-2 hash table
+ * size, which is guaranteed to explore all spaces for each key (see
+ * http://en.wikipedia.org/wiki/Quadratic_probing).
+ *
+ * TODO: Cache the hash values of each key? java.util.HashMap does that.
+ */
+private[spark]
+class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable {
+ require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+ require(initialCapacity >= 1, "Invalid initial capacity")
+
+ private var capacity = nextPowerOf2(initialCapacity)
+ private var mask = capacity - 1
+ private var curSize = 0
+
+ // Holds keys and values in the same array for memory locality; specifically, the order of
+ // elements is key0, value0, key1, value1, key2, value2, etc.
+ private var data = new Array[AnyRef](2 * capacity)
+
+ // Treat the null key differently so we can use nulls in "data" to represent empty items.
+ private var haveNullValue = false
+ private var nullValue: V = null.asInstanceOf[V]
+
+ private val LOAD_FACTOR = 0.7
+
+ /** Get the value for a given key */
+ def apply(key: K): V = {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ return nullValue
+ }
+ var pos = rehash(k.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (k.eq(curKey) || k == curKey) {
+ return data(2 * pos + 1).asInstanceOf[V]
+ } else if (curKey.eq(null)) {
+ return null.asInstanceOf[V]
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ return null.asInstanceOf[V]
+ }
+
+ /** Set the value for a key */
+ def update(key: K, value: V): Unit = {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ if (!haveNullValue) {
+ incrementSize()
+ }
+ nullValue = value
+ haveNullValue = true
+ return
+ }
+ val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
+ if (isNewEntry) {
+ incrementSize()
+ }
+ }
+
+ /**
+ * Set the value for key to updateFunc(hadValue, oldValue), where oldValue will be the old value
+ * for key, if any, or null otherwise. Returns the newly updated value.
+ */
+ def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
+ val k = key.asInstanceOf[AnyRef]
+ if (k.eq(null)) {
+ if (!haveNullValue) {
+ incrementSize()
+ }
+ nullValue = updateFunc(haveNullValue, nullValue)
+ haveNullValue = true
+ return nullValue
+ }
+ var pos = rehash(k.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (k.eq(curKey) || k == curKey) {
+ val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
+ data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
+ return newValue
+ } else if (curKey.eq(null)) {
+ val newValue = updateFunc(false, null.asInstanceOf[V])
+ data(2 * pos) = k
+ data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
+ incrementSize()
+ return newValue
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ null.asInstanceOf[V] // Never reached but needed to keep compiler happy
+ }
+
+ /** Iterator method from Iterable */
+ override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
+ var pos = -1
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def nextValue(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
+ }
+ while (pos < capacity) {
+ if (!data(2 * pos).eq(null)) {
+ return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ }
+ pos += 1
+ }
+ null
+ }
+
+ override def hasNext: Boolean = nextValue() != null
+
+ override def next(): (K, V) = {
+ val value = nextValue()
+ if (value == null) {
+ throw new NoSuchElementException("End of iterator")
+ }
+ pos += 1
+ value
+ }
+ }
+
+ override def size: Int = curSize
+
+ /** Increase table size by 1, rehashing if necessary */
+ private def incrementSize() {
+ curSize += 1
+ if (curSize > LOAD_FACTOR * capacity) {
+ growTable()
+ }
+ }
+
+ /**
+ * Re-hash a value to deal better with hash functions that don't differ
+ * in the lower bits, similar to java.util.HashMap
+ */
+ private def rehash(h: Int): Int = {
+ val r = h ^ (h >>> 20) ^ (h >>> 12)
+ r ^ (r >>> 7) ^ (r >>> 4)
+ }
+
+ /**
+ * Put an entry into a table represented by data, returning true if
+ * this increases the size of the table or false otherwise. Assumes
+ * that "data" has at least one empty slot.
+ */
+ private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
+ val mask = (data.length / 2) - 1
+ var pos = rehash(key.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (curKey.eq(null)) {
+ data(2 * pos) = key
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return true
+ } else if (curKey.eq(key) || curKey == key) {
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return false
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ return false // Never reached but needed to keep compiler happy
+ }
+
+ /** Double the table's size and re-hash everything */
+ private def growTable() {
+ val newCapacity = capacity * 2
+ if (newCapacity >= (1 << 30)) {
+ // We can't make the table this big because we want an array of 2x
+ // that size for our data, but array sizes are at most Int.MaxValue
+ throw new Exception("Can't make capacity bigger than 2^29 elements")
+ }
+ val newData = new Array[AnyRef](2 * newCapacity)
+ var pos = 0
+ while (pos < capacity) {
+ if (!data(2 * pos).eq(null)) {
+ putInto(newData, data(2 * pos), data(2 * pos + 1))
+ }
+ pos += 1
+ }
+ data = newData
+ capacity = newCapacity
+ mask = newCapacity - 1
+ }
+
+ private def nextPowerOf2(n: Int): Int = {
+ val highBit = Integer.highestOneBit(n)
+ if (highBit == n) n else highBit << 1
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index a430a75451..67a7f87a5c 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -17,7 +17,6 @@
package org.apache.spark.util
-import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors}
import java.util.{TimerTask, Timer}
import org.apache.spark.Logging
@@ -25,11 +24,14 @@ import org.apache.spark.Logging
/**
* Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
*/
-class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
+class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, cleanupFunc: (Long) => Unit) extends Logging {
+ val name = cleanerType.toString
+
private val delaySeconds = MetadataCleaner.getDelaySeconds
private val periodSeconds = math.max(10, delaySeconds / 10)
private val timer = new Timer(name + " cleanup timer", true)
+
private val task = new TimerTask {
override def run() {
try {
@@ -53,9 +55,38 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging
}
}
+object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
+ "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
+
+ val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
+ SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
+
+ type MetadataCleanerType = Value
+
+ def systemProperty(which: MetadataCleanerType.MetadataCleanerType) = "spark.cleaner.ttl." + which.toString
+}
object MetadataCleaner {
+
+ // using only sys props for now : so that workers can also get to it while preserving earlier behavior.
def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt
- def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) }
+
+ def getDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType): Int = {
+ System.getProperty(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds.toString).toInt
+ }
+
+ def setDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType, delay: Int) {
+ System.setProperty(MetadataCleanerType.systemProperty(cleanerType), delay.toString)
+ }
+
+ def setDelaySeconds(delay: Int, resetAll: Boolean = true) {
+ // override for all ?
+ System.setProperty("spark.cleaner.ttl", delay.toString)
+ if (resetAll) {
+ for (cleanerType <- MetadataCleanerType.values) {
+ System.clearProperty(MetadataCleanerType.systemProperty(cleanerType))
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 94ce50e964..7557ddab19 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,16 +18,12 @@
package org.apache.spark.util
import java.io._
-import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
+import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address}
import java.util.{Locale, Random, UUID}
+import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
-import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
-import java.util.regex.Pattern
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
-
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.collection.Map
import scala.io.Source
@@ -43,7 +39,7 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import org.apache.spark.deploy.SparkHadoopUtil
import java.nio.ByteBuffer
-import org.apache.spark.{SparkEnv, SparkException, Logging}
+import org.apache.spark.{SparkException, Logging}
/**
@@ -155,7 +151,7 @@ private[spark] object Utils extends Logging {
return buf
}
- private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
+ private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
@@ -287,9 +283,8 @@ private[spark] object Utils extends Logging {
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
- val env = SparkEnv.get
val uri = new URI(url)
- val conf = env.hadoop.newConfiguration()
+ val conf = SparkHadoopUtil.get.newConfiguration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(tempFile)
@@ -454,14 +449,17 @@ private[spark] object Utils extends Logging {
hostPortParseResults.get(hostPort)
}
- private[spark] val daemonThreadFactory: ThreadFactory =
- new ThreadFactoryBuilder().setDaemon(true).build()
+ private val daemonThreadFactoryBuilder: ThreadFactoryBuilder =
+ new ThreadFactoryBuilder().setDaemon(true)
/**
- * Wrapper over newCachedThreadPool.
+ * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a
+ * unique, sequentially assigned integer.
*/
- def newDaemonCachedThreadPool(): ThreadPoolExecutor =
- Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+ def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = {
+ val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build()
+ Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
+ }
/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
@@ -472,10 +470,13 @@ private[spark] object Utils extends Logging {
}
/**
- * Wrapper over newFixedThreadPool.
+ * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a
+ * unique, sequentially assigned integer.
*/
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
- Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+ def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = {
+ val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build()
+ Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor]
+ }
private def listFilesSafely(file: File): Seq[File] = {
val files = file.listFiles()
@@ -820,4 +821,10 @@ private[spark] object Utils extends Logging {
// Nothing else to guard against ?
hashAbs
}
+
+ /** Returns a copy of the system properties that is thread-safe to iterator over. */
+ def getSystemProperties(): Map[String, String] = {
+ return System.getProperties().clone()
+ .asInstanceOf[java.util.Properties].toMap[String, String]
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
new file mode 100644
index 0000000000..a1a452315d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+
+/**
+ * A simple, fixed-size bit set implementation. This implementation is fast because it avoids
+ * safety/bound checking.
+ */
+class BitSet(numBits: Int) {
+
+ private[this] val words = new Array[Long](bit2words(numBits))
+ private[this] val numWords = words.length
+
+ /**
+ * Sets the bit at the specified index to true.
+ * @param index the bit index
+ */
+ def set(index: Int) {
+ val bitmask = 1L << (index & 0x3f) // mod 64 and shift
+ words(index >> 6) |= bitmask // div by 64 and mask
+ }
+
+ /**
+ * Return the value of the bit with the specified index. The value is true if the bit with
+ * the index is currently set in this BitSet; otherwise, the result is false.
+ *
+ * @param index the bit index
+ * @return the value of the bit with the specified index
+ */
+ def get(index: Int): Boolean = {
+ val bitmask = 1L << (index & 0x3f) // mod 64 and shift
+ (words(index >> 6) & bitmask) != 0 // div by 64 and mask
+ }
+
+ /** Return the number of bits set to true in this BitSet. */
+ def cardinality(): Int = {
+ var sum = 0
+ var i = 0
+ while (i < numWords) {
+ sum += java.lang.Long.bitCount(words(i))
+ i += 1
+ }
+ sum
+ }
+
+ /**
+ * Returns the index of the first bit that is set to true that occurs on or after the
+ * specified starting index. If no such bit exists then -1 is returned.
+ *
+ * To iterate over the true bits in a BitSet, use the following loop:
+ *
+ * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) {
+ * // operate on index i here
+ * }
+ *
+ * @param fromIndex the index to start checking from (inclusive)
+ * @return the index of the next set bit, or -1 if there is no such bit
+ */
+ def nextSetBit(fromIndex: Int): Int = {
+ var wordIndex = fromIndex >> 6
+ if (wordIndex >= numWords) {
+ return -1
+ }
+
+ // Try to find the next set bit in the current word
+ val subIndex = fromIndex & 0x3f
+ var word = words(wordIndex) >> subIndex
+ if (word != 0) {
+ return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word)
+ }
+
+ // Find the next set bit in the rest of the words
+ wordIndex += 1
+ while (wordIndex < numWords) {
+ word = words(wordIndex)
+ if (word != 0) {
+ return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word)
+ }
+ wordIndex += 1
+ }
+
+ -1
+ }
+
+ /** Return the number of longs it would take to hold numBits. */
+ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
new file mode 100644
index 0000000000..45849b3380
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect.ClassTag
+
+
+/**
+ * A fast hash map implementation for nullable keys. This hash map supports insertions and updates,
+ * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less
+ * space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[spark]
+class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
+ initialCapacity: Int)
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ def this() = this(64)
+
+ protected var _keySet = new OpenHashSet[K](initialCapacity)
+
+ // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+ // bug that would generate two arrays (one for Object and one for specialized T).
+ private var _values: Array[V] = _
+ _values = new Array[V](_keySet.capacity)
+
+ @transient private var _oldValues: Array[V] = null
+
+ // Treat the null key differently so we can use nulls in "data" to represent empty items.
+ private var haveNullValue = false
+ private var nullValue: V = null.asInstanceOf[V]
+
+ override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ if (k == null) {
+ nullValue
+ } else {
+ val pos = _keySet.getPos(k)
+ if (pos < 0) {
+ null.asInstanceOf[V]
+ } else {
+ _values(pos)
+ }
+ }
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ if (k == null) {
+ haveNullValue = true
+ nullValue = v
+ } else {
+ val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ _keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+ }
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ if (k == null) {
+ if (haveNullValue) {
+ nullValue = mergeValue(nullValue)
+ } else {
+ haveNullValue = true
+ nullValue = defaultValue
+ }
+ nullValue
+ } else {
+ val pos = _keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ _keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = -1
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ pos += 1
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
+ }
+ pos = _keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (_keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the non-specialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
new file mode 100644
index 0000000000..49d95afdb9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -0,0 +1,272 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect._
+
+/**
+ * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
+ * removed.
+ *
+ * The underlying implementation uses Scala compiler's specialization to generate optimized
+ * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet
+ * while incurring much less memory overhead. This can serve as building blocks for higher level
+ * data structures such as an optimized HashMap.
+ *
+ * This OpenHashSet is designed to serve as building blocks for higher level data structures
+ * such as an optimized hash map. Compared with standard hash set implementations, this class
+ * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to
+ * retrieve the position of a key in the underlying array.
+ *
+ * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed
+ * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing).
+ */
+private[spark]
+class OpenHashSet[@specialized(Long, Int) T: ClassTag](
+ initialCapacity: Int,
+ loadFactor: Double)
+ extends Serializable {
+
+ require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+ require(initialCapacity >= 1, "Invalid initial capacity")
+ require(loadFactor < 1.0, "Load factor must be less than 1.0")
+ require(loadFactor > 0.0, "Load factor must be greater than 0.0")
+
+ import OpenHashSet._
+
+ def this(initialCapacity: Int) = this(initialCapacity, 0.7)
+
+ def this() = this(64)
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the non-specialized one and needs access
+ // to the "private" variables).
+
+ protected val hasher: Hasher[T] = {
+ // It would've been more natural to write the following using pattern matching. But Scala 2.9.x
+ // compiler has a bug when specialization is used together with this pattern matching, and
+ // throws:
+ // scala.tools.nsc.symtab.Types$TypeError: type mismatch;
+ // found : scala.reflect.AnyValManifest[Long]
+ // required: scala.reflect.ClassTag[Int]
+ // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298)
+ // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207)
+ // ...
+ val mt = classTag[T]
+ if (mt == ClassTag.Long) {
+ (new LongHasher).asInstanceOf[Hasher[T]]
+ } else if (mt == ClassTag.Int) {
+ (new IntHasher).asInstanceOf[Hasher[T]]
+ } else {
+ new Hasher[T]
+ }
+ }
+
+ protected var _capacity = nextPowerOf2(initialCapacity)
+ protected var _mask = _capacity - 1
+ protected var _size = 0
+
+ protected var _bitset = new BitSet(_capacity)
+
+ // Init of the array in constructor (instead of in declaration) to work around a Scala compiler
+ // specialization bug that would generate two arrays (one for Object and one for specialized T).
+ protected var _data: Array[T] = _
+ _data = new Array[T](_capacity)
+
+ /** Number of elements in the set. */
+ def size: Int = _size
+
+ /** The capacity of the set (i.e. size of the underlying array). */
+ def capacity: Int = _capacity
+
+ /** Return true if this set contains the specified element. */
+ def contains(k: T): Boolean = getPos(k) != INVALID_POS
+
+ /**
+ * Add an element to the set. If the set is over capacity after the insertion, grow the set
+ * and rehash all elements.
+ */
+ def add(k: T) {
+ addWithoutResize(k)
+ rehashIfNeeded(k, grow, move)
+ }
+
+ /**
+ * Add an element to the set. This one differs from add in that it doesn't trigger rehashing.
+ * The caller is responsible for calling rehashIfNeeded.
+ *
+ * Use (retval & POSITION_MASK) to get the actual position, and
+ * (retval & EXISTENCE_MASK) != 0 for prior existence.
+ *
+ * @return The position where the key is placed, plus the highest order bit is set if the key
+ * exists previously.
+ */
+ def addWithoutResize(k: T): Int = putInto(_bitset, _data, k)
+
+ /**
+ * Rehash the set if it is overloaded.
+ * @param k A parameter unused in the function, but to force the Scala compiler to specialize
+ * this method.
+ * @param allocateFunc Callback invoked when we are allocating a new, larger array.
+ * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
+ * to a new position (in the new data array).
+ */
+ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ if (_size > loadFactor * _capacity) {
+ rehash(k, allocateFunc, moveFunc)
+ }
+ }
+
+ /**
+ * Return the position of the element in the underlying array, or INVALID_POS if it is not found.
+ */
+ def getPos(k: T): Int = {
+ var pos = hashcode(hasher.hash(k)) & _mask
+ var i = 1
+ while (true) {
+ if (!_bitset.get(pos)) {
+ return INVALID_POS
+ } else if (k == _data(pos)) {
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & _mask
+ i += 1
+ }
+ }
+ // Never reached here
+ INVALID_POS
+ }
+
+ /** Return the value at the specified position. */
+ def getValue(pos: Int): T = _data(pos)
+
+ /**
+ * Return the next position with an element stored, starting from the given position inclusively.
+ */
+ def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
+
+ /**
+ * Put an entry into the set. Return the position where the key is placed. In addition, the
+ * highest bit in the returned position is set if the key exists prior to this put.
+ *
+ * This function assumes the data array has at least one empty slot.
+ */
+ private def putInto(bitset: BitSet, data: Array[T], k: T): Int = {
+ val mask = data.length - 1
+ var pos = hashcode(hasher.hash(k)) & mask
+ var i = 1
+ while (true) {
+ if (!bitset.get(pos)) {
+ // This is a new key.
+ data(pos) = k
+ bitset.set(pos)
+ _size += 1
+ return pos | NONEXISTENCE_MASK
+ } else if (data(pos) == k) {
+ // Found an existing key.
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
+ }
+ // Never reached here
+ assert(INVALID_POS != INVALID_POS)
+ INVALID_POS
+ }
+
+ /**
+ * Double the table's size and re-hash everything. We are not really using k, but it is declared
+ * so Scala compiler can specialize this method (which leads to calling the specialized version
+ * of putInto).
+ *
+ * @param k A parameter unused in the function, but to force the Scala compiler to specialize
+ * this method.
+ * @param allocateFunc Callback invoked when we are allocating a new, larger array.
+ * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
+ * to a new position (in the new data array).
+ */
+ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ val newCapacity = _capacity * 2
+ require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+
+ allocateFunc(newCapacity)
+ val newData = new Array[T](newCapacity)
+ val newBitset = new BitSet(newCapacity)
+ var pos = 0
+ _size = 0
+ while (pos < _capacity) {
+ if (_bitset.get(pos)) {
+ val newPos = putInto(newBitset, newData, _data(pos))
+ moveFunc(pos, newPos & POSITION_MASK)
+ }
+ pos += 1
+ }
+ _bitset = newBitset
+ _data = newData
+ _capacity = newCapacity
+ _mask = newCapacity - 1
+ }
+
+ /**
+ * Re-hash a value to deal better with hash functions that don't differ
+ * in the lower bits, similar to java.util.HashMap
+ */
+ private def hashcode(h: Int): Int = {
+ val r = h ^ (h >>> 20) ^ (h >>> 12)
+ r ^ (r >>> 7) ^ (r >>> 4)
+ }
+
+ private def nextPowerOf2(n: Int): Int = {
+ val highBit = Integer.highestOneBit(n)
+ if (highBit == n) n else highBit << 1
+ }
+}
+
+
+private[spark]
+object OpenHashSet {
+
+ val INVALID_POS = -1
+ val NONEXISTENCE_MASK = 0x80000000
+ val POSITION_MASK = 0xEFFFFFF
+
+ /**
+ * A set of specialized hash function implementation to avoid boxing hash code computation
+ * in the specialized implementation of OpenHashSet.
+ */
+ sealed class Hasher[@specialized(Long, Int) T] {
+ def hash(o: T): Int = o.hashCode()
+ }
+
+ class LongHasher extends Hasher[Long] {
+ override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt
+ }
+
+ class IntHasher extends Hasher[Int] {
+ override def hash(o: Int): Int = o
+ }
+
+ private def grow1(newSize: Int) {}
+ private def move1(oldPos: Int, newPos: Int) { }
+
+ private val grow = grow1 _
+ private val move = move1 _
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
new file mode 100644
index 0000000000..2e1ef06cbc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect._
+
+/**
+ * A fast hash map implementation for primitive, non-null keys. This hash map supports
+ * insertions and updates, but not deletions. This map is about an order of magnitude
+ * faster than java.util.HashMap, while using much less space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[spark]
+class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
+ @specialized(Long, Int, Double) V: ClassTag](
+ initialCapacity: Int)
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ def this() = this(64)
+
+ require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int])
+
+ // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+ // bug that would generate two arrays (one for Object and one for specialized T).
+ protected var _keySet: OpenHashSet[K] = _
+ private var _values: Array[V] = _
+ _keySet = new OpenHashSet[K](initialCapacity)
+ _values = new Array[V](_keySet.capacity)
+
+ private var _oldValues: Array[V] = null
+
+ override def size = _keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ val pos = _keySet.getPos(k)
+ _values(pos)
+ }
+
+ /** Get the value for a given key, or returns elseValue if it doesn't exist. */
+ def getOrElse(k: K, elseValue: V): V = {
+ val pos = _keySet.getPos(k)
+ if (pos >= 0) _values(pos) else elseValue
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ _keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ val pos = _keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ _keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = 0
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ pos = _keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (_keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the unspecialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
new file mode 100644
index 0000000000..465c221d5f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect.ClassTag
+
+/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */
+private[spark]
+class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: Int = 64) {
+ private var numElements = 0
+ private var array: Array[V] = _
+
+ // NB: This must be separate from the declaration, otherwise the specialized parent class
+ // will get its own array with the same initial size. TODO: Figure out why...
+ array = new Array[V](initialSize)
+
+ def apply(index: Int): V = {
+ require(index < numElements)
+ array(index)
+ }
+
+ def +=(value: V) {
+ if (numElements == array.length) { resize(array.length * 2) }
+ array(numElements) = value
+ numElements += 1
+ }
+
+ def length = numElements
+
+ def getUnderlyingArray = array
+
+ /** Resizes the array, dropping elements if the total length decreases. */
+ def resize(newLength: Int) {
+ val newArray = new Array[V](newLength)
+ array.copyToArray(newArray)
+ array = newArray
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
index b3a53d928b..e022accee6 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -20,8 +20,42 @@ package org.apache.spark
import org.scalatest.FunSuite
class BroadcastSuite extends FunSuite with LocalSparkContext {
-
- test("basic broadcast") {
+
+ override def afterEach() {
+ super.afterEach()
+ System.clearProperty("spark.broadcast.factory")
+ }
+
+ test("Using HttpBroadcast locally") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ sc = new SparkContext("local", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === Set((1, 10), (2, 10)))
+ }
+
+ test("Accessing HttpBroadcast variables from multiple threads") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ sc = new SparkContext("local[10]", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+ }
+
+ test("Accessing HttpBroadcast variables in a local cluster") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
+ test("Using TorrentBroadcast locally") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
@@ -29,11 +63,23 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
assert(results.collect.toSet === Set((1, 10), (2, 10)))
}
- test("broadcast variables accessed in multiple threads") {
+ test("Accessing TorrentBroadcast variables from multiple threads") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
}
+
+ test("Accessing TorrentBroadcast variables in a local cluster") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 3a7171c488..ea936e815b 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.mock.EasyMockSugar
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.{BlockManager, StorageLevel}
+import org.apache.spark.storage.{BlockManager, RDDBlockId, StorageLevel}
// TODO: Test the CacheManager's thread-safety aspects
class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar {
@@ -52,13 +52,14 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get uncached rdd") {
expecting {
- blockManager.get("rdd_0_0").andReturn(None)
- blockManager.put("rdd_0_0", ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, true).
- andReturn(0)
+ blockManager.get(RDDBlockId(0, 0)).andReturn(None)
+ blockManager.put(RDDBlockId(0, 0), ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY,
+ true).andReturn(0)
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+ val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
+ taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
@@ -66,11 +67,12 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get cached rdd") {
expecting {
- blockManager.get("rdd_0_0").andReturn(Some(ArrayBuffer(5, 6, 7).iterator))
+ blockManager.get(RDDBlockId(0, 0)).andReturn(Some(ArrayBuffer(5, 6, 7).iterator))
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+ val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
+ taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -79,11 +81,12 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get uncached local rdd") {
expecting {
// Local computation should not persist the resulting value, so don't expect a put().
- blockManager.get("rdd_0_0").andReturn(None)
+ blockManager.get(RDDBlockId(0, 0)).andReturn(None)
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, runningLocally = true, null)
+ val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false,
+ taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 78e35732d2..fcfc2c9893 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -22,7 +22,7 @@ import org.scalatest.FunSuite
import java.io.File
import org.apache.spark.rdd._
import org.apache.spark.SparkContext._
-import storage.StorageLevel
+import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils
class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
@@ -63,8 +63,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithIndexRDD(r,
- (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+ testCheckpointing(r => new MapPartitionsWithContextRDD(r,
+ (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
testCheckpointing(_.pipe(Seq("cat")))
@@ -84,7 +84,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
test("BlockRDD") {
- val blockId = "id"
+ val blockId = TestBlockId("id")
val blockManager = SparkEnv.get.blockManager
blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
val blockRDD = new BlockRDD[String](sc, Array(blockId))
@@ -192,7 +192,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
test("CheckpointRDD with zero partitions") {
- val rdd = new BlockRDD[Int](sc, Array[String]())
+ val rdd = new BlockRDD[Int](sc, Array[BlockId]())
assert(rdd.partitions.size === 0)
assert(rdd.isCheckpointed === false)
rdd.checkpoint()
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 988ab1747d..d9cffb74de 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -18,24 +18,14 @@
package org.apache.spark
import network.ConnectionManagerId
-import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Timeouts._
+import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
-import org.scalatest.prop.Checkers
import org.scalatest.time.{Span, Millis}
-import org.scalacheck.Arbitrary._
-import org.scalacheck.Gen
-import org.scalacheck.Prop._
-import org.eclipse.jetty.server.{Server, Request, Handler}
-
-import com.google.common.io.Files
-
-import scala.collection.mutable.ArrayBuffer
import SparkContext._
-import storage.{GetBlock, BlockManagerWorker, StorageLevel}
-import ui.JettyUtils
+import org.apache.spark.storage.{BlockManagerWorker, GetBlock, RDDBlockId, StorageLevel}
class NotSerializableClass
@@ -193,7 +183,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
// Get all the locations of the first partition and try to fetch the partitions
// from those locations.
- val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray
+ val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray
val blockId = blockIds(0)
val blockManager = SparkEnv.get.blockManager
blockManager.master.getLocations(blockId).foreach(id => {
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
index 35d1d41af1..c210dd5c3b 100644
--- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -120,4 +120,20 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
}.collect()
assert(result.toSet === Set((1,2), (2,7), (3,121)))
}
+
+ test ("Dynamically adding JARS on a standalone cluster using local: URL") {
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile()
+ sc.addJar(sampleJarFile.replace("file", "local"))
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0))
+ val result = sc.parallelize(testData).reduceByKey { (x,y) =>
+ val fac = Thread.currentThread.getContextClassLoader()
+ .loadClass("org.uncommons.maths.Maths")
+ .getDeclaredMethod("factorial", classOf[Int])
+ val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ a + b
+ }.collect()
+ assert(result.toSet === Set((1,2), (2,7), (3,121)))
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 591c1d498d..352036f182 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -473,6 +473,27 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void repartition() {
+ // Shrinking number of partitions
+ JavaRDD<Integer> in1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 2);
+ JavaRDD<Integer> repartitioned1 = in1.repartition(4);
+ List<List<Integer>> result1 = repartitioned1.glom().collect();
+ Assert.assertEquals(4, result1.size());
+ for (List<Integer> l: result1) {
+ Assert.assertTrue(l.size() > 0);
+ }
+
+ // Growing number of partitions
+ JavaRDD<Integer> in2 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 4);
+ JavaRDD<Integer> repartitioned2 = in2.repartition(2);
+ List<List<Integer>> result2 = repartitioned2.glom().collect();
+ Assert.assertEquals(2, result2.size());
+ for (List<Integer> l: result2) {
+ Assert.assertTrue(l.size() > 0);
+ }
+ }
+
+ @Test
public void persist() {
JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY());
@@ -495,7 +516,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContext(0, 0, 0, false, null);
+ TaskContext context = new TaskContext(0, 0, 0, false, false, null);
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
}
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
new file mode 100644
index 0000000000..d8a0e983b2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.concurrent.Semaphore
+
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+import scala.concurrent.future
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListener}
+
+
+/**
+ * Test suite for cancelling running jobs. We run the cancellation tasks for single job action
+ * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers
+ * in both FIFO and fair scheduling modes.
+ */
+class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
+ with LocalSparkContext {
+
+ override def afterEach() {
+ super.afterEach()
+ resetSparkContext()
+ System.clearProperty("spark.scheduler.mode")
+ }
+
+ test("local mode, FIFO scheduler") {
+ System.setProperty("spark.scheduler.mode", "FIFO")
+ sc = new SparkContext("local[2]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("local mode, fair scheduler") {
+ System.setProperty("spark.scheduler.mode", "FAIR")
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.scheduler.allocation.file", xmlPath)
+ sc = new SparkContext("local[2]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("cluster mode, FIFO scheduler") {
+ System.setProperty("spark.scheduler.mode", "FIFO")
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("cluster mode, fair scheduler") {
+ System.setProperty("spark.scheduler.mode", "FAIR")
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.scheduler.allocation.file", xmlPath)
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("job group") {
+ sc = new SparkContext("local[2]", "test")
+
+ // Add a listener to release the semaphore once any tasks are launched.
+ val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+
+ // jobA is the one to be cancelled.
+ val jobA = future {
+ sc.setJobGroup("jobA", "this is a job to be cancelled")
+ sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
+ }
+
+ sc.clearJobGroup()
+ val jobB = sc.parallelize(1 to 100, 2).countAsync()
+
+ // Block until both tasks of job A have started and cancel job A.
+ sem.acquire(2)
+ sc.cancelJobGroup("jobA")
+ val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) }
+ assert(e.getMessage contains "cancel")
+
+ // Once A is cancelled, job B should finish fairly quickly.
+ assert(jobB.get() === 100)
+ }
+
+ test("two jobs sharing the same stage") {
+ // sem1: make sure cancel is issued after some tasks are launched
+ // sem2: make sure the first stage is not finished until cancel is issued
+ val sem1 = new Semaphore(0)
+ val sem2 = new Semaphore(0)
+
+ sc = new SparkContext("local[2]", "test")
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem1.release()
+ }
+ })
+
+ // Create two actions that would share the some stages.
+ val rdd = sc.parallelize(1 to 10, 2).map { i =>
+ sem2.acquire()
+ (i, i)
+ }.reduceByKey(_+_)
+ val f1 = rdd.collectAsync()
+ val f2 = rdd.countAsync()
+
+ // Kill one of the action.
+ future {
+ sem1.acquire()
+ f1.cancel()
+ sem2.release(10)
+ }
+
+ // Expect both to fail now.
+ // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2.
+ intercept[SparkException] { f1.get() }
+ intercept[SparkException] { f2.get() }
+ }
+
+ def testCount() {
+ // Cancel before launching any tasks
+ {
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
+ future { f.cancel() }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+
+ // Cancel after some tasks have been launched
+ {
+ // Add a listener to release the semaphore once any tasks are launched.
+ val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
+ future {
+ // Wait until some tasks were launched before we cancel the job.
+ sem.acquire()
+ f.cancel()
+ }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+ }
+
+ def testTake() {
+ // Cancel before launching any tasks
+ {
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
+ future { f.cancel() }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+
+ // Cancel after some tasks have been launched
+ {
+ // Add a listener to release the semaphore once any tasks are launched.
+ val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
+ future {
+ sem.acquire()
+ f.cancel()
+ }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 18fb1bf590..fd174600c7 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -48,15 +48,15 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.stop()
}
test("master register and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -74,19 +74,17 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
- // As if we had two simulatenous fetch failures
+ // As if we had two simultaneous fetch failures
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
@@ -102,9 +100,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
- val masterTracker = new MapOutputTracker()
+ val masterTracker = new MapOutputTrackerMaster()
masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
val slaveTracker = new MapOutputTracker()
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index 05f8545c7b..0b38e239f9 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -25,7 +25,7 @@ import net.liftweb.json.JsonAST.JValue
import org.scalatest.FunSuite
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse}
-import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo}
+import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState, WorkerInfo}
import org.apache.spark.deploy.worker.ExecutorRunner
class JsonProtocolSuite extends FunSuite {
@@ -53,7 +53,8 @@ class JsonProtocolSuite extends FunSuite {
val workers = Array[WorkerInfo](createWorkerInfo(), createWorkerInfo())
val activeApps = Array[ApplicationInfo](createAppInfo())
val completedApps = Array[ApplicationInfo]()
- val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps)
+ val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps,
+ RecoveryState.ALIVE)
val output = JsonProtocol.writeMasterState(stateResponse)
assertValidJson(output)
}
@@ -79,7 +80,7 @@ class JsonProtocolSuite extends FunSuite {
}
def createExecutorRunner() : ExecutorRunner = {
new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host",
- new File("sparkHome"), new File("workDir"))
+ new File("sparkHome"), new File("workDir"), ExecutorState.RUNNING)
}
def assertValidJson(json: JValue) {
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
new file mode 100644
index 0000000000..d433806987
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -0,0 +1,20 @@
+package org.apache.spark.deploy.worker
+
+import java.io.File
+import org.scalatest.FunSuite
+import org.apache.spark.deploy.{ExecutorState, Command, ApplicationDescription}
+
+class ExecutorRunnerTest extends FunSuite {
+
+ test("command includes appId") {
+ def f(s:String) = new File(s)
+ val sparkHome = sys.props("user.dir")
+ val appDesc = new ApplicationDescription("app name", 8, 500, Command("foo", Seq(),Map()),
+ sparkHome, "appUiUrl")
+ val appId = "12345-worker321-9876"
+ val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome),
+ f("ooga"), ExecutorState.RUNNING)
+
+ assert(er.buildCommandSeq().last === appId)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
new file mode 100644
index 0000000000..da032b17d9
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.util.concurrent.Semaphore
+
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.{SparkContext, SparkException, LocalSparkContext}
+
+
+class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts {
+
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local[2]", "test")
+ }
+
+ override def afterAll() {
+ LocalSparkContext.stop(sc)
+ sc = null
+ }
+
+ lazy val zeroPartRdd = new EmptyRDD[Int](sc)
+
+ test("countAsync") {
+ assert(zeroPartRdd.countAsync().get() === 0)
+ assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000)
+ }
+
+ test("collectAsync") {
+ assert(zeroPartRdd.collectAsync().get() === Seq.empty)
+
+ val collected = sc.parallelize(1 to 1000, 3).collectAsync().get()
+ assert(collected === (1 to 1000))
+ }
+
+ test("foreachAsync") {
+ zeroPartRdd.foreachAsync(i => Unit).get()
+
+ val accum = sc.accumulator(0)
+ sc.parallelize(1 to 1000, 3).foreachAsync { i =>
+ accum += 1
+ }.get()
+ assert(accum.value === 1000)
+ }
+
+ test("foreachPartitionAsync") {
+ zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
+
+ val accum = sc.accumulator(0)
+ sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
+ accum += 1
+ }.get()
+ assert(accum.value === 9)
+ }
+
+ test("takeAsync") {
+ def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
+ val expected = input.take(num)
+ val saw = rdd.takeAsync(num).get()
+ assert(saw == expected, "incorrect result for rdd with %d partitions (expected %s, saw %s)"
+ .format(rdd.partitions.size, expected, saw))
+ }
+ val input = Range(1, 1000)
+
+ var rdd = sc.parallelize(input, 1)
+ for (num <- Seq(0, 1, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+
+ rdd = sc.parallelize(input, 2)
+ for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+
+ rdd = sc.parallelize(input, 100)
+ for (num <- Seq(0, 1, 500, 501, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+
+ rdd = sc.parallelize(input, 1000)
+ for (num <- Seq(0, 1, 3, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+ }
+
+ /**
+ * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
+ * of a successful job execution.
+ */
+ test("async success handling") {
+ val f = sc.parallelize(1 to 10, 2).countAsync()
+
+ // Use a semaphore to make sure onSuccess and onComplete's success path will be called.
+ // If not, the test will hang.
+ val sem = new Semaphore(0)
+
+ f.onComplete {
+ case scala.util.Success(res) =>
+ sem.release()
+ case scala.util.Failure(e) =>
+ info("Should not have reached this code path (onComplete matching Failure)")
+ throw new Exception("Task should succeed")
+ }
+ f.onSuccess { case a: Any =>
+ sem.release()
+ }
+ f.onFailure { case t =>
+ info("Should not have reached this code path (onFailure)")
+ throw new Exception("Task should succeed")
+ }
+ assert(f.get() === 10)
+
+ failAfter(10 seconds) {
+ sem.acquire(2)
+ }
+ }
+
+ /**
+ * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
+ * of a failed job execution.
+ */
+ test("async failure handling") {
+ val f = sc.parallelize(1 to 10, 2).map { i =>
+ throw new Exception("intentional"); i
+ }.countAsync()
+
+ // Use a semaphore to make sure onFailure and onComplete's failure path will be called.
+ // If not, the test will hang.
+ val sem = new Semaphore(0)
+
+ f.onComplete {
+ case scala.util.Success(res) =>
+ info("Should not have reached this code path (onComplete matching Success)")
+ throw new Exception("Task should fail")
+ case scala.util.Failure(e) =>
+ sem.release()
+ }
+ f.onSuccess { case a: Any =>
+ info("Should not have reached this code path (onSuccess)")
+ throw new Exception("Task should fail")
+ }
+ f.onFailure { case t =>
+ sem.release()
+ }
+ intercept[SparkException] {
+ f.get()
+ }
+
+ failAfter(10 seconds) {
+ sem.acquire(2)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 31f97fc139..57d3382ed0 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -106,7 +106,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
}
}
visit(sums)
- assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection.
}
test("join") {
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 8fd1115207..d8dcd6d14c 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -139,6 +139,26 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(rdd.union(emptyKv).collect().size === 2)
}
+ test("repartitioned RDDs") {
+ val data = sc.parallelize(1 to 1000, 10)
+
+ // Coalesce partitions
+ val repartitioned1 = data.repartition(2)
+ assert(repartitioned1.partitions.size == 2)
+ val partitions1 = repartitioned1.glom().collect()
+ assert(partitions1(0).length > 0)
+ assert(partitions1(1).length > 0)
+ assert(repartitioned1.collect().toSet === (1 to 1000).toSet)
+
+ // Split partitions
+ val repartitioned2 = data.repartition(20)
+ assert(repartitioned2.partitions.size == 20)
+ val partitions2 = repartitioned2.glom().collect()
+ assert(partitions2(0).length > 0)
+ assert(partitions2(19).length > 0)
+ assert(repartitioned2.collect().toSet === (1 to 1000).toSet)
+ }
+
test("coalesced RDDs") {
val data = sc.parallelize(1 to 10, 10)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 2f933246b0..00f2fdd657 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -23,16 +23,15 @@ import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.apache.spark.LocalSparkContext
-import org.apache.spark.MapOutputTracker
-import org.apache.spark.rdd.RDD
+import org.apache.spark.MapOutputTrackerMaster
import org.apache.spark.SparkContext
import org.apache.spark.Partition
import org.apache.spark.TaskContext
import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
import org.apache.spark.{FetchFailed, Success, TaskEndReason}
-import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
-
+import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
/**
* Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
@@ -60,11 +59,12 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
taskSets += taskSet
}
- override def setListener(listener: TaskSchedulerListener) = {}
+ override def cancelTasks(stageId: Int) {}
+ override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2
}
- var mapOutputTracker: MapOutputTracker = null
+ var mapOutputTracker: MapOutputTrackerMaster = null
var scheduler: DAGScheduler = null
/**
@@ -75,15 +75,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
// stub out BlockManagerMaster.getLocations to use our cacheLocations
val blockManagerMaster = new BlockManagerMaster(null) {
- override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
- blockIds.map { name =>
- val pieces = name.split("_")
- if (pieces(0) == "rdd") {
- val key = pieces(1).toInt -> pieces(2).toInt
- cacheLocations.getOrElse(key, Seq())
- } else {
- Seq()
- }
+ override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
+ blockIds.map {
+ _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)).
+ getOrElse(Seq())
}.toSeq
}
override def removeExecutor(execId: String) {
@@ -104,7 +99,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSets.clear()
cacheLocations.clear()
results.clear()
- mapOutputTracker = new MapOutputTracker()
+ mapOutputTracker = new MapOutputTrackerMaster()
scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
override def runLocally(job: ActiveJob) {
// don't bother with the thread while unit testing
@@ -186,7 +181,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
allowLocal: Boolean = false,
listener: JobListener = listener) {
- runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
+ val jobId = scheduler.nextJobId.getAndIncrement()
+ runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener))
}
/** Sends TaskSetFailed to the scheduler. */
@@ -220,7 +216,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def getPreferredLocations(split: Partition) = Nil
override def toString = "DAGSchedulerSuite Local RDD"
}
- runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
+ val jobId = scheduler.nextJobId.getAndIncrement()
+ runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener))
assert(results === Map(0 -> 42))
}
@@ -247,7 +244,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
test("trivial job failure") {
submit(makeRdd(1, Nil), Array(0))
failed(taskSets(0), "some failure")
- assert(failure.getMessage === "Job failed: some failure")
+ assert(failure.getMessage === "Job aborted: some failure")
}
test("run trivial shuffle") {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
index cece60dda7..984881861c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -58,11 +58,14 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val parentRdd = makeRdd(4, Nil)
val shuffleDep = new ShuffleDependency(parentRdd, null)
val rootRdd = makeRdd(4, List(shuffleDep))
- val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None)
- val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None)
-
- joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null))
- joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
+ val shuffleMapStage =
+ new Stage(1, parentRdd, parentRdd.partitions.size, Some(shuffleDep), Nil, jobID, None)
+ val rootStage =
+ new Stage(0, rootRdd, rootRdd.partitions.size, None, List(shuffleMapStage), jobID, None)
+ val rootStageInfo = new StageInfo(rootStage)
+
+ joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null))
+ joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getSimpleName)
parentRdd.setName("MyRDD")
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
joblogger.createLogWriterTest(jobID)
@@ -88,8 +91,10 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
+
+ val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER)
- joblogger.getLogDir should be ("/tmp/spark")
+ joblogger.getLogDir should be ("/tmp/spark-%s".format(user))
joblogger.getJobIDtoPrintWriter.size should be (1)
joblogger.getStageIDToJobID.size should be (2)
joblogger.getStageIDToJobID.get(0) should be (Some(0))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 0d8742cb81..2e41438a52 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -17,16 +17,62 @@
package org.apache.spark.scheduler
-import org.scalatest.FunSuite
-import org.apache.spark.{SparkContext, LocalSparkContext}
-import scala.collection.mutable
+import scala.collection.mutable.{Buffer, HashSet}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._
-class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
+ with BeforeAndAfterAll {
+ /** Length of time to wait while draining listener events. */
+ val WAIT_TIMEOUT_MILLIS = 10000
+
+ override def afterAll {
+ System.clearProperty("spark.akka.frameSize")
+ }
+
+ test("basic creation of StageInfo") {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
+ val listener = new SaveStageInfo
+ sc.addSparkListener(listener)
+ val rdd1 = sc.parallelize(1 to 100, 4)
+ val rdd2 = rdd1.map(x => x.toString)
+ rdd2.setName("Target RDD")
+ rdd2.count
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
+ listener.stageInfos.size should be {1}
+ val first = listener.stageInfos.head
+ first.rddName should be {"Target RDD"}
+ first.numTasks should be {4}
+ first.numPartitions should be {4}
+ first.submissionTime should be ('defined)
+ first.completionTime should be ('defined)
+ first.taskInfos.length should be {4}
+ }
+
+ test("StageInfo with fewer tasks than partitions") {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
+ val listener = new SaveStageInfo
+ sc.addSparkListener(listener)
+ val rdd1 = sc.parallelize(1 to 100, 4)
+ val rdd2 = rdd1.map(x => x.toString)
+ sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
+ listener.stageInfos.size should be {1}
+ val first = listener.stageInfos.head
+ first.numTasks should be {2}
+ first.numPartitions should be {4}
+ }
test("local metrics") {
- sc = new SparkContext("local[4]", "test")
+ sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
@@ -37,9 +83,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
i
}
- val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
- d.count
- val WAIT_TIMEOUT_MILLIS = 10000
+ val d = sc.parallelize(0 to 1e4.toInt, 64).map{i => w(i)}
+ d.count()
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (1)
@@ -50,7 +95,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)}
d4.setName("A Cogroup")
- d4.collectAsMap
+ d4.collectAsMap()
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (4)
@@ -64,7 +109,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong},
stageInfo + " executorDeserializeTime")
- if (stageInfo.stage.rdd.name == d4.name) {
+ if (stageInfo.rddName == d4.name) {
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime},
stageInfo + " fetchWaitTime")
@@ -72,11 +117,11 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
stageInfo.taskInfos.foreach { case (taskInfo, taskMetrics) =>
taskMetrics.resultSize should be > (0l)
- if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
+ if (stageInfo.rddName == d2.name || stageInfo.rddName == d3.name) {
taskMetrics.shuffleWriteMetrics should be ('defined)
taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
}
- if (stageInfo.stage.rdd.name == d4.name) {
+ if (stageInfo.rddName == d4.name) {
taskMetrics.shuffleReadMetrics should be ('defined)
val sm = taskMetrics.shuffleReadMetrics.get
sm.totalBlocksFetched should be > (0)
@@ -89,20 +134,73 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
}
- def checkNonZeroAvg(m: Traversable[Long], msg: String) {
- assert(m.sum / m.size.toDouble > 0.0, msg)
+ test("onTaskGettingResult() called when result fetched remotely") {
+ // Need to use local cluster mode here, because results are not ever returned through the
+ // block manager when using the LocalScheduler.
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+
+ val listener = new SaveTaskEvents
+ sc.addSparkListener(listener)
+
+ // Make a task whose result is larger than the akka frame size
+ System.setProperty("spark.akka.frameSize", "1")
+ val akkaFrameSize =
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
+ val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x)
+ assert(result === 1.to(akkaFrameSize).toArray)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ val TASK_INDEX = 0
+ assert(listener.startedTasks.contains(TASK_INDEX))
+ assert(listener.startedGettingResultTasks.contains(TASK_INDEX))
+ assert(listener.endedTasks.contains(TASK_INDEX))
}
- def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
- val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
- !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
+ test("onTaskGettingResult() not called when result sent directly") {
+ // Need to use local cluster mode here, because results are not ever returned through the
+ // block manager when using the LocalScheduler.
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+
+ val listener = new SaveTaskEvents
+ sc.addSparkListener(listener)
+
+ // Make a task whose result is larger than the akka frame size
+ val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
+ assert(result === 2)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ val TASK_INDEX = 0
+ assert(listener.startedTasks.contains(TASK_INDEX))
+ assert(listener.startedGettingResultTasks.isEmpty == true)
+ assert(listener.endedTasks.contains(TASK_INDEX))
+ }
+
+ def checkNonZeroAvg(m: Traversable[Long], msg: String) {
+ assert(m.sum / m.size.toDouble > 0.0, msg)
}
class SaveStageInfo extends SparkListener {
- val stageInfos = mutable.Buffer[StageInfo]()
+ val stageInfos = Buffer[StageInfo]()
override def onStageCompleted(stage: StageCompleted) {
- stageInfos += stage.stageInfo
+ stageInfos += stage.stage
}
}
+ class SaveTaskEvents extends SparkListener {
+ val startedTasks = new HashSet[Int]()
+ val startedGettingResultTasks = new HashSet[Int]()
+ val endedTasks = new HashSet[Int]()
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ startedTasks += taskStart.taskInfo.index
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ endedTasks += taskEnd.taskInfo.index
+ }
+
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
+ startedGettingResultTasks += taskGettingResult.taskInfo.index
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
index 80d0c5a5e9..b97f2b19b5 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -28,6 +28,30 @@ import org.apache.spark.executor.TaskMetrics
import java.nio.ByteBuffer
import org.apache.spark.util.{Utils, FakeClock}
+class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ taskScheduler.startedTasks += taskInfo.index
+ }
+
+ override def taskEnded(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: mutable.Map[Long, Any],
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) {
+ taskScheduler.endedTasks(taskInfo.index) = reason
+ }
+
+ override def executorGained(execId: String, host: String) {}
+
+ override def executorLost(execId: String) {}
+
+ override def taskSetFailed(taskSet: TaskSet, reason: String) {
+ taskScheduler.taskSetsFailed += taskSet.id
+ }
+}
+
/**
* A mock ClusterScheduler implementation that just remembers information about tasks started and
* feedback received from the TaskSetManagers. Note that it's important to initialize this with
@@ -44,30 +68,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
val executors = new mutable.HashMap[String, String] ++ liveExecutors
- listener = new TaskSchedulerListener {
- def taskStarted(task: Task[_], taskInfo: TaskInfo) {
- startedTasks += taskInfo.index
- }
-
- def taskEnded(
- task: Task[_],
- reason: TaskEndReason,
- result: Any,
- accumUpdates: mutable.Map[Long, Any],
- taskInfo: TaskInfo,
- taskMetrics: TaskMetrics)
- {
- endedTasks(taskInfo.index) = reason
- }
-
- def executorGained(execId: String, host: String) {}
-
- def executorLost(execId: String) {}
-
- def taskSetFailed(taskSet: TaskSet, reason: String) {
- taskSetsFailed += taskSet.id
- }
- }
+ dagScheduler = new FakeDAGScheduler(this)
def removeExecutor(execId: String): Unit = executors -= execId
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
index 2f12aaed18..0f01515179 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
@@ -17,10 +17,11 @@
package org.apache.spark.scheduler.cluster
+import org.apache.spark.TaskContext
import org.apache.spark.scheduler.{TaskLocation, Task}
-class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) {
- override def run(attemptId: Long): Int = 0
+class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
+ override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
index a00198db8c..27c2d53361 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
@@ -23,6 +23,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv}
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
+import org.apache.spark.storage.TaskResultBlockId
/**
* Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
@@ -85,7 +86,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
assert(result === 1.to(akkaFrameSize).toArray)
- val RESULT_BLOCK_ID = "taskresult_0"
+ val RESULT_BLOCK_ID = TaskResultBlockId(0)
assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0,
"Expect result to be removed from the block manager.")
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
index af76c843e8..1e676c1719 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
@@ -17,17 +17,15 @@
package org.apache.spark.scheduler.local
-import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark._
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.{ConcurrentMap, HashMap}
import java.util.concurrent.Semaphore
import java.util.concurrent.CountDownLatch
-import java.util.Properties
+
+import scala.collection.mutable.HashMap
+
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+
+import org.apache.spark._
+
class Lock() {
var finished = false
@@ -63,7 +61,12 @@ object TaskThreadInfo {
* 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
* thus it will be scheduled later when cluster has free cpu cores.
*/
-class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
+class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach {
+
+ override def afterEach() {
+ super.afterEach()
+ System.clearProperty("spark.scheduler.mode")
+ }
def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
@@ -148,12 +151,13 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
}
test("Local fair scheduler end-to-end test") {
- sc = new SparkContext("local[8]", "LocalSchedulerSuite")
- val sem = new Semaphore(0)
System.setProperty("spark.scheduler.mode", "FAIR")
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
System.setProperty("spark.scheduler.allocation.file", xmlPath)
+ sc = new SparkContext("local[8]", "LocalSchedulerSuite")
+ val sem = new Semaphore(0)
+
createThread(10,"1",sc,sem)
TaskThreadInfo.threadToStarted(10).await()
createThread(20,"2",sc,sem)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
new file mode 100644
index 0000000000..cb76275e39
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.scalatest.FunSuite
+
+class BlockIdSuite extends FunSuite {
+ def assertSame(id1: BlockId, id2: BlockId) {
+ assert(id1.name === id2.name)
+ assert(id1.hashCode === id2.hashCode)
+ assert(id1 === id2)
+ }
+
+ def assertDifferent(id1: BlockId, id2: BlockId) {
+ assert(id1.name != id2.name)
+ assert(id1.hashCode != id2.hashCode)
+ assert(id1 != id2)
+ }
+
+ test("test-bad-deserialization") {
+ try {
+ // Try to deserialize an invalid block id.
+ BlockId("myblock")
+ fail()
+ } catch {
+ case e: IllegalStateException => // OK
+ case _ => fail()
+ }
+ }
+
+ test("rdd") {
+ val id = RDDBlockId(1, 2)
+ assertSame(id, RDDBlockId(1, 2))
+ assertDifferent(id, RDDBlockId(1, 1))
+ assert(id.name === "rdd_1_2")
+ assert(id.asRDDId.get.rddId === 1)
+ assert(id.asRDDId.get.splitIndex === 2)
+ assert(id.isRDD)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("shuffle") {
+ val id = ShuffleBlockId(1, 2, 3)
+ assertSame(id, ShuffleBlockId(1, 2, 3))
+ assertDifferent(id, ShuffleBlockId(3, 2, 3))
+ assert(id.name === "shuffle_1_2_3")
+ assert(id.asRDDId === None)
+ assert(id.shuffleId === 1)
+ assert(id.mapId === 2)
+ assert(id.reduceId === 3)
+ assert(id.isShuffle)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("broadcast") {
+ val id = BroadcastBlockId(42)
+ assertSame(id, BroadcastBlockId(42))
+ assertDifferent(id, BroadcastBlockId(123))
+ assert(id.name === "broadcast_42")
+ assert(id.asRDDId === None)
+ assert(id.broadcastId === 42)
+ assert(id.isBroadcast)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("taskresult") {
+ val id = TaskResultBlockId(60)
+ assertSame(id, TaskResultBlockId(60))
+ assertDifferent(id, TaskResultBlockId(61))
+ assert(id.name === "taskresult_60")
+ assert(id.asRDDId === None)
+ assert(id.taskId === 60)
+ assert(!id.isRDD)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("stream") {
+ val id = StreamBlockId(1, 100)
+ assertSame(id, StreamBlockId(1, 100))
+ assertDifferent(id, StreamBlockId(2, 101))
+ assert(id.name === "input-1-100")
+ assert(id.asRDDId === None)
+ assert(id.streamId === 1)
+ assert(id.uniqueId === 100)
+ assert(!id.isBroadcast)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("test") {
+ val id = TestBlockId("abc")
+ assertSame(id, TestBlockId("abc"))
+ assertDifferent(id, TestBlockId("ab"))
+ assert(id.name === "test_abc")
+ assert(id.asRDDId === None)
+ assert(id.id === "abc")
+ assert(!id.isShuffle)
+ assertSame(id, BlockId(id.toString))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 038a9acb85..484a654108 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -32,7 +32,6 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.util.{SizeEstimator, Utils, AkkaUtils, ByteBufferInputStream}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
-
class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester {
var store: BlockManager = null
var store2: BlockManager = null
@@ -46,6 +45,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
System.setProperty("spark.kryoserializer.buffer.mb", "1")
val serializer = new KryoSerializer
+ // Implicitly convert strings to BlockIds for test clarity.
+ implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
+ def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
+
before {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0)
this.actorSystem = actorSystem
@@ -229,31 +232,31 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
// Putting a1, a2 and a3 in memory.
- store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY)
store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY)
master.removeRdd(0, blocking = false)
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
- store.getSingle("rdd_0_0") should be (None)
- master.getLocations("rdd_0_0") should have size 0
+ store.getSingle(rdd(0, 0)) should be (None)
+ master.getLocations(rdd(0, 0)) should have size 0
}
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
- store.getSingle("rdd_0_1") should be (None)
- master.getLocations("rdd_0_1") should have size 0
+ store.getSingle(rdd(0, 1)) should be (None)
+ master.getLocations(rdd(0, 1)) should have size 0
}
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
store.getSingle("nonrddblock") should not be (None)
master.getLocations("nonrddblock") should have size (1)
}
- store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY)
master.removeRdd(0, blocking = true)
- store.getSingle("rdd_0_0") should be (None)
- master.getLocations("rdd_0_0") should have size 0
- store.getSingle("rdd_0_1") should be (None)
- master.getLocations("rdd_0_1") should have size 0
+ store.getSingle(rdd(0, 0)) should be (None)
+ master.getLocations(rdd(0, 0)) should have size 0
+ store.getSingle(rdd(0, 1)) should be (None)
+ master.getLocations(rdd(0, 1)) should have size 0
}
test("reregistration on heart beat") {
@@ -372,41 +375,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
- store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 2), a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY)
// Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2
// from the same RDD
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
- assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
- assert(store.getSingle("rdd_0_1") != None, "rdd_0_1 was not in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store")
+ assert(store.getSingle(rdd(0, 1)) != None, "rdd_0_1 was not in store")
// Check that rdd_0_3 doesn't replace them even after further accesses
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
}
test("in-memory LRU for partitions of multiple RDDs") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
- store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
// At this point rdd_1_1 should've replaced rdd_0_1
- assert(store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was not in store")
- assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store")
- assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store")
+ assert(store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was not in store")
+ assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store")
+ assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store")
// Do a get() on rdd_0_2 so that it is the most recently used item
- assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
+ assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store")
// Put in more partitions from RDD 0; they should replace rdd_1_1
- store.putSingle("rdd_0_3", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_4", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 3), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 4), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
// Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped
// when we try to add rdd_0_4.
- assert(!store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was in store")
- assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store")
- assert(!store.memoryStore.contains("rdd_0_4"), "rdd_0_4 was in store")
- assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store")
- assert(store.memoryStore.contains("rdd_0_3"), "rdd_0_3 was not in store")
+ assert(!store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was in store")
+ assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store")
+ assert(!store.memoryStore.contains(rdd(0, 4)), "rdd_0_4 was in store")
+ assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store")
+ assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store")
}
test("on-disk storage") {
@@ -590,43 +593,46 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
try {
System.setProperty("spark.shuffle.compress", "true")
store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
- store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed")
+ store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100,
+ "shuffle_0_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.shuffle.compress", "false")
store = new BlockManager("exec2", actorSystem, master, serializer, 2000)
- store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed")
+ store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000,
+ "shuffle_0_0_0 was compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "true")
store = new BlockManager("exec3", actorSystem, master, serializer, 2000)
- store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed")
+ store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100,
+ "broadcast_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "false")
store = new BlockManager("exec4", actorSystem, master, serializer, 2000)
- store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed")
+ store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "true")
store = new BlockManager("exec5", actorSystem, master, serializer, 2000)
- store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed")
+ store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "false")
store = new BlockManager("exec6", actorSystem, master, serializer, 2000)
- store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed")
+ store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
new file mode 100644
index 0000000000..0b9056344c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -0,0 +1,84 @@
+package org.apache.spark.storage
+
+import java.io.{FileWriter, File}
+
+import scala.collection.mutable
+
+import com.google.common.io.Files
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+
+class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach {
+
+ val rootDir0 = Files.createTempDir()
+ rootDir0.deleteOnExit()
+ val rootDir1 = Files.createTempDir()
+ rootDir1.deleteOnExit()
+ val rootDirs = rootDir0.getName + "," + rootDir1.getName
+ println("Created root dirs: " + rootDirs)
+
+ val shuffleBlockManager = new ShuffleBlockManager(null) {
+ var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]()
+ override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id)
+ }
+
+ var diskBlockManager: DiskBlockManager = _
+
+ override def beforeEach() {
+ diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs)
+ shuffleBlockManager.idToSegmentMap.clear()
+ }
+
+ test("basic block creation") {
+ val blockId = new TestBlockId("test")
+ assertSegmentEquals(blockId, blockId.name, 0, 0)
+
+ val newFile = diskBlockManager.getFile(blockId)
+ writeToFile(newFile, 10)
+ assertSegmentEquals(blockId, blockId.name, 0, 10)
+
+ newFile.delete()
+ }
+
+ test("block appending") {
+ val blockId = new TestBlockId("test")
+ val newFile = diskBlockManager.getFile(blockId)
+ writeToFile(newFile, 15)
+ assertSegmentEquals(blockId, blockId.name, 0, 15)
+ val newFile2 = diskBlockManager.getFile(blockId)
+ assert(newFile === newFile2)
+ writeToFile(newFile2, 12)
+ assertSegmentEquals(blockId, blockId.name, 0, 27)
+ newFile.delete()
+ }
+
+ test("block remapping") {
+ val filename = "test"
+ val blockId0 = new ShuffleBlockId(1, 2, 3)
+ val newFile = diskBlockManager.getFile(filename)
+ writeToFile(newFile, 15)
+ shuffleBlockManager.idToSegmentMap(blockId0) = new FileSegment(newFile, 0, 15)
+ assertSegmentEquals(blockId0, filename, 0, 15)
+
+ val blockId1 = new ShuffleBlockId(1, 2, 4)
+ val newFile2 = diskBlockManager.getFile(filename)
+ writeToFile(newFile2, 12)
+ shuffleBlockManager.idToSegmentMap(blockId1) = new FileSegment(newFile, 15, 12)
+ assertSegmentEquals(blockId1, filename, 15, 12)
+
+ assert(newFile === newFile2)
+ newFile.delete()
+ }
+
+ def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) {
+ val segment = diskBlockManager.getBlockLocation(blockId)
+ assert(segment.file.getName === filename)
+ assert(segment.offset === offset)
+ assert(segment.length === length)
+ }
+
+ def writeToFile(file: File, numBytes: Int) {
+ val writer = new FileWriter(file, true)
+ for (i <- 0 until numBytes) writer.write(i)
+ writer.close()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala
new file mode 100644
index 0000000000..7177919a58
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.collection.mutable.HashSet
+
+import org.scalatest.FunSuite
+
+class AppendOnlyMapSuite extends FunSuite {
+ test("initialization") {
+ val goodMap1 = new AppendOnlyMap[Int, Int](1)
+ assert(goodMap1.size === 0)
+ val goodMap2 = new AppendOnlyMap[Int, Int](255)
+ assert(goodMap2.size === 0)
+ val goodMap3 = new AppendOnlyMap[Int, Int](256)
+ assert(goodMap3.size === 0)
+ intercept[IllegalArgumentException] {
+ new AppendOnlyMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29
+ }
+ intercept[IllegalArgumentException] {
+ new AppendOnlyMap[Int, Int](-1)
+ }
+ intercept[IllegalArgumentException] {
+ new AppendOnlyMap[Int, Int](0)
+ }
+ }
+
+ test("object keys and values") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map("" + i) === "" + i)
+ }
+ assert(map("0") === null)
+ assert(map("101") === null)
+ assert(map(null) === null)
+ val set = new HashSet[(String, String)]
+ for ((k, v) <- map) { // Test the foreach method
+ set += ((k, v))
+ }
+ assert(set === (1 to 100).map(_.toString).map(x => (x, x)).toSet)
+ }
+
+ test("primitive keys and values") {
+ val map = new AppendOnlyMap[Int, Int]()
+ for (i <- 1 to 100) {
+ map(i) = i
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map(i) === i)
+ }
+ assert(map(0) === null)
+ assert(map(101) === null)
+ val set = new HashSet[(Int, Int)]
+ for ((k, v) <- map) { // Test the foreach method
+ set += ((k, v))
+ }
+ assert(set === (1 to 100).map(x => (x, x)).toSet)
+ }
+
+ test("null keys") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ assert(map.size === 100)
+ assert(map(null) === null)
+ map(null) = "hello"
+ assert(map.size === 101)
+ assert(map(null) === "hello")
+ }
+
+ test("null values") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = null
+ }
+ assert(map.size === 100)
+ assert(map("1") === null)
+ assert(map(null) === null)
+ assert(map.size === 100)
+ map(null) = null
+ assert(map.size === 101)
+ assert(map(null) === null)
+ }
+
+ test("changeValue") {
+ val map = new AppendOnlyMap[String, String]()
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ val res = map.changeValue("" + i, (hadValue, oldValue) => {
+ assert(hadValue === true)
+ assert(oldValue === "" + i)
+ oldValue + "!"
+ })
+ assert(res === i + "!")
+ }
+ // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
+ // bug where changeValue would return the wrong result when the map grew on that insert
+ for (i <- 101 to 400) {
+ val res = map.changeValue("" + i, (hadValue, oldValue) => {
+ assert(hadValue === false)
+ i + "!"
+ })
+ assert(res === i + "!")
+ }
+ assert(map.size === 400)
+ assert(map(null) === null)
+ map.changeValue(null, (hadValue, oldValue) => {
+ assert(hadValue === false)
+ "null!"
+ })
+ assert(map.size === 401)
+ map.changeValue(null, (hadValue, oldValue) => {
+ assert(hadValue === true)
+ assert(oldValue === "null!")
+ "null!!"
+ })
+ assert(map.size === 401)
+ }
+
+ test("inserting in capacity-1 map") {
+ val map = new AppendOnlyMap[String, String](1)
+ for (i <- 1 to 100) {
+ map("" + i) = "" + i
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map("" + i) === "" + i)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
new file mode 100644
index 0000000000..0f1ab3d20e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+
+
+class BitSetSuite extends FunSuite {
+
+ test("basic set and get") {
+ val setBits = Seq(0, 9, 1, 10, 90, 96)
+ val bitset = new BitSet(100)
+
+ for (i <- 0 until 100) {
+ assert(!bitset.get(i))
+ }
+
+ setBits.foreach(i => bitset.set(i))
+
+ for (i <- 0 until 100) {
+ if (setBits.contains(i)) {
+ assert(bitset.get(i))
+ } else {
+ assert(!bitset.get(i))
+ }
+ }
+ assert(bitset.cardinality() === setBits.size)
+ }
+
+ test("100% full bit set") {
+ val bitset = new BitSet(10000)
+ for (i <- 0 until 10000) {
+ assert(!bitset.get(i))
+ bitset.set(i)
+ }
+ for (i <- 0 until 10000) {
+ assert(bitset.get(i))
+ }
+ assert(bitset.cardinality() === 10000)
+ }
+
+ test("nextSetBit") {
+ val setBits = Seq(0, 9, 1, 10, 90, 96)
+ val bitset = new BitSet(100)
+ setBits.foreach(i => bitset.set(i))
+
+ assert(bitset.nextSetBit(0) === 0)
+ assert(bitset.nextSetBit(1) === 1)
+ assert(bitset.nextSetBit(2) === 9)
+ assert(bitset.nextSetBit(9) === 9)
+ assert(bitset.nextSetBit(10) === 10)
+ assert(bitset.nextSetBit(11) === 90)
+ assert(bitset.nextSetBit(80) === 90)
+ assert(bitset.nextSetBit(91) === 96)
+ assert(bitset.nextSetBit(96) === 96)
+ assert(bitset.nextSetBit(97) === -1)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
new file mode 100644
index 0000000000..ca3f684668
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -0,0 +1,148 @@
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.HashSet
+import org.scalatest.FunSuite
+
+class OpenHashMapSuite extends FunSuite {
+
+ test("initialization") {
+ val goodMap1 = new OpenHashMap[String, Int](1)
+ assert(goodMap1.size === 0)
+ val goodMap2 = new OpenHashMap[String, Int](255)
+ assert(goodMap2.size === 0)
+ val goodMap3 = new OpenHashMap[String, String](256)
+ assert(goodMap3.size === 0)
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, Int](1 << 30) // Invalid map size: bigger than 2^29
+ }
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, Int](-1)
+ }
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, String](0)
+ }
+ }
+
+ test("primitive value") {
+ val map = new OpenHashMap[String, Int]
+
+ for (i <- 1 to 1000) {
+ map(i.toString) = i
+ assert(map(i.toString) === i)
+ }
+
+ assert(map.size === 1000)
+ assert(map(null) === 0)
+
+ map(null) = -1
+ assert(map.size === 1001)
+ assert(map(null) === -1)
+
+ for (i <- 1 to 1000) {
+ assert(map(i.toString) === i)
+ }
+
+ // Test iterator
+ val set = new HashSet[(String, Int)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ val expected = (1 to 1000).map(x => (x.toString, x)) :+ (null.asInstanceOf[String], -1)
+ assert(set === expected.toSet)
+ }
+
+ test("non-primitive value") {
+ val map = new OpenHashMap[String, String]
+
+ for (i <- 1 to 1000) {
+ map(i.toString) = i.toString
+ assert(map(i.toString) === i.toString)
+ }
+
+ assert(map.size === 1000)
+ assert(map(null) === null)
+
+ map(null) = "-1"
+ assert(map.size === 1001)
+ assert(map(null) === "-1")
+
+ for (i <- 1 to 1000) {
+ assert(map(i.toString) === i.toString)
+ }
+
+ // Test iterator
+ val set = new HashSet[(String, String)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ val expected = (1 to 1000).map(_.toString).map(x => (x, x)) :+ (null.asInstanceOf[String], "-1")
+ assert(set === expected.toSet)
+ }
+
+ test("null keys") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ assert(map(null) === null)
+ map(null) = "hello"
+ assert(map.size === 101)
+ assert(map(null) === "hello")
+ }
+
+ test("null values") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = null
+ }
+ assert(map.size === 100)
+ assert(map("1") === null)
+ assert(map(null) === null)
+ assert(map.size === 100)
+ map(null) = null
+ assert(map.size === 101)
+ assert(map(null) === null)
+ }
+
+ test("changeValue") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ val res = map.changeValue(i.toString, { assert(false); "" }, v => {
+ assert(v === i.toString)
+ v + "!"
+ })
+ assert(res === i + "!")
+ }
+ // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
+ // bug where changeValue would return the wrong result when the map grew on that insert
+ for (i <- 101 to 400) {
+ val res = map.changeValue(i.toString, { i + "!" }, v => { assert(false); v })
+ assert(res === i + "!")
+ }
+ assert(map.size === 400)
+ assert(map(null) === null)
+ map.changeValue(null, { "null!" }, v => { assert(false); v })
+ assert(map.size === 401)
+ map.changeValue(null, { assert(false); "" }, v => {
+ assert(v === "null!")
+ "null!!"
+ })
+ assert(map.size === 401)
+ }
+
+ test("inserting in capacity-1 map") {
+ val map = new OpenHashMap[String, String](1)
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map(i.toString) === i.toString)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
new file mode 100644
index 0000000000..4e11e8a628
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
@@ -0,0 +1,145 @@
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+
+
+class OpenHashSetSuite extends FunSuite {
+
+ test("primitive int") {
+ val set = new OpenHashSet[Int]
+ assert(set.size === 0)
+ assert(!set.contains(10))
+ assert(!set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(10)
+ assert(set.contains(10))
+ assert(!set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(50)
+ assert(set.size === 2)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(999)
+ assert(set.size === 3)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(50)
+ assert(set.size === 3)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(set.contains(999))
+ assert(!set.contains(10000))
+ }
+
+ test("primitive long") {
+ val set = new OpenHashSet[Long]
+ assert(set.size === 0)
+ assert(!set.contains(10L))
+ assert(!set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(10L)
+ assert(set.size === 1)
+ assert(set.contains(10L))
+ assert(!set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(50L)
+ assert(set.size === 2)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(999L)
+ assert(set.size === 3)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(50L)
+ assert(set.size === 3)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(set.contains(999L))
+ assert(!set.contains(10000L))
+ }
+
+ test("non-primitive") {
+ val set = new OpenHashSet[String]
+ assert(set.size === 0)
+ assert(!set.contains(10.toString))
+ assert(!set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(10.toString)
+ assert(set.size === 1)
+ assert(set.contains(10.toString))
+ assert(!set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(50.toString)
+ assert(set.size === 2)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(999.toString)
+ assert(set.size === 3)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(50.toString)
+ assert(set.size === 3)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+ }
+
+ test("non-primitive set growth") {
+ val set = new OpenHashSet[String]
+ for (i <- 1 to 1000) {
+ set.add(i.toString)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ for (i <- 1 to 100) {
+ set.add(i.toString)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ }
+
+ test("primitive set growth") {
+ val set = new OpenHashSet[Long]
+ for (i <- 1 to 1000) {
+ set.add(i.toLong)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ for (i <- 1 to 100) {
+ set.add(i.toLong)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala
new file mode 100644
index 0000000000..dfd6aed2c4
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala
@@ -0,0 +1,90 @@
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.HashSet
+import org.scalatest.FunSuite
+
+class PrimitiveKeyOpenHashSetSuite extends FunSuite {
+
+ test("initialization") {
+ val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1)
+ assert(goodMap1.size === 0)
+ val goodMap2 = new PrimitiveKeyOpenHashMap[Int, Int](255)
+ assert(goodMap2.size === 0)
+ val goodMap3 = new PrimitiveKeyOpenHashMap[Int, Int](256)
+ assert(goodMap3.size === 0)
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29
+ }
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](-1)
+ }
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](0)
+ }
+ }
+
+ test("basic operations") {
+ val longBase = 1000000L
+ val map = new PrimitiveKeyOpenHashMap[Long, Int]
+
+ for (i <- 1 to 1000) {
+ map(i + longBase) = i
+ assert(map(i + longBase) === i)
+ }
+
+ assert(map.size === 1000)
+
+ for (i <- 1 to 1000) {
+ assert(map(i + longBase) === i)
+ }
+
+ // Test iterator
+ val set = new HashSet[(Long, Int)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ assert(set === (1 to 1000).map(x => (x + longBase, x)).toSet)
+ }
+
+ test("null values") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String]()
+ for (i <- 1 to 100) {
+ map(i.toLong) = null
+ }
+ assert(map.size === 100)
+ assert(map(1.toLong) === null)
+ }
+
+ test("changeValue") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String]()
+ for (i <- 1 to 100) {
+ map(i.toLong) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ val res = map.changeValue(i.toLong, { assert(false); "" }, v => {
+ assert(v === i.toString)
+ v + "!"
+ })
+ assert(res === i + "!")
+ }
+ // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
+ // bug where changeValue would return the wrong result when the map grew on that insert
+ for (i <- 101 to 400) {
+ val res = map.changeValue(i.toLong, { i + "!" }, v => { assert(false); v })
+ assert(res === i + "!")
+ }
+ assert(map.size === 400)
+ }
+
+ test("inserting in capacity-1 map") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String](1)
+ for (i <- 1 to 100) {
+ map(i.toLong) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map(i.toLong) === i.toString)
+ }
+ }
+}
diff --git a/docker/README.md b/docker/README.md
new file mode 100644
index 0000000000..bf59e77d11
--- /dev/null
+++ b/docker/README.md
@@ -0,0 +1,5 @@
+Spark docker files
+===========
+
+Drawn from Matt Massie's docker files (https://github.com/massie/dockerfiles),
+as well as some updates from Andre Schumacher (https://github.com/AndreSchumacher/docker). \ No newline at end of file
diff --git a/docker/build b/docker/build
new file mode 100755
index 0000000000..253a2fc8dd
--- /dev/null
+++ b/docker/build
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+docker images > /dev/null || { echo Please install docker in non-sudo mode. ; exit; }
+
+./spark-test/build \ No newline at end of file
diff --git a/docker/spark-test/README.md b/docker/spark-test/README.md
new file mode 100644
index 0000000000..ec0baf6e6d
--- /dev/null
+++ b/docker/spark-test/README.md
@@ -0,0 +1,11 @@
+Spark Docker files usable for testing and development purposes.
+
+These images are intended to be run like so:
+
+ docker run -v $SPARK_HOME:/opt/spark spark-test-master
+ docker run -v $SPARK_HOME:/opt/spark spark-test-worker spark://<master_ip>:7077
+
+Using this configuration, the containers will have their Spark directories
+mounted to your actual `SPARK_HOME`, allowing you to modify and recompile
+your Spark source and have them immediately usable in the docker images
+(without rebuilding them).
diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile
new file mode 100644
index 0000000000..60962776dd
--- /dev/null
+++ b/docker/spark-test/base/Dockerfile
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+FROM ubuntu:precise
+
+RUN echo "deb http://archive.ubuntu.com/ubuntu precise main universe" > /etc/apt/sources.list
+
+# Upgrade package index
+RUN apt-get update
+
+# install a few other useful packages plus Open Jdk 7
+RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server
+
+ENV SCALA_VERSION 2.9.3
+ENV SPARK_VERSION 0.8.1
+ENV CDH_VERSION cdh4
+ENV SCALA_HOME /opt/scala-$SCALA_VERSION
+ENV SPARK_HOME /opt/spark
+ENV PATH $SPARK_HOME:$SCALA_HOME/bin:$PATH
+
+# Install Scala
+ADD http://www.scala-lang.org/files/archive/scala-$SCALA_VERSION.tgz /
+RUN (cd / && gunzip < scala-$SCALA_VERSION.tgz)|(cd /opt && tar -xvf -)
+RUN rm /scala-$SCALA_VERSION.tgz
diff --git a/docker/spark-test/build b/docker/spark-test/build
new file mode 100755
index 0000000000..6f9e197433
--- /dev/null
+++ b/docker/spark-test/build
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+docker build -t spark-test-base spark-test/base/
+docker build -t spark-test-master spark-test/master/
+docker build -t spark-test-worker spark-test/worker/
diff --git a/docker/spark-test/master/Dockerfile b/docker/spark-test/master/Dockerfile
new file mode 100644
index 0000000000..f729534ab6
--- /dev/null
+++ b/docker/spark-test/master/Dockerfile
@@ -0,0 +1,21 @@
+# Spark Master
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+FROM spark-test-base
+ADD default_cmd /root/
+CMD ["/root/default_cmd"]
diff --git a/docker/spark-test/master/default_cmd b/docker/spark-test/master/default_cmd
new file mode 100755
index 0000000000..a5b1303c2e
--- /dev/null
+++ b/docker/spark-test/master/default_cmd
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+IP=$(ip -o -4 addr list eth0 | perl -n -e 'if (m{inet\s([\d\.]+)\/\d+\s}xms) { print $1 }')
+echo "CONTAINER_IP=$IP"
+/opt/spark/spark-class org.apache.spark.deploy.master.Master -i $IP
diff --git a/docker/spark-test/worker/Dockerfile b/docker/spark-test/worker/Dockerfile
new file mode 100644
index 0000000000..890febe7b6
--- /dev/null
+++ b/docker/spark-test/worker/Dockerfile
@@ -0,0 +1,22 @@
+# Spark Worker
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+FROM spark-test-base
+ENV SPARK_WORKER_PORT 8888
+ADD default_cmd /root/
+ENTRYPOINT ["/root/default_cmd"]
diff --git a/docker/spark-test/worker/default_cmd b/docker/spark-test/worker/default_cmd
new file mode 100755
index 0000000000..ab6336f70c
--- /dev/null
+++ b/docker/spark-test/worker/default_cmd
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+IP=$(ip -o -4 addr list eth0 | perl -n -e 'if (m{inet\s([\d\.]+)\/\d+\s}xms) { print $1 }')
+echo "CONTAINER_IP=$IP"
+/opt/spark/spark-class org.apache.spark.deploy.worker.Worker $1
diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md
index f679cad713..5927f736f3 100644
--- a/docs/cluster-overview.md
+++ b/docs/cluster-overview.md
@@ -13,7 +13,7 @@ object in your main program (called the _driver program_).
Specifically, to run on a cluster, the SparkContext can connect to several types of _cluster managers_
(either Spark's own standalone cluster manager or Mesos/YARN), which allocate resources across
applications. Once connected, Spark acquires *executors* on nodes in the cluster, which are
-worker processes that run computations and store data for your application.
+worker processes that run computations and store data for your application.
Next, it sends your application code (defined by JAR or Python files passed to SparkContext) to
the executors. Finally, SparkContext sends *tasks* for the executors to run.
@@ -57,6 +57,18 @@ which takes a list of JAR files (Java/Scala) or .egg and .zip libraries (Python)
worker nodes. You can also dynamically add new files to be sent to executors with `SparkContext.addJar`
and `addFile`.
+## URIs for addJar / addFile
+
+- **file:** - Absolute paths and `file:/` URIs are served by the driver's HTTP file server, and every executor
+ pulls the file from the driver HTTP server
+- **hdfs:**, **http:**, **https:**, **ftp:** - these pull down files and JARs from the URI as expected
+- **local:** - a URI starting with local:/ is expected to exist as a local file on each worker node. This
+ means that no network IO will be incurred, and works well for large files/JARs that are pushed to each worker,
+ or shared via NFS, GlusterFS, etc.
+
+Note that JARs and files are copied to the working directory for each SparkContext on the executor nodes.
+Over time this can use up a significant amount of space and will need to be cleaned up.
+
# Monitoring
Each driver program has a web UI, typically on port 4040, that displays information about running
diff --git a/docs/configuration.md b/docs/configuration.md
index 7940d41a27..97183bafdb 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -149,7 +149,7 @@ Apart from these, the following properties are also available, and may be useful
<td>spark.io.compression.codec</td>
<td>org.apache.spark.io.<br />LZFCompressionCodec</td>
<td>
- The compression codec class to use for various compressions. By default, Spark provides two
+ The codec used to compress internal data such as RDD partitions and shuffle outputs. By default, Spark provides two
codecs: <code>org.apache.spark.io.LZFCompressionCodec</code> and <code>org.apache.spark.io.SnappyCompressionCodec</code>.
</td>
</tr>
@@ -319,6 +319,14 @@ Apart from these, the following properties are also available, and may be useful
Should be greater than or equal to 1. Number of allowed retries = this value - 1.
</td>
</tr>
+<tr>
+ <td>spark.broadcast.blockSize</td>
+ <td>4096</td>
+ <td>
+ Size of each piece of a block in kilobytes for <code>TorrentBroadcastFactory</code>.
+ Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, <code>BlockManager</code> might take a performance hit.
+ </td>
+</tr>
</table>
diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md
index 1e5575d657..156a727026 100644
--- a/docs/ec2-scripts.md
+++ b/docs/ec2-scripts.md
@@ -98,7 +98,7 @@ permissions on your private key file, you can run `launch` with the
`bin/hadoop` script in that directory. Note that the data in this
HDFS goes away when you stop and restart a machine.
- There is also a *persistent HDFS* instance in
- `/root/presistent-hdfs` that will keep data across cluster restarts.
+ `/root/persistent-hdfs` that will keep data across cluster restarts.
Typically each node has relatively little space of persistent data
(about 3 GB), but you can use the `--ebs-vol-size` option to
`spark-ec2` to attach a persistent EBS volume to each node for
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index 6c2336ad0c..55e39b1de1 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -131,6 +131,17 @@ sc = SparkContext("local", "App Name", pyFiles=['MyFile.py', 'lib.zip', 'app.egg
Files listed here will be added to the `PYTHONPATH` and shipped to remote worker machines.
Code dependencies can be added to an existing SparkContext using its `addPyFile()` method.
+You can set [system properties](configuration.html#system-properties)
+using `SparkContext.setSystemProperty()` class method *before*
+instantiating SparkContext. For example, to set the amount of memory
+per executor process:
+
+{% highlight python %}
+from pyspark import SparkContext
+SparkContext.setSystemProperty('spark.executor.memory', '2g')
+sc = SparkContext("local", "App Name")
+{% endhighlight %}
+
# API Docs
[API documentation](api/pyspark/index.html) for PySpark is available as Epydoc.
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 30128ec45d..2898af0bed 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -34,6 +34,8 @@ Environment variables:
System Properties:
* 'spark.yarn.applicationMaster.waitTries', property to set the number of times the ApplicationMaster waits for the the spark master and then also the number of tries it waits for the Spark Context to be intialized. Default is 10.
+* 'spark.yarn.submit.file.replication', the HDFS replication level for the files uploaded into HDFS for the application. These include things like the spark jar, the app jar, and any distributed cache files/archives.
+* 'spark.yarn.preserve.staging.files', set to true to preserve the staged files(spark jar, app jar, distributed cache files) at the end of the job rather then delete them.
# Launching Spark on YARN
@@ -51,7 +53,10 @@ The command to launch the YARN Client is as follows:
--worker-memory <MEMORY_PER_WORKER> \
--worker-cores <CORES_PER_WORKER> \
--name <application_name> \
- --queue <queue_name>
+ --queue <queue_name> \
+ --addJars <any_local_files_used_in_SparkContext.addJar> \
+ --files <files_for_distributed_cache> \
+ --archives <archives_for_distributed_cache>
For example:
@@ -84,3 +89,5 @@ The above starts a YARN Client programs which periodically polls the Application
- When your application instantiates a Spark context it must use a special "yarn-standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "yarn-standalone" as an argument to your program, as shown in the example above.
- We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed.
- The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored.
+- The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN.
+- The --addJars option allows the SparkContext.addJar function to work if you are using it with local files. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files.
diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md
index 03647a2ad2..94e8563a8b 100644
--- a/docs/scala-programming-guide.md
+++ b/docs/scala-programming-guide.md
@@ -142,7 +142,7 @@ All transformations in Spark are <i>lazy</i>, in that they do not compute their
By default, each transformed RDD is recomputed each time you run an action on it. However, you may also *persist* an RDD in memory using the `persist` (or `cache`) method, in which case Spark will keep the elements around on the cluster for much faster access the next time you query it. There is also support for persisting datasets on disk, or replicated across the cluster. The next section in this document describes these options.
-The following tables list the transformations and actions currently supported (see also the [RDD API doc](api/core/index.html#org.apache.spark.RDD) for details):
+The following tables list the transformations and actions currently supported (see also the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD) for details):
### Transformations
@@ -211,7 +211,7 @@ The following tables list the transformations and actions currently supported (s
</tr>
</table>
-A complete list of transformations is available in the [RDD API doc](api/core/index.html#org.apache.spark.RDD).
+A complete list of transformations is available in the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD).
### Actions
@@ -259,7 +259,7 @@ A complete list of transformations is available in the [RDD API doc](api/core/in
</tr>
</table>
-A complete list of actions is available in the [RDD API doc](api/core/index.html#org.apache.spark.RDD).
+A complete list of actions is available in the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD).
## RDD Persistence
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 81cdbefd0c..17066ef0dd 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -3,6 +3,9 @@ layout: global
title: Spark Standalone Mode
---
+* This will become a table of contents (this text will be scraped).
+{:toc}
+
In addition to running on the Mesos or YARN cluster managers, Spark also provides a simple standalone deploy mode. You can launch a standalone cluster either manually, by starting a master and workers by hand, or use our provided [launch scripts](#cluster-launch-scripts). It is also possible to run these daemons on a single machine for testing.
# Installing Spark Standalone to a Cluster
@@ -169,3 +172,75 @@ In addition, detailed log output for each job is also written to the work direct
You can run Spark alongside your existing Hadoop cluster by just launching it as a separate service on the same machines. To access Hadoop data from Spark, just use a hdfs:// URL (typically `hdfs://<namenode>:9000/path`, but you can find the right URL on your Hadoop Namenode's web UI). Alternatively, you can set up a separate cluster for Spark, and still have it access HDFS over the network; this will be slower than disk-local access, but may not be a concern if you are still running in the same local area network (e.g. you place a few Spark machines on each rack that you have Hadoop on).
+
+# High Availability
+
+By default, standalone scheduling clusters are resilient to Worker failures (insofar as Spark itself is resilient to losing work by moving it to other workers). However, the scheduler uses a Master to make scheduling decisions, and this (by default) creates a single point of failure: if the Master crashes, no new applications can be created. In order to circumvent this, we have two high availability schemes, detailed below.
+
+## Standby Masters with ZooKeeper
+
+**Overview**
+
+Utilizing ZooKeeper to provide leader election and some state storage, you can launch multiple Masters in your cluster connected to the same ZooKeeper instance. One will be elected "leader" and the others will remain in standby mode. If the current leader dies, another Master will be elected, recover the old Master's state, and then resume scheduling. The entire recovery process (from the time the the first leader goes down) should take between 1 and 2 minutes. Note that this delay only affects scheduling _new_ applications -- applications that were already running during Master failover are unaffected.
+
+Learn more about getting started with ZooKeeper [here](http://zookeeper.apache.org/doc/trunk/zookeeperStarted.html).
+
+**Configuration**
+
+In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env using this configuration:
+
+<table class="table">
+ <tr><th style="width:21%">System property</th><th>Meaning</th></tr>
+ <tr>
+ <td><code>spark.deploy.recoveryMode</code></td>
+ <td>Set to ZOOKEEPER to enable standby Master recovery mode (default: NONE).</td>
+ </tr>
+ <tr>
+ <td><code>spark.deploy.zookeeper.url</code></td>
+ <td>The ZooKeeper cluster url (e.g., 192.168.1.100:2181,192.168.1.101:2181).</td>
+ </tr>
+ <tr>
+ <td><code>spark.deploy.zookeeper.dir</code></td>
+ <td>The directory in ZooKeeper to store recovery state (default: /spark).</td>
+ </tr>
+</table>
+
+Possible gotcha: If you have multiple Masters in your cluster but fail to correctly configure the Masters to use ZooKeeper, the Masters will fail to discover each other and think they're all leaders. This will not lead to a healthy cluster state (as all Masters will schedule independently).
+
+**Details**
+
+After you have a ZooKeeper cluster set up, enabling high availability is straightforward. Simply start multiple Master processes on different nodes with the same ZooKeeper configuration (ZooKeeper URL and directory). Masters can be added and removed at any time.
+
+In order to schedule new applications or add Workers to the cluster, they need to know the IP address of the current leader. This can be accomplished by simply passing in a list of Masters where you used to pass in a single one. For example, you might start your SparkContext pointing to ``spark://host1:port1,host2:port2``. This would cause your SparkContext to try registering with both Masters -- if ``host1`` goes down, this configuration would still be correct as we'd find the new leader, ``host2``.
+
+There's an important distinction to be made between "registering with a Master" and normal operation. When starting up, an application or Worker needs to be able to find and register with the current lead Master. Once it successfully registers, though, it is "in the system" (i.e., stored in ZooKeeper). If failover occurs, the new leader will contact all previously registered applications and Workers to inform them of the change in leadership, so they need not even have known of the existence of the new Master at startup.
+
+Due to this property, new Masters can be created at any time, and the only thing you need to worry about is that _new_ applications and Workers can find it to register with in case it becomes the leader. Once registered, you're taken care of.
+
+## Single-Node Recovery with Local File System
+
+**Overview**
+
+ZooKeeper is the best way to go for production-level high availability, but if you just want to be able to restart the Master if it goes down, FILESYSTEM mode can take care of it. When applications and Workers register, they have enough state written to the provided directory so that they can be recovered upon a restart of the Master process.
+
+**Configuration**
+
+In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env using this configuration:
+
+<table class="table">
+ <tr><th style="width:21%">System property</th><th>Meaning</th></tr>
+ <tr>
+ <td><code>spark.deploy.recoveryMode</code></td>
+ <td>Set to FILESYSTEM to enable single-node recovery mode (default: NONE).</td>
+ </tr>
+ <tr>
+ <td><code>spark.deploy.recoveryDirectory</code></td>
+ <td>The directory in which Spark will store recovery state, accessible from the Master's perspective.</td>
+ </tr>
+</table>
+
+**Details**
+
+* This solution can be used in tandem with a process monitor/manager like [monit](http://mmonit.com/monit/), or just to enable manual recovery via restart.
+* While filesystem recovery seems straightforwardly better than not doing any recovery at all, this mode may be suboptimal for certain development or experimental purposes. In particular, killing a master via stop-master.sh does not clean up its recovery state, so whenever you start a new Master, it will enter recovery mode. This could increase the startup time by up to 1 minute if it needs to wait for all previously-registered Workers/clients to timeout.
+* While it's not officially supported, you could mount an NFS directory as the recovery directory. If the original Master node dies completely, you could then start a Master on a different node, which would correctly recover all previously registered Workers/applications (equivalent to ZooKeeper recovery). Future applications will have to be able to find the new Master, however, in order to register.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index c7df172024..851e30fe76 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -73,6 +73,10 @@ DStreams support many of the transformations available on normal Spark RDD's:
Iterator[T] => Iterator[U] when running on an DStream of type T. </td>
</tr>
<tr>
+ <td> <b>repartition</b>(<i>numPartitions</i>) </td>
+ <td> Changes the level of parallelism in this DStream by creating more or fewer partitions. </td>
+</tr>
+<tr>
<td> <b>union</b>(<i>otherStream</i>) </td>
<td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td>
</tr>
@@ -122,12 +126,12 @@ Spark Streaming features windowed computations, which allow you to apply transfo
<table class="table">
<tr><th style="width:30%">Transformation</th><th>Meaning</th></tr>
<tr>
- <td> <b>window</b>(<i>windowDuration</i>, </i>slideDuration</i>) </td>
+ <td> <b>window</b>(<i>windowDuration</i>, <i>slideDuration</i>) </td>
<td> Return a new DStream which is computed based on windowed batches of the source DStream. <i>windowDuration</i> is the width of the window and <i>slideTime</i> is the frequency during which the window is calculated. Both times must be multiples of the batch interval.
</td>
</tr>
<tr>
- <td> <b>countByWindow</b>(<i>windowDuration</i>, </i>slideDuration</i>) </td>
+ <td> <b>countByWindow</b>(<i>windowDuration</i>, <i>slideDuration</i>) </td>
<td> Return a sliding count of elements in the stream. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
</td>
</tr>
@@ -161,7 +165,6 @@ Spark Streaming features windowed computations, which allow you to apply transfo
<i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>.
</td>
</tr>
-
</table>
A complete list of DStream operations is available in the API documentation of [DStream](api/streaming/index.html#org.apache.spark.streaming.DStream) and [PairDStreamFunctions](api/streaming/index.html#org.apache.spark.streaming.PairDStreamFunctions).
diff --git a/docs/tuning.md b/docs/tuning.md
index 28d88a2659..f491ae9b95 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -175,7 +175,7 @@ To further tune garbage collection, we first need to understand some basic infor
* Java Heap space is divided in to two regions Young and Old. The Young generation is meant to hold short-lived objects
while the Old generation is intended for objects with longer lifetimes.
-* The Young generation is further divided into three regions [Eden, Survivor1, Survivor2].
+* The Young generation is further divided into three regions \[Eden, Survivor1, Survivor2\].
* A simplified description of the garbage collection procedure: When Eden is full, a minor GC is run on Eden and objects
that are alive from Eden and Survivor1 are copied to Survivor2. The Survivor regions are swapped. If an object is old
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 65868b76b9..79848380c0 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -73,7 +73,7 @@ def parse_args():
parser.add_option("-v", "--spark-version", default="0.8.0",
help="Version of Spark to use: 'X.Y.Z' or a specific git hash")
parser.add_option("--spark-git-repo",
- default="https://github.com/mesos/spark",
+ default="https://github.com/apache/incubator-spark",
help="Github repo from which to checkout supplied commit hash")
parser.add_option("--hadoop-major-version", default="1",
help="Major version of Hadoop (default: 1)")
diff --git a/examples/pom.xml b/examples/pom.xml
index c6c9def5be..a10dee7847 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -32,13 +32,20 @@
<url>http://spark.incubator.apache.org/</url>
<repositories>
- <!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
<repository>
- <id>lib</id>
- <url>file://${project.basedir}/lib</url>
+ <id>apache-repo</id>
+ <name>Apache Repository</name>
+ <url>https://repository.apache.org/content/repositories/releases</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
</repository>
</repositories>
+
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
@@ -81,9 +88,18 @@
</dependency>
<dependency>
<groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
- <scope>provided</scope>
+ <artifactId>kafka_2.9.2</artifactId>
+ <version>0.8.0-beta1</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.sun.jmx</groupId>
+ <artifactId>jmxri</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.sun.jdmk</groupId>
+ <artifactId>jmxtools</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
@@ -137,6 +153,14 @@
<groupId>org.apache.cassandra.deps</groupId>
<artifactId>avro</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.xerial.snappy</groupId>
+ <artifactId>*</artifactId>
+ </exclusion>
</exclusions>
</dependency>
</dependencies>
diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
new file mode 100644
index 0000000000..9a8e4209ed
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.examples;
+
+import java.util.Map;
+import java.util.HashMap;
+
+import com.google.common.collect.Lists;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import scala.Tuple2;
+
+/**
+ * Consumes messages from one or more topics in Kafka and does wordcount.
+ * Usage: JavaKafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <zkQuorum> is a list of one or more zookeeper servers that make quorum
+ * <group> is the name of kafka consumer group
+ * <topics> is a list of one or more kafka topics to consume from
+ * <numThreads> is the number of threads the kafka consumer should use
+ *
+ * Example:
+ * `./run-example org.apache.spark.streaming.examples.JavaKafkaWordCount local[2] zoo01,zoo02,
+ * zoo03 my-consumer-group topic1,topic2 1`
+ */
+
+public class JavaKafkaWordCount {
+ public static void main(String[] args) {
+ if (args.length < 5) {
+ System.err.println("Usage: KafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>");
+ System.exit(1);
+ }
+
+ // Create the context with a 1 second batch size
+ JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount",
+ new Duration(2000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
+
+ int numThreads = Integer.parseInt(args[4]);
+ Map<String, Integer> topicMap = new HashMap<String, Integer>();
+ String[] topics = args[3].split(",");
+ for (String topic: topics) {
+ topicMap.put(topic, numThreads);
+ }
+
+ JavaPairDStream<String, String> messages = ssc.kafkaStream(args[1], args[2], topicMap);
+
+ JavaDStream<String> lines = messages.map(new Function<Tuple2<String, String>, String>() {
+ @Override
+ public String call(Tuple2<String, String> tuple2) throws Exception {
+ return tuple2._2();
+ }
+ });
+
+ JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
+ @Override
+ public Iterable<String> call(String x) {
+ return Lists.newArrayList(x.split(" "));
+ }
+ });
+
+ JavaPairDStream<String, Integer> wordCounts = words.map(
+ new PairFunction<String, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(String s) throws Exception {
+ return new Tuple2<String, Integer>(s, 1);
+ }
+ }).reduceByKey(new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer i1, Integer i2) throws Exception {
+ return i1 + i2;
+ }
+ });
+
+ wordCounts.print();
+ ssc.start();
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
index 868ff81f67..529709c2f9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
@@ -22,12 +22,19 @@ import org.apache.spark.SparkContext
object BroadcastTest {
def main(args: Array[String]) {
if (args.length == 0) {
- System.err.println("Usage: BroadcastTest <master> [<slices>] [numElem]")
+ System.err.println("Usage: BroadcastTest <master> [slices] [numElem] [broadcastAlgo] [blockSize]")
System.exit(1)
}
- val sc = new SparkContext(args(0), "Broadcast Test",
+ val bcName = if (args.length > 3) args(3) else "Http"
+ val blockSize = if (args.length > 4) args(4) else "4096"
+
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + "BroadcastFactory")
+ System.setProperty("spark.broadcast.blockSize", blockSize)
+
+ val sc = new SparkContext(args(0), "Broadcast Test 2",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
+
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
@@ -36,13 +43,15 @@ object BroadcastTest {
arr1(i) = i
}
- for (i <- 0 until 2) {
+ for (i <- 0 until 3) {
println("Iteration " + i)
println("===========")
+ val startTime = System.nanoTime
val barr1 = sc.broadcast(arr1)
sc.parallelize(1 to 10, slices).foreach {
i => println(barr1.value.size)
}
+ println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6))
}
System.exit(0)
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
index 646682878f..86dd9ca1b3 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
@@ -21,6 +21,7 @@ import java.util.Random
import scala.math.exp
import org.apache.spark.util.Vector
import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.InputFormatInfo
/**
@@ -51,7 +52,7 @@ object SparkHdfsLR {
System.exit(1)
}
val inputPath = args(1)
- val conf = SparkEnv.get.hadoop.newConfiguration()
+ val conf = SparkHadoopUtil.get.newConfiguration()
val sc = new SparkContext(args(0), "SparkHdfsLR",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")), Map(),
InputFormatInfo.computePreferredLocations(
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
index f7bf75b4e5..bc2db39c12 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala
@@ -21,8 +21,6 @@ import java.util.Random
import org.apache.spark.SparkContext
import org.apache.spark.util.Vector
import org.apache.spark.SparkContext._
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
/**
* K-means clustering.
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
index 5a2bc9b0d0..a689e5a360 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
@@ -38,6 +38,6 @@ object SparkPi {
if (x*x + y*y < 1) 1 else 0
}.reduce(_ + _)
println("Pi is roughly " + 4.0 * count / n)
- System.exit(0)
+ spark.stop()
}
}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
index 12f939d5a7..570ba4c81a 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
@@ -18,13 +18,11 @@
package org.apache.spark.streaming.examples
import java.util.Properties
-import kafka.message.Message
-import kafka.producer.SyncProducerConfig
+
import kafka.producer._
-import org.apache.spark.SparkContext
+
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.util.RawTextHelper._
/**
@@ -54,9 +52,10 @@ object KafkaWordCount {
ssc.checkpoint("checkpoint")
val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
- val lines = ssc.kafkaStream(zkQuorum, group, topicpMap)
+ val lines = ssc.kafkaStream(zkQuorum, group, topicpMap).map(_._2)
val words = lines.flatMap(_.split(" "))
- val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
+ val wordCounts = words.map(x => (x, 1l))
+ .reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
wordCounts.print()
ssc.start()
@@ -68,15 +67,16 @@ object KafkaWordCountProducer {
def main(args: Array[String]) {
if (args.length < 2) {
- System.err.println("Usage: KafkaWordCountProducer <zkQuorum> <topic> <messagesPerSec> <wordsPerMessage>")
+ System.err.println("Usage: KafkaWordCountProducer <metadataBrokerList> <topic> " +
+ "<messagesPerSec> <wordsPerMessage>")
System.exit(1)
}
- val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args
+ val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args
// Zookeper connection properties
val props = new Properties()
- props.put("zk.connect", zkQuorum)
+ props.put("metadata.broker.list", brokers)
props.put("serializer.class", "kafka.serializer.StringEncoder")
val config = new ProducerConfig(props)
@@ -85,11 +85,13 @@ object KafkaWordCountProducer {
// Send some messages
while(true) {
val messages = (1 to messagesPerSec.toInt).map { messageNum =>
- (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ")
+ val str = (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString)
+ .mkString(" ")
+
+ new KeyedMessage[String, String](topic, str)
}.toArray
- println(messages.mkString(","))
- val data = new ProducerData[String, String](topic, messages)
- producer.send(data)
+
+ producer.send(messages: _*)
Thread.sleep(100)
}
}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala
new file mode 100644
index 0000000000..af698a01d5
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.examples
+
+import org.apache.spark.streaming.{ Seconds, StreamingContext }
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.dstream.MQTTReceiver
+import org.apache.spark.storage.StorageLevel
+
+import org.eclipse.paho.client.mqttv3.MqttClient
+import org.eclipse.paho.client.mqttv3.MqttClientPersistence
+import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
+import org.eclipse.paho.client.mqttv3.MqttException
+import org.eclipse.paho.client.mqttv3.MqttMessage
+import org.eclipse.paho.client.mqttv3.MqttTopic
+
+/**
+ * A simple Mqtt publisher for demonstration purposes, repeatedly publishes
+ * Space separated String Message "hello mqtt demo for spark streaming"
+ */
+object MQTTPublisher {
+
+ var client: MqttClient = _
+
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println("Usage: MQTTPublisher <MqttBrokerUrl> <topic>")
+ System.exit(1)
+ }
+
+ val Seq(brokerUrl, topic) = args.toSeq
+
+ try {
+ var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp")
+ client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance)
+ } catch {
+ case e: MqttException => println("Exception Caught: " + e)
+ }
+
+ client.connect()
+
+ val msgtopic: MqttTopic = client.getTopic(topic);
+ val msg: String = "hello mqtt demo for spark streaming"
+
+ while (true) {
+ val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes())
+ msgtopic.publish(message);
+ println("Published data. topic: " + msgtopic.getName() + " Message: " + message)
+ }
+ client.disconnect()
+ }
+}
+
+/**
+ * A sample wordcount with MqttStream stream
+ *
+ * To work with Mqtt, Mqtt Message broker/server required.
+ * Mosquitto (http://mosquitto.org/) is an open source Mqtt Broker
+ * In ubuntu mosquitto can be installed using the command `$ sudo apt-get install mosquitto`
+ * Eclipse paho project provides Java library for Mqtt Client http://www.eclipse.org/paho/
+ * Example Java code for Mqtt Publisher and Subscriber can be found here https://bitbucket.org/mkjinesh/mqttclient
+ * Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>
+ * In local mode, <master> should be 'local[n]' with n > 1
+ * <MqttbrokerUrl> and <topic> describe where Mqtt publisher is running.
+ *
+ * To run this example locally, you may run publisher as
+ * `$ ./run-example org.apache.spark.streaming.examples.MQTTPublisher tcp://localhost:1883 foo`
+ * and run the example as
+ * `$ ./run-example org.apache.spark.streaming.examples.MQTTWordCount local[2] tcp://localhost:1883 foo`
+ */
+object MQTTWordCount {
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println(
+ "Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>" +
+ " In local mode, <master> should be 'local[n]' with n > 1")
+ System.exit(1)
+ }
+
+ val Seq(master, brokerUrl, topic) = args.toSeq
+
+ val ssc = new StreamingContext(master, "MqttWordCount", Seconds(2), System.getenv("SPARK_HOME"),
+ Seq(System.getenv("SPARK_EXAMPLES_JAR")))
+ val lines = ssc.mqttStream(brokerUrl, topic, StorageLevel.MEMORY_ONLY)
+
+ val words = lines.flatMap(x => x.toString.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.print()
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
index 884d6d6f34..de70c50473 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
@@ -17,17 +17,19 @@
package org.apache.spark.streaming.examples.clickstream
-import java.net.{InetAddress,ServerSocket,Socket,SocketException}
-import java.io.{InputStreamReader, BufferedReader, PrintWriter}
+import java.net.ServerSocket
+import java.io.PrintWriter
import util.Random
/** Represents a page view on a website with associated dimension data.*/
-class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) {
+class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int)
+ extends Serializable {
override def toString() : String = {
"%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID)
}
}
-object PageView {
+
+object PageView extends Serializable {
def fromString(in : String) : PageView = {
val parts = in.split("\t")
new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt)
@@ -39,6 +41,9 @@ object PageView {
* This should be used in tandem with PageViewStream.scala. Example:
* $ ./run-example spark.streaming.examples.clickstream.PageViewGenerator 44444 10
* $ ./run-example spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444
+ *
+ * When running this, you may want to set the root logging level to ERROR in
+ * conf/log4j.properties to reduce the verbosity of the output.
* */
object PageViewGenerator {
val pages = Map("http://foo.com/" -> .7,
diff --git a/pom.xml b/pom.xml
index 9278856e5a..72b9549cfa 100644
--- a/pom.xml
+++ b/pom.xml
@@ -21,7 +21,7 @@
<parent>
<groupId>org.apache</groupId>
<artifactId>apache</artifactId>
- <version>11</version>
+ <version>13</version>
</parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
@@ -61,6 +61,29 @@
<maven>3.0.0</maven>
</prerequisites>
+ <mailingLists>
+ <mailingList>
+ <name>Dev Mailing List</name>
+ <post>dev@spark.incubator.apache.org</post>
+ <subscribe>dev-subscribe@spark.incubator.apache.org</subscribe>
+ <unsubscribe>dev-unsubscribe@spark.incubator.apache.org</unsubscribe>
+ </mailingList>
+
+ <mailingList>
+ <name>User Mailing List</name>
+ <post>user@spark.incubator.apache.org</post>
+ <subscribe>user-subscribe@spark.incubator.apache.org</subscribe>
+ <unsubscribe>user-unsubscribe@spark.incubator.apache.org</unsubscribe>
+ </mailingList>
+
+ <mailingList>
+ <name>Commits Mailing List</name>
+ <post>commits@spark.incubator.apache.org</post>
+ <subscribe>commits-subscribe@spark.incubator.apache.org</subscribe>
+ <unsubscribe>commits-unsubscribe@spark.incubator.apache.org</unsubscribe>
+ </mailingList>
+ </mailingLists>
+
<modules>
<module>core</module>
<module>bagel</module>
@@ -126,6 +149,17 @@
<enabled>false</enabled>
</snapshots>
</repository>
+ <repository>
+ <id>mqtt-repo</id>
+ <name>MQTT Repository</name>
+ <url>https://repo.eclipse.org/content/repositories/paho-releases/</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
+ </repository>
</repositories>
<pluginRepositories>
<pluginRepository>
@@ -227,13 +261,36 @@
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
+ <artifactId>akka-actor_${scala-short.version}</artifactId>
+ <version>${akka.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>com.typesafe.akka</groupId>
<artifactId>akka-remote_${scala-short.version}</artifactId>
<version>${akka.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-slf4j_${scala-short.version}</artifactId>
<version>${akka.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>it.unimi.dsi</groupId>
@@ -347,6 +404,17 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.apache.zookeeper</groupId>
+ <artifactId>zookeeper</artifactId>
+ <version>3.4.5</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
<version>${hadoop.version}</version>
@@ -361,19 +429,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@@ -397,19 +457,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@@ -428,19 +480,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@@ -459,19 +503,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 43db97f680..b71e1b3a56 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -60,6 +60,8 @@ object SparkBuild extends Build {
lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings)
.dependsOn(core, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*)
+ lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects")
+
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
@@ -74,8 +76,11 @@ object SparkBuild extends Build {
// Conditionally include the yarn sub-project
lazy val maybeYarn = if(isYarnEnabled) Seq[ClasspathDependency](yarn) else Seq[ClasspathDependency]()
lazy val maybeYarnRef = if(isYarnEnabled) Seq[ProjectReference](yarn) else Seq[ProjectReference]()
- lazy val allProjects = Seq[ProjectReference](
- core, repl, examples, bagel, streaming, mllib, tools, assemblyProj) ++ maybeYarnRef
+
+ // Everything except assembly, tools and examples belong to packageProjects
+ lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib) ++ maybeYarnRef
+
+ lazy val allProjects = packageProjects ++ Seq[ProjectReference](examples, tools, assemblyProj)
def sharedSettings = Defaults.defaultSettings ++ Seq(
organization := "org.apache.spark",
@@ -100,7 +105,13 @@ object SparkBuild extends Build {
// also check the local Maven repository ~/.m2
resolvers ++= Seq(Resolver.file("Local Maven Repo", file(Path.userHome + "/.m2/repository"))),
- // For Sonatype publishing
+ // Shared between both core and streaming.
+ resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"),
+
+ // Shared between both examples and streaming.
+ resolvers ++= Seq("Mqtt Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/"),
+
+ // For Sonatype publishing
resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
"sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"),
@@ -212,7 +223,7 @@ object SparkBuild extends Build {
"org.apache.mesos" % "mesos" % "0.13.0",
"net.java.dev.jets3t" % "jets3t" % "0.7.1",
"org.apache.derby" % "derby" % "10.4.2.0" % "test",
- "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm),
+ "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib),
"org.apache.avro" % "avro" % "1.7.4",
"org.apache.avro" % "avro-ipc" % "1.7.4" excludeAll(excludeNetty),
"com.codahale.metrics" % "metrics-core" % "3.0.0",
@@ -272,8 +283,19 @@ object SparkBuild extends Build {
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
+ resolvers ++= Seq(
+ "Akka Repository" at "http://repo.akka.io/releases/",
+ "Apache repo" at "https://repository.apache.org/content/repositories/releases"
+ ),
+
libraryDependencies ++= Seq(
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy),
+ "com.sksamuel.kafka" %% "kafka" % "0.8.0-beta1"
+ exclude("com.sun.jdmk", "jmxtools")
+ exclude("com.sun.jmx", "jmxri")
+ exclude("net.sf.jopt-simple", "jopt-simple")
+ excludeAll(excludeNetty),
+ "org.eclipse.paho" % "mqtt-client" % "0.4.0",
"com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty),
"org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
"com.typesafe.akka" %% "akka-zeromq" % "2.2.3" excludeAll(excludeNetty)
@@ -300,7 +322,9 @@ object SparkBuild extends Build {
def assemblyProjSettings = sharedSettings ++ Seq(
name := "spark-assembly",
- jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }
+ assembleDeps in Compile <<= (packageProjects.map(packageBin in Compile in _) ++ Seq(packageDependency in Compile)).dependOn,
+ jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" },
+ jarName in packageDependency <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" }
) ++ assemblySettings ++ extraAssemblySettings
def extraAssemblySettings() = Seq(
@@ -308,6 +332,7 @@ object SparkBuild extends Build {
mergeStrategy in assembly := {
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
+ case "log4j.properties" => MergeStrategy.discard
case "META-INF/services/org.apache.hadoop.fs.FileSystem" => MergeStrategy.concat
case "reference.conf" => MergeStrategy.concat
case _ => MergeStrategy.first
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index d367f91967..da3d96689a 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -42,6 +42,13 @@
>>> a.value
13
+>>> b = sc.accumulator(0)
+>>> def g(x):
+... b.add(x)
+>>> rdd.foreach(g)
+>>> b.value
+6
+
>>> from pyspark.accumulators import AccumulatorParam
>>> class VectorAccumulatorParam(AccumulatorParam):
... def zero(self, value):
@@ -139,9 +146,13 @@ class Accumulator(object):
raise Exception("Accumulator.value cannot be accessed inside tasks")
self._value = value
+ def add(self, term):
+ """Adds a term to this accumulator's value"""
+ self._value = self.accum_param.addInPlace(self._value, term)
+
def __iadd__(self, term):
"""The += operator; adds a term to this accumulator's value"""
- self._value = self.accum_param.addInPlace(self._value, term)
+ self.add(term)
return self
def __str__(self):
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 597110321a..a7ca8bc888 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -49,6 +49,7 @@ class SparkContext(object):
_lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
+
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
"""
@@ -66,19 +67,18 @@ class SparkContext(object):
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
+
+
+ >>> from pyspark.context import SparkContext
+ >>> sc = SparkContext('local', 'test')
+
+ >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
"""
- with SparkContext._lock:
- if SparkContext._active_spark_context:
- raise ValueError("Cannot run multiple SparkContexts at once")
- else:
- SparkContext._active_spark_context = self
- if not SparkContext._gateway:
- SparkContext._gateway = launch_gateway()
- SparkContext._jvm = SparkContext._gateway.jvm
- SparkContext._writeIteratorToPickleFile = \
- SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
- SparkContext._takePartition = \
- SparkContext._jvm.PythonRDD.takePartition
+ SparkContext._ensure_initialized(self)
+
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
@@ -119,6 +119,32 @@ class SparkContext(object):
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
+ @classmethod
+ def _ensure_initialized(cls, instance=None):
+ with SparkContext._lock:
+ if not SparkContext._gateway:
+ SparkContext._gateway = launch_gateway()
+ SparkContext._jvm = SparkContext._gateway.jvm
+ SparkContext._writeIteratorToPickleFile = \
+ SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+ SparkContext._takePartition = \
+ SparkContext._jvm.PythonRDD.takePartition
+
+ if instance:
+ if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
+ raise ValueError("Cannot run multiple SparkContexts at once")
+ else:
+ SparkContext._active_spark_context = instance
+
+ @classmethod
+ def setSystemProperty(cls, key, value):
+ """
+ Set a system property, such as spark.executor.memory. This must be
+ invoked before instantiating SparkContext.
+ """
+ SparkContext._ensure_initialized()
+ SparkContext._jvm.java.lang.System.setProperty(key, value)
+
@property
def defaultParallelism(self):
"""
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 988b624feb..43e504c290 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -675,6 +675,20 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
Result(true, shouldReplay)
}
+ def addAllClasspath(args: Seq[String]): Unit = {
+ var added = false
+ var totalClasspath = ""
+ for (arg <- args) {
+ val f = File(arg).normalize
+ if (f.exists) {
+ added = true
+ addedClasspath = ClassPath.join(addedClasspath, f.path)
+ totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
+ }
+ }
+ if (added) replay()
+ }
+
def addClasspath(arg: String): Unit = {
val f = File(arg).normalize
if (f.exists) {
@@ -915,10 +929,10 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}
def createSparkContext(): SparkContext = {
- val uri = System.getenv("SPARK_EXECUTOR_URI")
- if (uri != null) {
- System.setProperty("spark.executor.uri", uri)
- }
+ val uri = System.getenv("SPARK_EXECUTOR_URI")
+ if (uri != null) {
+ System.setProperty("spark.executor.uri", uri)
+ }
val master = this.master match {
case Some(m) => m
case None => {
@@ -926,9 +940,17 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
if (prop != null) prop else "local"
}
}
- val jars = SparkILoop.getAddedJars.map(new java.io.File(_).getAbsolutePath)
- sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
- echo("Created spark context..")
+ val jars = Option(System.getenv("ADD_JARS")).map(_.split(','))
+ .getOrElse(new Array[String](0))
+ .map(new java.io.File(_).getAbsolutePath)
+ try {
+ sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
+ } catch {
+ case e: Exception =>
+ e.printStackTrace()
+ echo("Failed to create SparkContext, exiting...")
+ sys.exit(1)
+ }
sparkContext
}
diff --git a/spark-class b/spark-class
index 5305b3d025..359db3d984 100755
--- a/spark-class
+++ b/spark-class
@@ -95,10 +95,17 @@ export JAVA_OPTS
if [ ! -f "$FWDIR/RELEASE" ]; then
# Exit if the user hasn't compiled Spark
- ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null
- if [[ $? != 0 ]]; then
- echo "Failed to find Spark assembly in $FWDIR/assembly/target" >&2
- echo "You need to build Spark with sbt/sbt assembly before running this program" >&2
+ num_jars=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar" | wc -l)
+ jars_list=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar")
+ if [ "$num_jars" -eq "0" ]; then
+ echo "Failed to find Spark assembly in $FWDIR/assembly/target/scala-$SCALA_VERSION/" >&2
+ echo "You need to build Spark with 'sbt/sbt assembly' before running this program." >&2
+ exit 1
+ fi
+ if [ "$num_jars" -gt "1" ]; then
+ echo "Found multiple Spark assembly jars in $FWDIR/assembly/target/scala-$SCALA_VERSION:" >&2
+ echo "$jars_list"
+ echo "Please remove all but one jar."
exit 1
fi
fi
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar
deleted file mode 100644
index 65f79925a4..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar
+++ /dev/null
Binary files differ
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5
deleted file mode 100644
index 29f45f4adb..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5
+++ /dev/null
@@ -1 +0,0 @@
-18876b8bc2e4cef28b6d191aa49d963f \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1
deleted file mode 100644
index e3bd62bac0..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1
+++ /dev/null
@@ -1 +0,0 @@
-06b27270ffa52250a2c08703b397c99127b72060 \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom
deleted file mode 100644
index 082d35726a..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom
+++ /dev/null
@@ -1,9 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns="http://maven.apache.org/POM/4.0.0"
- xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
- <modelVersion>4.0.0</modelVersion>
- <groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <version>0.7.2-spark</version>
- <description>POM was created from install:install-file</description>
-</project>
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5
deleted file mode 100644
index 92c4132b5b..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5
+++ /dev/null
@@ -1 +0,0 @@
-7bc4322266e6032bdf9ef6eebdd8097d \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1
deleted file mode 100644
index 8a1d8a097a..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1
+++ /dev/null
@@ -1 +0,0 @@
-d0f79e8eff0db43ca7bcf7dce2c8cd2972685c9d \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml
deleted file mode 100644
index 720cd51c2f..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml
+++ /dev/null
@@ -1,12 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<metadata>
- <groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <versioning>
- <release>0.7.2-spark</release>
- <versions>
- <version>0.7.2-spark</version>
- </versions>
- <lastUpdated>20130121015225</lastUpdated>
- </versioning>
-</metadata>
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5
deleted file mode 100644
index a4ce5dc9e8..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5
+++ /dev/null
@@ -1 +0,0 @@
-e2b9c7c5f6370dd1d21a0aae5e8dcd77 \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1
deleted file mode 100644
index b869eaf2a6..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1
+++ /dev/null
@@ -1 +0,0 @@
-2a4341da936b6c07a09383d17ffb185ac558ee91 \ No newline at end of file
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 3f2033f34a..fb15681e25 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -32,10 +32,16 @@
<url>http://spark.incubator.apache.org/</url>
<repositories>
- <!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
<repository>
- <id>lib</id>
- <url>file://${project.basedir}/lib</url>
+ <id>apache-repo</id>
+ <name>Apache Repository</name>
+ <url>https://repository.apache.org/content/repositories/releases</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
</repository>
</repositories>
@@ -56,9 +62,22 @@
</dependency>
<dependency>
<groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
- <scope>provided</scope>
+ <artifactId>kafka_2.9.2</artifactId>
+ <version>0.8.0-beta1</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.sun.jmx</groupId>
+ <artifactId>jmxri</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.sun.jdmk</groupId>
+ <artifactId>jmxtools</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>net.sf.jopt-simple</groupId>
+ <artifactId>jopt-simple</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.apache.flume</groupId>
@@ -69,17 +88,22 @@
<groupId>org.jboss.netty</groupId>
<artifactId>netty</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.xerial.snappy</groupId>
+ <artifactId>*</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
- <groupId>com.github.sgroschupf</groupId>
- <artifactId>zkclient</artifactId>
- <version>0.1</version>
- </dependency>
- <dependency>
<groupId>org.twitter4j</groupId>
<artifactId>twitter4j-stream</artifactId>
<version>3.0.3</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
@@ -89,6 +113,12 @@
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-zeromq_${scala-short.version}</artifactId>
<version>${akka.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
@@ -114,6 +144,11 @@
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.eclipse.paho</groupId>
+ <artifactId>mqtt-client</artifactId>
+ <version>0.4.0</version>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala-short.version}/classes</outputDirectory>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 2d8f072624..bb9febad38 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.Logging
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.util.MetadataCleaner
private[streaming]
@@ -40,6 +41,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
val checkpointDir = ssc.checkpointDir
val checkpointDuration = ssc.checkpointDuration
val pendingTimes = ssc.scheduler.jobManager.getPendingTimes()
+ val delaySeconds = MetadataCleaner.getDelaySeconds
def validate() {
assert(master != null, "Checkpoint.master is null")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index cd404fd408..329d2b5835 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -38,7 +38,7 @@ import org.apache.hadoop.conf.Configuration
/**
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
- * sequence of RDDs (of the same type) representing a continuous stream of data (see [[org.apache.spark.RDD]]
+ * sequence of RDDs (of the same type) representing a continuous stream of data (see [[org.apache.spark.rdd.RDD]]
* for more details on RDDs). DStreams can either be created from live data (such as, data from
* HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations
* such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each
@@ -439,6 +439,13 @@ abstract class DStream[T: ClassTag] (
*/
def glom(): DStream[Array[T]] = new GlommedDStream(this)
+
+ /**
+ * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+ * returned DStream has exactly numPartitions partitions.
+ */
+ def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions))
+
/**
* Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
* of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
@@ -480,7 +487,7 @@ abstract class DStream[T: ClassTag] (
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
- * this DStream will be registered as an output stream and therefore materialized.
+ * 'this' DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: RDD[T] => Unit) {
this.foreach((r: RDD[T], t: Time) => foreachFunc(r))
@@ -488,7 +495,7 @@ abstract class DStream[T: ClassTag] (
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
- * this DStream will be registered as an output stream and therefore materialized.
+ * 'this' DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: (RDD[T], Time) => Unit) {
ssc.registerOutputStream(new ForEachDStream(this, context.sparkContext.clean(foreachFunc)))
@@ -496,18 +503,52 @@ abstract class DStream[T: ClassTag] (
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
- transform((r: RDD[T], t: Time) => transformFunc(r))
+ transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r)))
}
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
- new TransformedDStream(this, context.sparkContext.clean(transformFunc))
+ //new TransformedDStream(this, context.sparkContext.clean(transformFunc))
+ val cleanedF = context.sparkContext.clean(transformFunc)
+ val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
+ assert(rdds.length == 1)
+ cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
+ }
+ new TransformedDStream[U](Seq(this), realTransformFunc)
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream and 'other' DStream.
+ */
+ def transformWith[U: ClassTag, V: ClassTag](
+ other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V]
+ ): DStream[V] = {
+ val cleanedF = ssc.sparkContext.clean(transformFunc)
+ transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream and 'other' DStream.
+ */
+ def transformWith[U: ClassTag, V: ClassTag](
+ other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V]
+ ): DStream[V] = {
+ val cleanedF = ssc.sparkContext.clean(transformFunc)
+ val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
+ assert(rdds.length == 2)
+ val rdd1 = rdds(0).asInstanceOf[RDD[T]]
+ val rdd2 = rdds(1).asInstanceOf[RDD[U]]
+ cleanedF(rdd1, rdd2, time)
+ }
+ new TransformedDStream[V](Seq(this, other), realTransformFunc)
}
/**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
index b761646dff..66fe6e7870 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
@@ -29,10 +29,12 @@ import scala.collection.mutable.Queue
import akka.actor._
import akka.pattern.ask
import scala.concurrent.duration._
+import akka.dispatch._
+import org.apache.spark.storage.BlockId
private[streaming] sealed trait NetworkInputTrackerMessage
private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
-private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage
+private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) extends NetworkInputTrackerMessage
private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
/**
@@ -47,7 +49,7 @@ class NetworkInputTracker(
val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor()
val receiverInfo = new HashMap[Int, ActorRef]
- val receivedBlockIds = new HashMap[Int, Queue[String]]
+ val receivedBlockIds = new HashMap[Int, Queue[BlockId]]
val timeout = 5000.milliseconds
var currentTime: Time = null
@@ -66,9 +68,9 @@ class NetworkInputTracker(
}
/** Return all the blocks received from a receiver. */
- def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized {
+ def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized {
val queue = receivedBlockIds.synchronized {
- receivedBlockIds.getOrElse(receiverId, new Queue[String]())
+ receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]())
}
val result = queue.synchronized {
queue.dequeueAll(x => true)
@@ -91,7 +93,7 @@ class NetworkInputTracker(
case AddBlocks(streamId, blockIds, metadata) => {
val tmp = receivedBlockIds.synchronized {
if (!receivedBlockIds.contains(streamId)) {
- receivedBlockIds += ((streamId, new Queue[String]))
+ receivedBlockIds += ((streamId, new Queue[BlockId]))
}
receivedBlockIds(streamId)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
index f021e29619..ea5c165691 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
@@ -18,9 +18,7 @@
package org.apache.spark.streaming
import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.streaming.dstream.{ReducedWindowedDStream, StateDStream}
-import org.apache.spark.streaming.dstream.{CoGroupedDStream, ShuffledDStream}
-import org.apache.spark.streaming.dstream.{MapValuedDStream, FlatMapValuedDStream}
+import org.apache.spark.streaming.dstream._
import org.apache.spark.{Partitioner, HashPartitioner}
import org.apache.spark.SparkContext._
@@ -35,6 +33,7 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.conf.Configuration
+import scala.Some
class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)])
extends Serializable {
@@ -360,7 +359,7 @@ extends Serializable {
}
/**
- * Create a new "state" DStream where the state for each key is updated by applying
+ * Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of the key.
* [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
@@ -399,11 +398,18 @@ extends Serializable {
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner)
}
-
+ /**
+ * Return a new DStream by applying a map function to the value of each key-value pairs in
+ * 'this' DStream without changing the key.
+ */
def mapValues[U: ClassTag](mapValuesFunc: V => U): DStream[(K, U)] = {
new MapValuedDStream[K, V, U](self, mapValuesFunc)
}
+ /**
+ * Return a new DStream by applying a flatmap function to the value of each key-value pairs in
+ * 'this' DStream without changing the key.
+ */
def flatMapValues[U: ClassTag](
flatMapValuesFunc: V => TraversableOnce[U]
): DStream[(K, U)] = {
@@ -411,9 +417,8 @@ extends Serializable {
}
/**
- * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this`
- * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
- * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with Spark's default number
* of partitions.
*/
def cogroup[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = {
@@ -421,56 +426,132 @@ extends Serializable {
}
/**
- * Cogroup `this` DStream with `other` DStream using a partitioner. For each key k in corresponding RDDs of `this`
- * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
- * key in both RDDs. Partitioner is used to partition each generated RDD.
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ */
+ def cogroup[W: ClassTag](other: DStream[(K, W)], numPartitions: Int): DStream[(K, (Seq[V], Seq[W]))] = {
+ cogroup(other, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * The supplied [[org.apache.spark.Partitioner]] is used to partition the generated RDDs.
*/
def cogroup[W: ClassTag](
other: DStream[(K, W)],
partitioner: Partitioner
): DStream[(K, (Seq[V], Seq[W]))] = {
-
- val cgd = new CoGroupedDStream[K](
- Seq(self.asInstanceOf[DStream[(K, _)]], other.asInstanceOf[DStream[(K, _)]]),
- partitioner
- )
- val pdfs = new PairDStreamFunctions[K, Seq[Seq[_]]](cgd)(
- classTag[K],
- ClassTags.seqSeqClassTag
+ self.transformWith(
+ other,
+ (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.cogroup(rdd2, partitioner)
)
- pdfs.mapValues {
- case Seq(vs, ws) =>
- (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
- }
}
/**
- * Join `this` DStream with `other` DStream. HashPartitioner is used
- * to partition each generated RDD into default number of partitions.
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
*/
def join[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, W))] = {
join[W](other, defaultPartitioner())
}
/**
- * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will
- * be generated by joining RDDs from `this` and other DStream. Uses the given
- * Partitioner to partition each generated RDD.
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ */
+ def join[W: ClassTag](other: DStream[(K, W)], numPartitions: Int): DStream[(K, (V, W))] = {
+ join[W](other, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * The supplied [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD.
*/
def join[W: ClassTag](
other: DStream[(K, W)],
partitioner: Partitioner
): DStream[(K, (V, W))] = {
- this.cogroup(other, partitioner)
- .flatMapValues{
- case (vs, ws) =>
- for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
- }
+ self.transformWith(
+ other,
+ (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.join(rdd2, partitioner)
+ )
+ }
+
+ /**
+ * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default
+ * number of partitions.
+ */
+ def leftOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, Option[W]))] = {
+ leftOuterJoin[W](other, defaultPartitioner())
+ }
+
+ /**
+ * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions`
+ * partitions.
+ */
+ def leftOuterJoin[W: ClassTag](
+ other: DStream[(K, W)],
+ numPartitions: Int
+ ): DStream[(K, (V, Option[W]))] = {
+ leftOuterJoin[W](other, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and
+ * `other` DStream. The supplied [[org.apache.spark.Partitioner]] is used to control
+ * the partitioning of each RDD.
+ */
+ def leftOuterJoin[W: ClassTag](
+ other: DStream[(K, W)],
+ partitioner: Partitioner
+ ): DStream[(K, (V, Option[W]))] = {
+ self.transformWith(
+ other,
+ (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.leftOuterJoin(rdd2, partitioner)
+ )
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default
+ * number of partitions.
+ */
+ def rightOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Option[V], W))] = {
+ rightOuterJoin[W](other, defaultPartitioner())
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions`
+ * partitions.
+ */
+ def rightOuterJoin[W: ClassTag](
+ other: DStream[(K, W)],
+ numPartitions: Int
+ ): DStream[(K, (Option[V], W))] = {
+ rightOuterJoin[W](other, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. The supplied [[org.apache.spark.Partitioner]] is used to control
+ * the partitioning of each RDD.
+ */
+ def rightOuterJoin[W: ClassTag](
+ other: DStream[(K, W)],
+ partitioner: Partitioner
+ ): DStream[(K, (Option[V], W))] = {
+ self.transformWith(
+ other,
+ (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.rightOuterJoin(rdd2, partitioner)
+ )
}
/**
- * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated
- * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix"
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval
+ * is generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix"
*/
def saveAsHadoopFiles[F <: OutputFormat[K, V]](
prefix: String,
@@ -480,8 +561,8 @@ extends Serializable {
}
/**
- * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated
- * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix"
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval
+ * is generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix"
*/
def saveAsHadoopFiles(
prefix: String,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index c722aa15ab..d2c4fdee65 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -102,6 +102,10 @@ class StreamingContext private (
"both SparkContext and checkpoint as null")
}
+ if(cp_ != null && cp_.delaySeconds >= 0 && MetadataCleaner.getDelaySeconds < 0) {
+ MetadataCleaner.setDelaySeconds(cp_.delaySeconds)
+ }
+
if (MetadataCleaner.getDelaySeconds < 0) {
throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; "
+ "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)")
@@ -254,10 +258,14 @@ class StreamingContext private (
groupId: String,
topics: Map[String, Int],
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
- ): DStream[String] = {
+ ): DStream[(String, String)] = {
val kafkaParams = Map[String, String](
- "zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000")
- kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel)
+ "zookeeper.connect" -> zkQuorum, "group.id" -> groupId,
+ "zookeeper.connection.timeout.ms" -> "10000")
+ kafkaStream[String, String, kafka.serializer.StringDecoder, kafka.serializer.StringDecoder](
+ kafkaParams,
+ topics,
+ storageLevel)
}
/**
@@ -268,12 +276,16 @@ class StreamingContext private (
* in its own thread.
* @param storageLevel Storage level to use for storing the received objects
*/
- def kafkaStream[T: ClassTag, D <: kafka.serializer.Decoder[_]: Manifest](
+ def kafkaStream[
+ K: ClassTag,
+ V: ClassTag,
+ U <: kafka.serializer.Decoder[_]: Manifest,
+ T <: kafka.serializer.Decoder[_]: Manifest](
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
- ): DStream[T] = {
- val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel)
+ ): DStream[(K, V)] = {
+ val inputStream = new KafkaInputDStream[K, V, U, T](this, kafkaParams, topics, storageLevel)
registerInputStream(inputStream)
inputStream
}
@@ -452,14 +464,40 @@ class StreamingContext private (
inputStream
}
+/**
+ * Create an input stream that receives messages pushed by a mqtt publisher.
+ * @param brokerUrl Url of remote mqtt publisher
+ * @param topic topic name to subscribe to
+ * @param storageLevel RDD storage level. Defaults to memory-only.
+ */
+
+ def mqttStream(
+ brokerUrl: String,
+ topic: String,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2): DStream[String] = {
+ val inputStream = new MQTTInputDStream[String](this, brokerUrl, topic, storageLevel)
+ registerInputStream(inputStream)
+ inputStream
+ }
/**
- * Create a unified DStream from multiple DStreams of the same type and same interval
+ * Create a unified DStream from multiple DStreams of the same type and same slide duration.
*/
def union[T: ClassTag](streams: Seq[DStream[T]]): DStream[T] = {
new UnionDStream[T](streams.toArray)
}
/**
+ * Create a new DStream in which each RDD is generated by applying a function on RDDs of
+ * the DStreams.
+ */
+ def transform[T: ClassTag](
+ dstreams: Seq[DStream[_]],
+ transformFunc: (Seq[RDD[_]], Time) => RDD[T]
+ ): DStream[T] = {
+ new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc))
+ }
+
+ /**
* Register an input stream that will be started (InputDStream.start() called) to get the
* input data.
*/
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
index 0d54d78ed3..d29033df32 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
/**
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
- * sequence of RDDs (of the same type) representing a continuous stream of data (see [[org.apache.spark.RDD]]
+ * sequence of RDDs (of the same type) representing a continuous stream of data (see [[org.apache.spark.rdd.RDD]]
* for more details on RDDs). DStreams can either be created from live data (such as, data from
* HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations
* such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each
@@ -96,6 +96,12 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classTag: ClassTag[T]
*/
def union(that: JavaDStream[T]): JavaDStream[T] =
dstream.union(that.dstream)
+
+ /**
+ * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+ * returned DStream has exactly numPartitions partitions.
+ */
+ def repartition(numPartitions: Int): JavaDStream[T] = dstream.repartition(numPartitions)
}
object JavaDStream {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index 4508e48590..64f38ce1c0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -25,7 +25,8 @@ import scala.reflect.ClassTag
import org.apache.spark.streaming._
import org.apache.spark.api.java.{JavaPairRDD, JavaRDDLike, JavaRDD}
-import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
+import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
+import org.apache.spark.api.java.function.{Function3 => JFunction3, _}
import java.util
import org.apache.spark.rdd.RDD
import JavaDStream._
@@ -121,10 +122,12 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
* this DStream. Applying glom() to an RDD coalesces all elements within each partition into
* an array.
*/
- def glom(): JavaDStream[JList[T]] =
+ def glom(): JavaDStream[JList[T]] = {
new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
+ }
+
- /** Return the StreamingContext associated with this DStream */
+ /** Return the [[org.apache.spark.streaming.StreamingContext]] associated with this DStream */
def context(): StreamingContext = dstream.context()
/** Return a new DStream by applying a function to all elements of this DStream. */
@@ -239,7 +242,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
- * this DStream will be registered as an output stream and therefore materialized.
+ * 'this' DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: JFunction[R, Void]) {
dstream.foreach(rdd => foreachFunc.call(wrapRDD(rdd)))
@@ -247,7 +250,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
- * this DStream will be registered as an output stream and therefore materialized.
+ * 'this' DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: JFunction2[R, Time, Void]) {
dstream.foreach((rdd, time) => foreachFunc.call(wrapRDD(rdd), time))
@@ -255,7 +258,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream.
*/
def transform[U](transformFunc: JFunction[R, JavaRDD[U]]): JavaDStream[U] = {
implicit val cm: ClassTag[U] =
@@ -267,7 +270,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream.
*/
def transform[U](transformFunc: JFunction2[R, Time, JavaRDD[U]]): JavaDStream[U] = {
implicit val cm: ClassTag[U] =
@@ -279,7 +282,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream.
*/
def transform[K2, V2](transformFunc: JFunction[R, JavaPairRDD[K2, V2]]):
JavaPairDStream[K2, V2] = {
@@ -294,7 +297,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream.
*/
def transform[K2, V2](transformFunc: JFunction2[R, Time, JavaPairRDD[K2, V2]]):
JavaPairDStream[K2, V2] = {
@@ -308,6 +311,82 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
}
/**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream and 'other' DStream.
+ */
+ def transformWith[U, W](
+ other: JavaDStream[U],
+ transformFunc: JFunction3[R, JavaRDD[U], Time, JavaRDD[W]]
+ ): JavaDStream[W] = {
+ implicit val cmu: ClassTag[U] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
+ implicit val cmv: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ def scalaTransform (inThis: RDD[T], inThat: RDD[U], time: Time): RDD[W] =
+ transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd
+ dstream.transformWith[U, W](other.dstream, scalaTransform(_, _, _))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream and 'other' DStream.
+ */
+ def transformWith[U, K2, V2](
+ other: JavaDStream[U],
+ transformFunc: JFunction3[R, JavaRDD[U], Time, JavaPairRDD[K2, V2]]
+ ): JavaPairDStream[K2, V2] = {
+ implicit val cmu: ClassTag[U] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
+ implicit val cmk2: ClassTag[K2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]]
+ implicit val cmv2: ClassTag[V2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]]
+ def scalaTransform (inThis: RDD[T], inThat: RDD[U], time: Time): RDD[(K2, V2)] =
+ transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd
+ dstream.transformWith[U, (K2, V2)](other.dstream, scalaTransform(_, _, _))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream and 'other' DStream.
+ */
+ def transformWith[K2, V2, W](
+ other: JavaPairDStream[K2, V2],
+ transformFunc: JFunction3[R, JavaPairRDD[K2, V2], Time, JavaRDD[W]]
+ ): JavaDStream[W] = {
+ implicit val cmk2: ClassTag[K2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]]
+ implicit val cmv2: ClassTag[V2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]]
+ implicit val cmw: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] =
+ transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd
+ dstream.transformWith[(K2, V2), W](other.dstream, scalaTransform(_, _, _))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream and 'other' DStream.
+ */
+ def transformWith[K2, V2, K3, V3](
+ other: JavaPairDStream[K2, V2],
+ transformFunc: JFunction3[R, JavaPairRDD[K2, V2], Time, JavaPairRDD[K3, V3]]
+ ): JavaPairDStream[K3, V3] = {
+ implicit val cmk2: ClassTag[K2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]]
+ implicit val cmv2: ClassTag[V2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]]
+ implicit val cmk3: ClassTag[K3] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K3]]
+ implicit val cmv3: ClassTag[V3] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V3]]
+ def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[(K3, V3)] =
+ transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd
+ dstream.transformWith[(K2, V2), (K3, V3)](other.dstream, scalaTransform(_, _, _))
+ }
+
+ /**
* Enable periodic checkpointing of RDDs of this DStream
* @param interval Time interval after which generated RDD will be checkpointed
*/
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index c80545b530..3ba37bed4d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -25,7 +25,7 @@ import scala.reflect.ClassTag
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
+import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3}
import org.apache.spark.Partitioner
import org.apache.hadoop.mapred.{JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
@@ -37,8 +37,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PairRDDFunctions
class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
- implicit val kTag: ClassTag[K],
- implicit val vTag: ClassTag[V])
+ implicit val kManifest: ClassTag[K],
+ implicit val vManifest: ClassTag[V])
extends JavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] {
override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
@@ -60,6 +60,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
/** Persist the RDDs of this DStream with the given storage level */
def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel)
+ /**
+ * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+ * returned DStream has exactly numPartitions partitions.
+ */
+ def repartition(numPartitions: Int): JavaPairDStream[K, V] = dstream.repartition(numPartitions)
+
/** Method that generates a RDD for the given Duration */
def compute(validTime: Time): JavaPairRDD[K, V] = {
dstream.compute(validTime) match {
@@ -149,7 +155,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
/**
* Combine elements of each key in DStream's RDDs using custom function. This is similar to the
- * combineByKey for RDDs. Please refer to combineByKey in [[PairRDDFunctions]] for more
+ * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.PairRDDFunctions]] for more
* information.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
@@ -414,7 +420,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new "state" DStream where the state for each key is updated by applying
+ * Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
* Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
* @param updateFunc State update function. If `this` function returns None, then
@@ -429,7 +435,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new "state" DStream where the state for each key is updated by applying
+ * Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
* Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
* @param updateFunc State update function. If `this` function returns None, then
@@ -437,15 +443,17 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* @param numPartitions Number of partitions of each RDD in the new DStream.
* @tparam S State type
*/
- def updateStateByKey[S: ClassTag](
+ def updateStateByKey[S](
updateFunc: JFunction2[JList[V], Optional[S], Optional[S]],
numPartitions: Int)
: JavaPairDStream[K, S] = {
+ implicit val cm: ClassTag[S] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[S]]
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), numPartitions)
}
/**
- * Create a new "state" DStream where the state for each key is updated by applying
+ * Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of the key.
* [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
@@ -453,19 +461,30 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
* @tparam S State type
*/
- def updateStateByKey[S: ClassTag](
+ def updateStateByKey[S](
updateFunc: JFunction2[JList[V], Optional[S], Optional[S]],
partitioner: Partitioner
): JavaPairDStream[K, S] = {
+ implicit val cm: ClassTag[S] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[S]]
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner)
}
+
+ /**
+ * Return a new DStream by applying a map function to the value of each key-value pairs in
+ * 'this' DStream without changing the key.
+ */
def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = {
implicit val cm: ClassTag[U] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
dstream.mapValues(f)
}
+ /**
+ * Return a new DStream by applying a flatmap function to the value of each key-value pairs in
+ * 'this' DStream without changing the key.
+ */
def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = {
import scala.collection.JavaConverters._
def fn = (x: V) => f.apply(x).asScala
@@ -475,9 +494,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this`
- * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
- * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with Spark's default number
* of partitions.
*/
def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JList[V], JList[W])] = {
@@ -487,21 +505,36 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this`
- * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
- * key in both RDDs. Partitioner is used to partition each generated RDD.
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
*/
- def cogroup[W](other: JavaPairDStream[K, W], partitioner: Partitioner)
- : JavaPairDStream[K, (JList[V], JList[W])] = {
+ def cogroup[W](
+ other: JavaPairDStream[K, W],
+ numPartitions: Int
+ ): JavaPairDStream[K, (JList[V], JList[W])] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ dstream.cogroup(other.dstream, numPartitions)
+ .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2))))
+ }
+
+ /**
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ */
+ def cogroup[W](
+ other: JavaPairDStream[K, W],
+ partitioner: Partitioner
+ ): JavaPairDStream[K, (JList[V], JList[W])] = {
implicit val cm: ClassTag[W] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
dstream.cogroup(other.dstream, partitioner)
- .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2))))
+ .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2))))
}
/**
- * Join `this` DStream with `other` DStream. HashPartitioner is used
- * to partition each generated RDD into default number of partitions.
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
*/
def join[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, W)] = {
implicit val cm: ClassTag[W] =
@@ -510,18 +543,112 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will
- * be generated by joining RDDs from `this` and other DStream. Uses the given
- * Partitioner to partition each generated RDD.
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ */
+ def join[W](other: JavaPairDStream[K, W], numPartitions: Int): JavaPairDStream[K, (V, W)] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ dstream.join(other.dstream, numPartitions)
+ }
+
+ /**
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * The supplied [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD.
*/
- def join[W](other: JavaPairDStream[K, W], partitioner: Partitioner)
- : JavaPairDStream[K, (V, W)] = {
+ def join[W](
+ other: JavaPairDStream[K, W],
+ partitioner: Partitioner
+ ): JavaPairDStream[K, (V, W)] = {
implicit val cm: ClassTag[W] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
dstream.join(other.dstream, partitioner)
}
/**
+ * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default
+ * number of partitions.
+ */
+ def leftOuterJoin[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, Optional[W])] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.leftOuterJoin(other.dstream)
+ joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}
+ }
+
+ /**
+ * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions`
+ * partitions.
+ */
+ def leftOuterJoin[W](
+ other: JavaPairDStream[K, W],
+ numPartitions: Int
+ ): JavaPairDStream[K, (V, Optional[W])] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.leftOuterJoin(other.dstream, numPartitions)
+ joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}
+ }
+
+ /**
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * The supplied [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD.
+ */
+ def leftOuterJoin[W](
+ other: JavaPairDStream[K, W],
+ partitioner: Partitioner
+ ): JavaPairDStream[K, (V, Optional[W])] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.leftOuterJoin(other.dstream, partitioner)
+ joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default
+ * number of partitions.
+ */
+ def rightOuterJoin[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (Optional[V], W)] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.rightOuterJoin(other.dstream)
+ joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions`
+ * partitions.
+ */
+ def rightOuterJoin[W](
+ other: JavaPairDStream[K, W],
+ numPartitions: Int
+ ): JavaPairDStream[K, (Optional[V], W)] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.rightOuterJoin(other.dstream, numPartitions)
+ joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. The supplied [[org.apache.spark.Partitioner]] is used to control
+ * the partitioning of each RDD.
+ */
+ def rightOuterJoin[W](
+ other: JavaPairDStream[K, W],
+ partitioner: Partitioner
+ ): JavaPairDStream[K, (Optional[V], W)] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.rightOuterJoin(other.dstream, partitioner)
+ joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}
+ }
+
+ /**
* Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
* generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
*/
@@ -591,14 +718,19 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf)
}
+ /** Convert to a JavaDStream */
+ def toJavaDStream(): JavaDStream[(K, V)] = {
+ new JavaDStream[(K, V)](dstream)
+ }
+
override val classTag: ClassTag[(K, V)] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]]
}
object JavaPairDStream {
- implicit def fromPairDStream[K: ClassTag, V: ClassTag](dstream: DStream[(K, V)])
- :JavaPairDStream[K, V] =
+ implicit def fromPairDStream[K: ClassTag, V: ClassTag](dstream: DStream[(K, V)]) : JavaPairDStream[K, V] = {
new JavaPairDStream[K, V](dstream)
+ }
def fromJavaDStream[K, V](dstream: JavaDStream[(K, V)]): JavaPairDStream[K, V] = {
implicit val cmk: ClassTag[K] =
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 8242af6d5f..ca0c905932 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -19,7 +19,7 @@ package org.apache.spark.streaming.api.java
import java.lang.{Long => JLong, Integer => JInt}
import java.io.InputStream
-import java.util.{Map => JMap}
+import java.util.{Map => JMap, List => JList}
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
@@ -36,7 +36,7 @@ import twitter4j.auth.Authorization
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
-import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDDLike, JavaSparkContext, JavaRDD}
import org.apache.spark.streaming._
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.receivers.{ActorReceiver, ReceiverSupervisorStrategy}
@@ -144,7 +144,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
zkQuorum: String,
groupId: String,
topics: JMap[String, JInt])
- : JavaDStream[String] = {
+ : JavaPairDStream[String, String] = {
implicit val cmt: ClassTag[String] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]]
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
@@ -166,7 +166,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
groupId: String,
topics: JMap[String, JInt],
storageLevel: StorageLevel)
- : JavaDStream[String] = {
+ : JavaPairDStream[String, String] = {
implicit val cmt: ClassTag[String] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]]
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
@@ -175,25 +175,34 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Create an input stream that pulls messages form a Kafka Broker.
- * @param typeClass Type of RDD
- * @param decoderClass Type of kafka decoder
+ * @param keyTypeClass Key type of RDD
+ * @param valueTypeClass value type of RDD
+ * @param keyDecoderClass Type of kafka key decoder
+ * @param valueDecoderClass Type of kafka value decoder
* @param kafkaParams Map of kafka configuration paramaters.
* See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param storageLevel RDD storage level. Defaults to memory-only
*/
- def kafkaStream[T, D <: kafka.serializer.Decoder[_]](
- typeClass: Class[T],
- decoderClass: Class[D],
+ def kafkaStream[K, V, U <: kafka.serializer.Decoder[_], T <: kafka.serializer.Decoder[_]](
+ keyTypeClass: Class[K],
+ valueTypeClass: Class[V],
+ keyDecoderClass: Class[U],
+ valueDecoderClass: Class[T],
kafkaParams: JMap[String, String],
topics: JMap[String, JInt],
storageLevel: StorageLevel)
- : JavaDStream[T] = {
- implicit val cmt: ClassTag[T] =
- implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
- implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]]
- ssc.kafkaStream[T, D](
+ : JavaPairDStream[K, V] = {
+ implicit val keyCmt: ClassTag[K] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val valueCmt: ClassTag[V] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
+
+ implicit val keyCmd: Manifest[U] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[U]]
+ implicit val valueCmd: Manifest[T] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[T]]
+
+ ssc.kafkaStream[K, V, U, T](
kafkaParams.toMap,
Map(topics.mapValues(_.intValue()).toSeq: _*),
storageLevel)
@@ -589,6 +598,77 @@ class JavaStreamingContext(val ssc: StreamingContext) {
}
/**
+ * Create a unified DStream from multiple DStreams of the same type and same slide duration.
+ */
+ def union[T](first: JavaDStream[T], rest: JList[JavaDStream[T]]): JavaDStream[T] = {
+ val dstreams: Seq[DStream[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream)
+ implicit val cm: ClassTag[T] = first.classTag
+ ssc.union(dstreams)(cm)
+ }
+
+ /**
+ * Create a unified DStream from multiple DStreams of the same type and same slide duration.
+ */
+ def union[K, V](
+ first: JavaPairDStream[K, V],
+ rest: JList[JavaPairDStream[K, V]]
+ ): JavaPairDStream[K, V] = {
+ val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream)
+ implicit val cm: ClassTag[(K, V)] = first.classTag
+ implicit val kcm: ClassTag[K] = first.kManifest
+ implicit val vcm: ClassTag[V] = first.vManifest
+ new JavaPairDStream[K, V](ssc.union(dstreams)(cm))(kcm, vcm)
+ }
+
+ /**
+ * Create a new DStream in which each RDD is generated by applying a function on RDDs of
+ * the DStreams. The order of the JavaRDDs in the transform function parameter will be the
+ * same as the order of corresponding DStreams in the list. Note that for adding a
+ * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream().
+ * In the transform function, convert the JavaRDD corresponding to that JavaDStream to
+ * a JavaPairRDD using [[org.apache.spark.api.java.JavaPairRDD]].fromJavaRDD().
+ */
+ def transform[T](
+ dstreams: JList[JavaDStream[_]],
+ transformFunc: JFunction2[JList[JavaRDD[_]], Time, JavaRDD[T]]
+ ): JavaDStream[T] = {
+ implicit val cmt: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
+ val scalaDStreams = dstreams.map(_.dstream).toSeq
+ val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
+ val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList
+ transformFunc.call(jrdds, time).rdd
+ }
+ ssc.transform(scalaDStreams, scalaTransformFunc)
+ }
+
+ /**
+ * Create a new DStream in which each RDD is generated by applying a function on RDDs of
+ * the DStreams. The order of the JavaRDDs in the transform function parameter will be the
+ * same as the order of corresponding DStreams in the list. Note that for adding a
+ * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream().
+ * In the transform function, convert the JavaRDD corresponding to that JavaDStream to
+ * a JavaPairRDD using [[org.apache.spark.api.java.JavaPairRDD]].fromJavaRDD().
+ */
+ def transform[K, V](
+ dstreams: JList[JavaDStream[_]],
+ transformFunc: JFunction2[JList[JavaRDD[_]], Time, JavaPairRDD[K, V]]
+ ): JavaPairDStream[K, V] = {
+ implicit val cmk: ClassTag[K] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val cmv: ClassTag[V] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
+ val scalaDStreams = dstreams.map(_.dstream).toSeq
+ val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
+ val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList
+ transformFunc.call(jrdds, time).rdd
+ }
+ ssc.transform(scalaDStreams, scalaTransformFunc)
+ }
+
+ /**
* Sets the context to periodically checkpoint the DStream operations for master
* fault-tolerance. The graph will be checkpointed every batch interval.
* @param directory HDFS-compatible directory where the checkpoint data will be reliably stored
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala
index 96134868cc..ec0096c85f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala
@@ -19,24 +19,21 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{Time, DStreamCheckpointData, StreamingContext}
+import org.apache.spark.streaming.StreamingContext
import java.util.Properties
import java.util.concurrent.Executors
import kafka.consumer._
-import kafka.message.{Message, MessageSet, MessageAndMetadata}
import kafka.serializer.Decoder
-import kafka.utils.{Utils, ZKGroupTopicDirs}
-import kafka.utils.ZkUtils._
+import kafka.utils.VerifiableProperties
import kafka.utils.ZKStringSerializer
import org.I0Itec.zkclient._
import scala.collection.Map
-import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions._
import scala.reflect.ClassTag
+
/**
* Input stream that pulls messages from a Kafka Broker.
*
@@ -46,25 +43,32 @@ import scala.reflect.ClassTag
* @param storageLevel RDD storage level.
*/
private[streaming]
-class KafkaInputDStream[T: ClassTag, D <: Decoder[_]: Manifest](
+class KafkaInputDStream[
+ K: ClassTag,
+ V: ClassTag,
+ U <: Decoder[_]: Manifest,
+ T <: Decoder[_]: Manifest](
@transient ssc_ : StreamingContext,
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
- ) extends NetworkInputDStream[T](ssc_ ) with Logging {
-
+ ) extends NetworkInputDStream[(K, V)](ssc_) with Logging {
- def getReceiver(): NetworkReceiver[T] = {
- new KafkaReceiver[T, D](kafkaParams, topics, storageLevel)
- .asInstanceOf[NetworkReceiver[T]]
+ def getReceiver(): NetworkReceiver[(K, V)] = {
+ new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel)
+ .asInstanceOf[NetworkReceiver[(K, V)]]
}
}
private[streaming]
-class KafkaReceiver[T: ClassTag, D <: Decoder[_]: Manifest](
- kafkaParams: Map[String, String],
- topics: Map[String, Int],
- storageLevel: StorageLevel
+class KafkaReceiver[
+ K: ClassTag,
+ V: ClassTag,
+ U <: Decoder[_]: Manifest,
+ T <: Decoder[_]: Manifest](
+ kafkaParams: Map[String, String],
+ topics: Map[String, Int],
+ storageLevel: StorageLevel
) extends NetworkReceiver[Any] {
// Handles pushing data into the BlockManager
@@ -83,27 +87,35 @@ class KafkaReceiver[T: ClassTag, D <: Decoder[_]: Manifest](
// In case we are using multiple Threads to handle Kafka Messages
val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
- logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid"))
+ logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("group.id"))
// Kafka connection properties
val props = new Properties()
kafkaParams.foreach(param => props.put(param._1, param._2))
// Create the connection to the cluster
- logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect"))
+ logInfo("Connecting to Zookeper: " + kafkaParams("zookeeper.connect"))
val consumerConfig = new ConsumerConfig(props)
consumerConnector = Consumer.create(consumerConfig)
- logInfo("Connected to " + kafkaParams("zk.connect"))
+ logInfo("Connected to " + kafkaParams("zookeeper.connect"))
// When autooffset.reset is defined, it is our responsibility to try and whack the
// consumer group zk node.
- if (kafkaParams.contains("autooffset.reset")) {
- tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid"))
+ if (kafkaParams.contains("auto.offset.reset")) {
+ tryZookeeperConsumerGroupCleanup(kafkaParams("zookeeper.connect"), kafkaParams("group.id"))
}
+ val keyDecoder = manifest[U].runtimeClass.getConstructor(classOf[VerifiableProperties])
+ .newInstance(consumerConfig.props)
+ .asInstanceOf[Decoder[K]]
+ val valueDecoder = manifest[T].runtimeClass.getConstructor(classOf[VerifiableProperties])
+ .newInstance(consumerConfig.props)
+ .asInstanceOf[Decoder[V]]
+
// Create Threads for each Topic/Message Stream we are listening
- val decoder = manifest[D].runtimeClass.newInstance.asInstanceOf[Decoder[T]]
- val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder)
+ val topicMessageStreams = consumerConnector.createMessageStreams(
+ topics, keyDecoder, valueDecoder)
+
// Start the messages handler for each partition
topicMessageStreams.values.foreach { streams =>
@@ -112,11 +124,12 @@ class KafkaReceiver[T: ClassTag, D <: Decoder[_]: Manifest](
}
// Handles Kafka Messages
- private class MessageHandler[T: ClassTag](stream: KafkaStream[T]) extends Runnable {
+ private class MessageHandler[K: ClassTag, V: ClassTag](stream: KafkaStream[K, V])
+ extends Runnable {
def run() {
logInfo("Starting MessageHandler.")
for (msgAndMetadata <- stream) {
- blockGenerator += msgAndMetadata.message
+ blockGenerator += (msgAndMetadata.key, msgAndMetadata.message)
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MQTTInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MQTTInputDStream.scala
new file mode 100644
index 0000000000..ef4a737568
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MQTTInputDStream.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.dstream
+
+import org.apache.spark.Logging
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{ Time, DStreamCheckpointData, StreamingContext }
+
+import java.util.Properties
+import java.util.concurrent.Executors
+import java.io.IOException
+
+import org.eclipse.paho.client.mqttv3.MqttCallback
+import org.eclipse.paho.client.mqttv3.MqttClient
+import org.eclipse.paho.client.mqttv3.MqttClientPersistence
+import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence
+import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken
+import org.eclipse.paho.client.mqttv3.MqttException
+import org.eclipse.paho.client.mqttv3.MqttMessage
+import org.eclipse.paho.client.mqttv3.MqttTopic
+
+import scala.collection.Map
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
+
+/**
+ * Input stream that subscribe messages from a Mqtt Broker.
+ * Uses eclipse paho as MqttClient http://www.eclipse.org/paho/
+ * @param brokerUrl Url of remote mqtt publisher
+ * @param topic topic name to subscribe to
+ * @param storageLevel RDD storage level.
+ */
+
+private[streaming]
+class MQTTInputDStream[T: ClassTag](
+ @transient ssc_ : StreamingContext,
+ brokerUrl: String,
+ topic: String,
+ storageLevel: StorageLevel
+ ) extends NetworkInputDStream[T](ssc_) with Logging {
+
+ def getReceiver(): NetworkReceiver[T] = {
+ new MQTTReceiver(brokerUrl, topic, storageLevel)
+ .asInstanceOf[NetworkReceiver[T]]
+ }
+}
+
+private[streaming]
+class MQTTReceiver(brokerUrl: String,
+ topic: String,
+ storageLevel: StorageLevel
+ ) extends NetworkReceiver[Any] {
+ lazy protected val blockGenerator = new BlockGenerator(storageLevel)
+
+ def onStop() {
+ blockGenerator.stop()
+ }
+
+ def onStart() {
+
+ blockGenerator.start()
+
+ // Set up persistence for messages
+ var peristance: MqttClientPersistence = new MemoryPersistence()
+
+ // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance
+ var client: MqttClient = new MqttClient(brokerUrl, "MQTTSub", peristance)
+
+ // Connect to MqttBroker
+ client.connect()
+
+ // Subscribe to Mqtt topic
+ client.subscribe(topic)
+
+ // Callback automatically triggers as and when new message arrives on specified topic
+ var callback: MqttCallback = new MqttCallback() {
+
+ // Handles Mqtt message
+ override def messageArrived(arg0: String, arg1: MqttMessage) {
+ blockGenerator += new String(arg1.getPayload())
+ }
+
+ override def deliveryComplete(arg0: IMqttDeliveryToken) {
+ }
+
+ override def connectionLost(arg0: Throwable) {
+ logInfo("Connection lost " + arg0)
+ }
+ }
+
+ // Set up callback for MqttClient
+ client.setCallback(callback)
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index 394a39fbb0..ab97ee9349 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -32,7 +32,7 @@ import org.apache.spark.streaming.util.{RecurringTimer, SystemClock}
import org.apache.spark.streaming._
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.rdd.{RDD, BlockRDD}
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
/**
* Abstract class for defining any InputDStream that has to start a receiver on worker
@@ -70,7 +70,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte
val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime)
Some(new BlockRDD[T](ssc.sc, blockIds))
} else {
- Some(new BlockRDD[T](ssc.sc, Array[String]()))
+ Some(new BlockRDD[T](ssc.sc, Array[BlockId]()))
}
}
}
@@ -78,7 +78,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte
private[streaming] sealed trait NetworkReceiverMessage
private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage
-private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage
+private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any) extends NetworkReceiverMessage
private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage
/**
@@ -159,7 +159,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
/**
* Pushes a block (as an ArrayBuffer filled with data) into the block manager.
*/
- def pushBlock(blockId: String, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) {
+ def pushBlock(blockId: BlockId, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) {
env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level)
actor ! ReportBlock(blockId, metadata)
}
@@ -167,7 +167,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
/**
* Pushes a block (as bytes) into the block manager.
*/
- def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
+ def pushBlock(blockId: BlockId, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
env.blockManager.putBytes(blockId, bytes, level)
actor ! ReportBlock(blockId, metadata)
}
@@ -210,7 +210,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
class BlockGenerator(storageLevel: StorageLevel)
extends Serializable with Logging {
- case class Block(id: String, buffer: ArrayBuffer[T], metadata: Any = null)
+ case class Block(id: BlockId, buffer: ArrayBuffer[T], metadata: Any = null)
val clock = new SystemClock()
val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong
@@ -242,7 +242,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
val newBlockBuffer = currentBuffer
currentBuffer = new ArrayBuffer[T]
if (newBlockBuffer.size > 0) {
- val blockId = "input-" + NetworkReceiver.this.streamId + "-" + (time - blockInterval)
+ val blockId = StreamBlockId(NetworkReceiver.this.streamId, time - blockInterval)
val newBlock = new Block(blockId, newBlockBuffer)
blocksForPushing.add(newBlock)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
index a4746f06ad..dea0f26f90 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
@@ -18,7 +18,7 @@
package org.apache.spark.streaming.dstream
import org.apache.spark.Logging
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.StreamingContext
import scala.reflect.ClassTag
@@ -73,7 +73,7 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel)
var nextBlockNumber = 0
while (true) {
val buffer = queue.take()
- val blockId = "input-" + streamId + "-" + nextBlockNumber
+ val blockId = StreamBlockId(streamId, nextBlockNumber)
nextBlockNumber += 1
pushBlock(blockId, buffer, null, storageLevel)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
index 73e1ddf7a4..aeea060df7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
@@ -22,16 +22,22 @@ import org.apache.spark.streaming.{Duration, DStream, Time}
import scala.reflect.ClassTag
private[streaming]
-class TransformedDStream[T: ClassTag, U: ClassTag] (
- parent: DStream[T],
- transformFunc: (RDD[T], Time) => RDD[U]
- ) extends DStream[U](parent.ssc) {
+class TransformedDStream[U: ClassTag] (
+ parents: Seq[DStream[_]],
+ transformFunc: (Seq[RDD[_]], Time) => RDD[U]
+ ) extends DStream[U](parents.head.ssc) {
- override def dependencies = List(parent)
+ require(parents.length > 0, "List of DStreams to transform is empty")
+ require(parents.map(_.ssc).distinct.size == 1, "Some of the DStreams have different contexts")
+ require(parents.map(_.slideDuration).distinct.size == 1,
+ "Some of the DStreams have different slide durations")
- override def slideDuration: Duration = parent.slideDuration
+ override def dependencies = parents.toList
+
+ override def slideDuration: Duration = parents.head.slideDuration
override def compute(validTime: Time): Option[RDD[U]] = {
- parent.getOrCompute(validTime).map(transformFunc(_, validTime))
+ val parentRDDs = parents.map(_.getOrCompute(validTime).orNull).toSeq
+ Some(transformFunc(parentRDDs, validTime))
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
index ee087a1cf0..fdf5371a89 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
@@ -25,7 +25,7 @@ import akka.actor.SupervisorStrategy._
import scala.concurrent.duration._
import scala.reflect.ClassTag
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.dstream.NetworkReceiver
import java.util.concurrent.atomic.AtomicInteger
@@ -160,7 +160,7 @@ private[streaming] class ActorReceiver[T: ClassTag](
protected def pushBlock(iter: Iterator[T]) {
val buffer = new ArrayBuffer[T]
buffer ++= iter
- pushBlock("block-" + streamId + "-" + System.nanoTime(), buffer, null, storageLevel)
+ pushBlock(StreamBlockId(streamId, System.nanoTime()), buffer, null, storageLevel)
}
protected def onStart() = {
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 076fb53fa1..daeb99f5b7 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -25,6 +25,7 @@ import com.google.common.io.Files;
import kafka.serializer.StringDecoder;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.spark.streaming.api.java.JavaDStreamLike;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@@ -186,6 +187,39 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void testRepartitionMorePartitions() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 2);
+ JavaDStream repartitioned = stream.repartition(4);
+ JavaTestUtils.attachTestOutputStream(repartitioned);
+ List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2);
+ Assert.assertEquals(2, result.size());
+ for (List<List<Integer>> rdd : result) {
+ Assert.assertEquals(4, rdd.size());
+ Assert.assertEquals(
+ 10, rdd.get(0).size() + rdd.get(1).size() + rdd.get(2).size() + rdd.get(3).size());
+ }
+ }
+
+ @Test
+ public void testRepartitionFewerPartitions() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 4);
+ JavaDStream repartitioned = stream.repartition(2);
+ JavaTestUtils.attachTestOutputStream(repartitioned);
+ List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2);
+ Assert.assertEquals(2, result.size());
+ for (List<List<Integer>> rdd : result) {
+ Assert.assertEquals(2, rdd.size());
+ Assert.assertEquals(10, rdd.get(0).size() + rdd.get(1).size());
+ }
+ }
+
+ @Test
public void testGlom() {
List<List<String>> inputData = Arrays.asList(
Arrays.asList("giants", "dodgers"),
@@ -225,7 +259,7 @@ public class JavaAPISuite implements Serializable {
}
});
JavaTestUtils.attachTestOutputStream(mapped);
- List<List<List<String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+ List<List<String>> result = JavaTestUtils.runStreams(ssc, 2, 2);
Assert.assertEquals(expected, result);
}
@@ -294,8 +328,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(7,8,9));
JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc());
- JavaRDD<Integer> rdd1 = ssc.sc().parallelize(Arrays.asList(1,2,3));
- JavaRDD<Integer> rdd2 = ssc.sc().parallelize(Arrays.asList(4,5,6));
+ JavaRDD<Integer> rdd1 = ssc.sc().parallelize(Arrays.asList(1, 2, 3));
+ JavaRDD<Integer> rdd2 = ssc.sc().parallelize(Arrays.asList(4, 5, 6));
JavaRDD<Integer> rdd3 = ssc.sc().parallelize(Arrays.asList(7,8,9));
LinkedList<JavaRDD<Integer>> rdds = Lists.newLinkedList();
@@ -322,17 +356,19 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(9,10,11));
JavaDStream<Integer> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream<Integer> transformed =
- stream.transform(new Function<JavaRDD<Integer>, JavaRDD<Integer>>() {
- @Override
- public JavaRDD<Integer> call(JavaRDD<Integer> in) throws Exception {
- return in.map(new Function<Integer, Integer>() {
- @Override
- public Integer call(Integer i) throws Exception {
- return i + 2;
- }
- });
- }});
+ JavaDStream<Integer> transformed = stream.transform(
+ new Function<JavaRDD<Integer>, JavaRDD<Integer>>() {
+ @Override
+ public JavaRDD<Integer> call(JavaRDD<Integer> in) throws Exception {
+ return in.map(new Function<Integer, Integer>() {
+ @Override
+ public Integer call(Integer i) throws Exception {
+ return i + 2;
+ }
+ });
+ }
+ });
+
JavaTestUtils.attachTestOutputStream(transformed);
List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -340,6 +376,316 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void testVariousTransform() {
+ // tests whether all variations of transform can be called from Java
+
+ List<List<Integer>> inputData = Arrays.asList(Arrays.asList(1));
+ JavaDStream<Integer> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+
+ List<List<Tuple2<String, Integer>>> pairInputData =
+ Arrays.asList(Arrays.asList(new Tuple2<String, Integer>("x", 1)));
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(
+ JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1));
+
+ JavaDStream<Integer> transformed1 = stream.transform(
+ new Function<JavaRDD<Integer>, JavaRDD<Integer>>() {
+ @Override
+ public JavaRDD<Integer> call(JavaRDD<Integer> in) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Integer> transformed2 = stream.transform(
+ new Function2<JavaRDD<Integer>, Time, JavaRDD<Integer>>() {
+ @Override public JavaRDD<Integer> call(JavaRDD<Integer> in, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<String, Integer> transformed3 = stream.transform(
+ new Function<JavaRDD<Integer>, JavaPairRDD<String, Integer>>() {
+ @Override public JavaPairRDD<String, Integer> call(JavaRDD<Integer> in) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<String, Integer> transformed4 = stream.transform(
+ new Function2<JavaRDD<Integer>, Time, JavaPairRDD<String, Integer>>() {
+ @Override public JavaPairRDD<String, Integer> call(JavaRDD<Integer> in, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Integer> pairTransformed1 = pairStream.transform(
+ new Function<JavaPairRDD<String, Integer>, JavaRDD<Integer>>() {
+ @Override public JavaRDD<Integer> call(JavaPairRDD<String, Integer> in) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Integer> pairTransformed2 = pairStream.transform(
+ new Function2<JavaPairRDD<String, Integer>, Time, JavaRDD<Integer>>() {
+ @Override public JavaRDD<Integer> call(JavaPairRDD<String, Integer> in, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<String, String> pairTransformed3 = pairStream.transform(
+ new Function<JavaPairRDD<String, Integer>, JavaPairRDD<String, String>>() {
+ @Override public JavaPairRDD<String, String> call(JavaPairRDD<String, Integer> in) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<String, String> pairTransformed4 = pairStream.transform(
+ new Function2<JavaPairRDD<String, Integer>, Time, JavaPairRDD<String, String>>() {
+ @Override public JavaPairRDD<String, String> call(JavaPairRDD<String, Integer> in, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ }
+
+ @Test
+ public void testTransformWith() {
+ List<List<Tuple2<String, String>>> stringStringKVStream1 = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, String>("california", "dodgers"),
+ new Tuple2<String, String>("new york", "yankees")),
+ Arrays.asList(
+ new Tuple2<String, String>("california", "sharks"),
+ new Tuple2<String, String>("new york", "rangers")));
+
+ List<List<Tuple2<String, String>>> stringStringKVStream2 = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, String>("california", "giants"),
+ new Tuple2<String, String>("new york", "mets")),
+ Arrays.asList(
+ new Tuple2<String, String>("california", "ducks"),
+ new Tuple2<String, String>("new york", "islanders")));
+
+
+ List<List<Tuple2<String, Tuple2<String, String>>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, Tuple2<String, String>>("california",
+ new Tuple2<String, String>("dodgers", "giants")),
+ new Tuple2<String, Tuple2<String, String>>("new york",
+ new Tuple2<String, String>("yankees", "mets"))),
+ Arrays.asList(
+ new Tuple2<String, Tuple2<String, String>>("california",
+ new Tuple2<String, String>("sharks", "ducks")),
+ new Tuple2<String, Tuple2<String, String>>("new york",
+ new Tuple2<String, String>("rangers", "islanders"))));
+
+ JavaDStream<Tuple2<String, String>> stream1 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream1, 1);
+ JavaPairDStream<String, String> pairStream1 = JavaPairDStream.fromJavaDStream(stream1);
+
+ JavaDStream<Tuple2<String, String>> stream2 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream2, 1);
+ JavaPairDStream<String, String> pairStream2 = JavaPairDStream.fromJavaDStream(stream2);
+
+ JavaPairDStream<String, Tuple2<String, String>> joined = pairStream1.transformWith(
+ pairStream2,
+ new Function3<
+ JavaPairRDD<String, String>,
+ JavaPairRDD<String, String>,
+ Time,
+ JavaPairRDD<String, Tuple2<String, String>>
+ >() {
+ @Override
+ public JavaPairRDD<String, Tuple2<String, String>> call(
+ JavaPairRDD<String, String> rdd1,
+ JavaPairRDD<String, String> rdd2,
+ Time time
+ ) throws Exception {
+ return rdd1.join(rdd2);
+ }
+ }
+ );
+
+ JavaTestUtils.attachTestOutputStream(joined);
+ List<List<Tuple2<String, Tuple2<String, String>>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+
+ @Test
+ public void testVariousTransformWith() {
+ // tests whether all variations of transformWith can be called from Java
+
+ List<List<Integer>> inputData1 = Arrays.asList(Arrays.asList(1));
+ List<List<String>> inputData2 = Arrays.asList(Arrays.asList("x"));
+ JavaDStream<Integer> stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 1);
+ JavaDStream<String> stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1);
+
+ List<List<Tuple2<String, Integer>>> pairInputData1 =
+ Arrays.asList(Arrays.asList(new Tuple2<String, Integer>("x", 1)));
+ List<List<Tuple2<Double, Character>>> pairInputData2 =
+ Arrays.asList(Arrays.asList(new Tuple2<Double, Character>(1.0, 'x')));
+ JavaPairDStream<String, Integer> pairStream1 = JavaPairDStream.fromJavaDStream(
+ JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1));
+ JavaPairDStream<Double, Character> pairStream2 = JavaPairDStream.fromJavaDStream(
+ JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1));
+
+ JavaDStream<Double> transformed1 = stream1.transformWith(
+ stream2,
+ new Function3<JavaRDD<Integer>, JavaRDD<String>, Time, JavaRDD<Double>>() {
+ @Override
+ public JavaRDD<Double> call(JavaRDD<Integer> rdd1, JavaRDD<String> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Double> transformed2 = stream1.transformWith(
+ pairStream1,
+ new Function3<JavaRDD<Integer>, JavaPairRDD<String, Integer>, Time, JavaRDD<Double>>() {
+ @Override
+ public JavaRDD<Double> call(JavaRDD<Integer> rdd1, JavaPairRDD<String, Integer> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<Double, Double> transformed3 = stream1.transformWith(
+ stream2,
+ new Function3<JavaRDD<Integer>, JavaRDD<String>, Time, JavaPairRDD<Double, Double>>() {
+ @Override
+ public JavaPairRDD<Double, Double> call(JavaRDD<Integer> rdd1, JavaRDD<String> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<Double, Double> transformed4 = stream1.transformWith(
+ pairStream1,
+ new Function3<JavaRDD<Integer>, JavaPairRDD<String, Integer>, Time, JavaPairRDD<Double, Double>>() {
+ @Override
+ public JavaPairRDD<Double, Double> call(JavaRDD<Integer> rdd1, JavaPairRDD<String, Integer> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Double> pairTransformed1 = pairStream1.transformWith(
+ stream2,
+ new Function3<JavaPairRDD<String, Integer>, JavaRDD<String>, Time, JavaRDD<Double>>() {
+ @Override
+ public JavaRDD<Double> call(JavaPairRDD<String, Integer> rdd1, JavaRDD<String> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Double> pairTransformed2_ = pairStream1.transformWith(
+ pairStream1,
+ new Function3<JavaPairRDD<String, Integer>, JavaPairRDD<String, Integer>, Time, JavaRDD<Double>>() {
+ @Override
+ public JavaRDD<Double> call(JavaPairRDD<String, Integer> rdd1, JavaPairRDD<String, Integer> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<Double, Double> pairTransformed3 = pairStream1.transformWith(
+ stream2,
+ new Function3<JavaPairRDD<String, Integer>, JavaRDD<String>, Time, JavaPairRDD<Double, Double>>() {
+ @Override
+ public JavaPairRDD<Double, Double> call(JavaPairRDD<String, Integer> rdd1, JavaRDD<String> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<Double, Double> pairTransformed4 = pairStream1.transformWith(
+ pairStream2,
+ new Function3<JavaPairRDD<String, Integer>, JavaPairRDD<Double, Character>, Time, JavaPairRDD<Double, Double>>() {
+ @Override
+ public JavaPairRDD<Double, Double> call(JavaPairRDD<String, Integer> rdd1, JavaPairRDD<Double, Character> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+ }
+
+ @Test
+ public void testStreamingContextTransform(){
+ List<List<Integer>> stream1input = Arrays.asList(
+ Arrays.asList(1),
+ Arrays.asList(2)
+ );
+
+ List<List<Integer>> stream2input = Arrays.asList(
+ Arrays.asList(3),
+ Arrays.asList(4)
+ );
+
+ List<List<Tuple2<Integer, String>>> pairStream1input = Arrays.asList(
+ Arrays.asList(new Tuple2<Integer, String>(1, "x")),
+ Arrays.asList(new Tuple2<Integer, String>(2, "y"))
+ );
+
+ List<List<Tuple2<Integer, Tuple2<Integer, String>>>> expected = Arrays.asList(
+ Arrays.asList(new Tuple2<Integer, Tuple2<Integer, String>>(1, new Tuple2<Integer, String>(1, "x"))),
+ Arrays.asList(new Tuple2<Integer, Tuple2<Integer, String>>(2, new Tuple2<Integer, String>(2, "y")))
+ );
+
+ JavaDStream<Integer> stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1);
+ JavaDStream<Integer> stream2 = JavaTestUtils.attachTestInputStream(ssc, stream2input, 1);
+ JavaPairDStream<Integer, String> pairStream1 = JavaPairDStream.fromJavaDStream(
+ JavaTestUtils.attachTestInputStream(ssc, pairStream1input, 1));
+
+ List<JavaDStream<?>> listOfDStreams1 = Arrays.<JavaDStream<?>>asList(stream1, stream2);
+
+ // This is just to test whether this transform to JavaStream compiles
+ JavaDStream<Long> transformed1 = ssc.transform(
+ listOfDStreams1,
+ new Function2<List<JavaRDD<?>>, Time, JavaRDD<Long>>() {
+ public JavaRDD<Long> call(List<JavaRDD<?>> listOfRDDs, Time time) {
+ assert(listOfRDDs.size() == 2);
+ return null;
+ }
+ }
+ );
+
+ List<JavaDStream<?>> listOfDStreams2 =
+ Arrays.<JavaDStream<?>>asList(stream1, stream2, pairStream1.toJavaDStream());
+
+ JavaPairDStream<Integer, Tuple2<Integer, String>> transformed2 = ssc.transform(
+ listOfDStreams2,
+ new Function2<List<JavaRDD<?>>, Time, JavaPairRDD<Integer, Tuple2<Integer, String>>>() {
+ public JavaPairRDD<Integer, Tuple2<Integer, String>> call(List<JavaRDD<?>> listOfRDDs, Time time) {
+ assert(listOfRDDs.size() == 3);
+ JavaRDD<Integer> rdd1 = (JavaRDD<Integer>)listOfRDDs.get(0);
+ JavaRDD<Integer> rdd2 = (JavaRDD<Integer>)listOfRDDs.get(1);
+ JavaRDD<Tuple2<Integer, String>> rdd3 = (JavaRDD<Tuple2<Integer, String>>)listOfRDDs.get(2);
+ JavaPairRDD<Integer, String> prdd3 = JavaPairRDD.fromJavaRDD(rdd3);
+ PairFunction<Integer, Integer, Integer> mapToTuple = new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i, i);
+ }
+ };
+ return rdd1.union(rdd2).map(mapToTuple).join(prdd3);
+ }
+ }
+ );
+ JavaTestUtils.attachTestOutputStream(transformed2);
+ List<List<Tuple2<Integer, Tuple2<Integer, String>>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
public void testFlatMap() {
List<List<String>> inputData = Arrays.asList(
Arrays.asList("go", "giants"),
@@ -1101,7 +1447,7 @@ public class JavaAPISuite implements Serializable {
JavaPairDStream<String, Tuple2<List<String>, List<String>>> grouped = pairStream1.cogroup(pairStream2);
JavaTestUtils.attachTestOutputStream(grouped);
- List<List<Tuple2<String, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+ List<List<Tuple2<String, Tuple2<List<String>, List<String>>>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
Assert.assertEquals(expected, result);
}
@@ -1144,7 +1490,38 @@ public class JavaAPISuite implements Serializable {
JavaPairDStream<String, Tuple2<String, String>> joined = pairStream1.join(pairStream2);
JavaTestUtils.attachTestOutputStream(joined);
- List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+ List<List<Tuple2<String, Tuple2<String, String>>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testLeftOuterJoin() {
+ List<List<Tuple2<String, String>>> stringStringKVStream1 = Arrays.asList(
+ Arrays.asList(new Tuple2<String, String>("california", "dodgers"),
+ new Tuple2<String, String>("new york", "yankees")),
+ Arrays.asList(new Tuple2<String, String>("california", "sharks") ));
+
+ List<List<Tuple2<String, String>>> stringStringKVStream2 = Arrays.asList(
+ Arrays.asList(new Tuple2<String, String>("california", "giants") ),
+ Arrays.asList(new Tuple2<String, String>("new york", "islanders") )
+
+ );
+
+ List<List<Long>> expected = Arrays.asList(Arrays.asList(2L), Arrays.asList(1L));
+
+ JavaDStream<Tuple2<String, String>> stream1 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream1, 1);
+ JavaPairDStream<String, String> pairStream1 = JavaPairDStream.fromJavaDStream(stream1);
+
+ JavaDStream<Tuple2<String, String>> stream2 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream2, 1);
+ JavaPairDStream<String, String> pairStream2 = JavaPairDStream.fromJavaDStream(stream2);
+
+ JavaPairDStream<String, Tuple2<String, Optional<String>>> joined = pairStream1.leftOuterJoin(pairStream2);
+ JavaDStream<Long> counted = joined.count();
+ JavaTestUtils.attachTestOutputStream(counted);
+ List<List<Long>> result = JavaTestUtils.runStreams(ssc, 2, 2);
Assert.assertEquals(expected, result);
}
@@ -1222,14 +1599,20 @@ public class JavaAPISuite implements Serializable {
@Test
public void testKafkaStream() {
HashMap<String, Integer> topics = Maps.newHashMap();
- JavaDStream<String> test1 = ssc.kafkaStream("localhost:12345", "group", topics);
- JavaDStream<String> test2 = ssc.kafkaStream("localhost:12345", "group", topics,
+ JavaPairDStream<String, String> test1 = ssc.kafkaStream("localhost:12345", "group", topics);
+ JavaPairDStream<String, String> test2 = ssc.kafkaStream("localhost:12345", "group", topics,
StorageLevel.MEMORY_AND_DISK());
HashMap<String, String> kafkaParams = Maps.newHashMap();
- kafkaParams.put("zk.connect","localhost:12345");
- kafkaParams.put("groupid","consumer-group");
- JavaDStream<String> test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics,
+ kafkaParams.put("zookeeper.connect","localhost:12345");
+ kafkaParams.put("group.id","consumer-group");
+ JavaPairDStream<String, String> test3 = ssc.kafkaStream(
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ kafkaParams,
+ topics,
StorageLevel.MEMORY_AND_DISK());
}
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
index d5cdad4998..42ab9590d6 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
@@ -35,9 +35,9 @@ trait JavaTestBase extends TestSuiteBase {
* The stream will be derived from the supplied lists of Java objects.
*/
def attachTestInputStream[T](
- ssc: JavaStreamingContext,
- data: JList[JList[T]],
- numPartitions: Int) = {
+ ssc: JavaStreamingContext,
+ data: JList[JList[T]],
+ numPartitions: Int) = {
val seqData = data.map(Seq(_:_*))
implicit val cm: ClassTag[T] =
@@ -52,12 +52,11 @@ trait JavaTestBase extends TestSuiteBase {
* [[org.apache.spark.streaming.TestOutputStream]].
**/
def attachTestOutputStream[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]](
- dstream: JavaDStreamLike[T, This, R]) =
+ dstream: JavaDStreamLike[T, This, R]) =
{
implicit val cm: ClassTag[T] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
- val ostream = new TestOutputStream(dstream.dstream,
- new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]])
+ val ostream = new TestOutputStreamWithPartitions(dstream.dstream)
dstream.dstream.ssc.registerOutputStream(ostream)
}
@@ -65,9 +64,11 @@ trait JavaTestBase extends TestSuiteBase {
* Process all registered streams for a numBatches batches, failing if
* numExpectedOutput RDD's are not generated. Generated RDD's are collected
* and returned, represented as a list for each batch interval.
+ *
+ * Returns a list of items for each RDD.
*/
def runStreams[V](
- ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = {
+ ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = {
implicit val cm: ClassTag[V] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput)
@@ -75,6 +76,27 @@ trait JavaTestBase extends TestSuiteBase {
res.map(entry => out.append(new ArrayList[V](entry)))
out
}
+
+ /**
+ * Process all registered streams for a numBatches batches, failing if
+ * numExpectedOutput RDD's are not generated. Generated RDD's are collected
+ * and returned, represented as a list for each batch interval.
+ *
+ * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each
+ * representing one partition.
+ */
+ def runStreamsWithPartitions[V](ssc: JavaStreamingContext, numBatches: Int,
+ numExpectedOutput: Int): JList[JList[JList[V]]] = {
+ implicit val cm: ClassTag[V] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
+ val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput)
+ val out = new ArrayList[JList[JList[V]]]()
+ res.map{entry =>
+ val lists = entry.map(new ArrayList[V](_))
+ out.append(new ArrayList[JList[V]](lists))
+ }
+ out
+ }
}
object JavaTestUtils extends JavaTestBase {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 11586f72b6..259ef1608c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -18,7 +18,10 @@
package org.apache.spark.streaming
import org.apache.spark.streaming.StreamingContext._
-import scala.runtime.RichInt
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext._
+
import util.ManualClock
class BasicOperationsSuite extends TestSuiteBase {
@@ -82,6 +85,44 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(input, operation, output, true)
}
+ test("repartition (more partitions)") {
+ val input = Seq(1 to 100, 101 to 200, 201 to 300)
+ val operation = (r: DStream[Int]) => r.repartition(5)
+ val ssc = setupStreams(input, operation, 2)
+ val output = runStreamsWithPartitions(ssc, 3, 3)
+ assert(output.size === 3)
+ val first = output(0)
+ val second = output(1)
+ val third = output(2)
+
+ assert(first.size === 5)
+ assert(second.size === 5)
+ assert(third.size === 5)
+
+ assert(first.flatten.toSet === (1 to 100).toSet)
+ assert(second.flatten.toSet === (101 to 200).toSet)
+ assert(third.flatten.toSet === (201 to 300).toSet)
+ }
+
+ test("repartition (fewer partitions)") {
+ val input = Seq(1 to 100, 101 to 200, 201 to 300)
+ val operation = (r: DStream[Int]) => r.repartition(2)
+ val ssc = setupStreams(input, operation, 5)
+ val output = runStreamsWithPartitions(ssc, 3, 3)
+ assert(output.size === 3)
+ val first = output(0)
+ val second = output(1)
+ val third = output(2)
+
+ assert(first.size === 2)
+ assert(second.size === 2)
+ assert(third.size === 2)
+
+ assert(first.flatten.toSet === (1 to 100).toSet)
+ assert(second.flatten.toSet === (101 to 200).toSet)
+ assert(third.flatten.toSet === (201 to 300).toSet)
+ }
+
test("groupByKey") {
testOperation(
Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
@@ -143,6 +184,72 @@ class BasicOperationsSuite extends TestSuiteBase {
)
}
+ test("union") {
+ val input = Seq(1 to 4, 101 to 104, 201 to 204)
+ val output = Seq(1 to 8, 101 to 108, 201 to 208)
+ testOperation(
+ input,
+ (s: DStream[Int]) => s.union(s.map(_ + 4)) ,
+ output
+ )
+ }
+
+ test("StreamingContext.union") {
+ val input = Seq(1 to 4, 101 to 104, 201 to 204)
+ val output = Seq(1 to 12, 101 to 112, 201 to 212)
+ // union over 3 DStreams
+ testOperation(
+ input,
+ (s: DStream[Int]) => s.context.union(Seq(s, s.map(_ + 4), s.map(_ + 8))),
+ output
+ )
+ }
+
+ test("transform") {
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ testOperation(
+ input,
+ (r: DStream[Int]) => r.transform(rdd => rdd.map(_.toString)), // RDD.map in transform
+ input.map(_.map(_.toString))
+ )
+ }
+
+ test("transformWith") {
+ val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") )
+ val outputData = Seq(
+ Seq( ("a", (1, "x")), ("b", (1, "x")) ),
+ Seq( ("", (1, "x")) ),
+ Seq( ),
+ Seq( )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ val t1 = s1.map(x => (x, 1))
+ val t2 = s2.map(x => (x, "x"))
+ t1.transformWith( // RDD.join in transform
+ t2,
+ (rdd1: RDD[(String, Int)], rdd2: RDD[(String, String)]) => rdd1.join(rdd2)
+ )
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
+ test("StreamingContext.transform") {
+ val input = Seq(1 to 4, 101 to 104, 201 to 204)
+ val output = Seq(1 to 12, 101 to 112, 201 to 212)
+
+ // transform over 3 DStreams by doing union of the 3 RDDs
+ val operation = (s: DStream[Int]) => {
+ s.context.transform(
+ Seq(s, s.map(_ + 4), s.map(_ + 8)), // 3 DStreams
+ (rdds: Seq[RDD[_]], time: Time) =>
+ rdds.head.context.union(rdds.map(_.asInstanceOf[RDD[Int]])) // union of RDDs
+ )
+ }
+
+ testOperation(input, operation, output)
+ }
+
test("cogroup") {
val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() )
val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() )
@@ -168,7 +275,37 @@ class BasicOperationsSuite extends TestSuiteBase {
Seq( )
)
val operation = (s1: DStream[String], s2: DStream[String]) => {
- s1.map(x => (x,1)).join(s2.map(x => (x,"x")))
+ s1.map(x => (x, 1)).join(s2.map(x => (x, "x")))
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
+ test("leftOuterJoin") {
+ val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") )
+ val outputData = Seq(
+ Seq( ("a", (1, Some("x"))), ("b", (1, Some("x"))) ),
+ Seq( ("", (1, Some("x"))), ("a", (1, None)) ),
+ Seq( ("", (1, None)) ),
+ Seq( )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ s1.map(x => (x, 1)).leftOuterJoin(s2.map(x => (x, "x")))
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
+ test("rightOuterJoin") {
+ val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") )
+ val outputData = Seq(
+ Seq( ("a", (Some(1), "x")), ("b", (Some(1), "x")) ),
+ Seq( ("", (Some(1), "x")), ("b", (None, "x")) ),
+ Seq( ),
+ Seq( ("", (None, "x")) )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ s1.map(x => (x, 1)).rightOuterJoin(s2.map(x => (x, "x")))
}
testOperation(inputData1, inputData2, operation, outputData, true)
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 07de51bebb..e81287b44e 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -372,7 +372,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
logInfo("Manual clock after advancing = " + clock.time)
Thread.sleep(batchDuration.milliseconds)
- val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
- outputStream.output
+ val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
+ outputStream.output.map(_.flatten)
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index 42e3e51e3f..c29b75ece6 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -268,8 +268,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK)
// Test specifying decoder
- val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group")
- val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
+ val kafkaParams = Map("zookeeper.connect"->"localhost:12345","group.id"->"consumer-group")
+ val test3 = ssc.kafkaStream[
+ String,
+ String,
+ kafka.serializer.StringDecoder,
+ kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index c91f9ba46d..126915abc9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -61,8 +61,11 @@ class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]],
/**
* This is a output stream just for the testsuites. All the output is collected into a
* ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ *
+ * The buffer contains a sequence of RDD's, each containing a sequence of items
*/
-class TestOutputStream[T: ClassTag](parent: DStream[T], val output: ArrayBuffer[Seq[T]])
+class TestOutputStream[T: ClassTag](parent: DStream[T],
+ val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]())
extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
val collected = rdd.collect()
output += collected
@@ -77,6 +80,30 @@ class TestOutputStream[T: ClassTag](parent: DStream[T], val output: ArrayBuffer[
}
/**
+ * This is a output stream just for the testsuites. All the output is collected into a
+ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ *
+ * The buffer contains a sequence of RDD's, each containing a sequence of partitions, each
+ * containing a sequence of items.
+ */
+class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T],
+ val output: ArrayBuffer[Seq[Seq[T]]] = ArrayBuffer[Seq[Seq[T]]]())
+ extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
+ val collected = rdd.glom().collect().map(_.toSeq)
+ output += collected
+ }) {
+
+ // This is to clear the output buffer every it is read from a checkpoint
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ ois.defaultReadObject()
+ output.clear()
+ }
+
+ def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten))
+}
+
+/**
* This is the base trait for Spark Streaming testsuites. This provides basic functionality
* to run user-defined set of input on user-defined stream operations, and verify the output.
*/
@@ -109,7 +136,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
*/
def setupStreams[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
- operation: DStream[U] => DStream[V]
+ operation: DStream[U] => DStream[V],
+ numPartitions: Int = numInputPartitions
): StreamingContext = {
// Create StreamingContext
@@ -119,9 +147,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
}
// Setup the stream computation
- val inputStream = new TestInputStream(ssc, input, numInputPartitions)
+ val inputStream = new TestInputStream(ssc, input, numPartitions)
val operatedStream = operation(inputStream)
- val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]])
+ val outputStream = new TestOutputStreamWithPartitions(operatedStream,
+ new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
ssc.registerInputStream(inputStream)
ssc.registerOutputStream(outputStream)
ssc
@@ -147,7 +176,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions)
val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions)
val operatedStream = operation(inputStream1, inputStream2)
- val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]])
+ val outputStream = new TestOutputStreamWithPartitions(operatedStream,
+ new ArrayBuffer[Seq[Seq[W]]] with SynchronizedBuffer[Seq[Seq[W]]])
ssc.registerInputStream(inputStream1)
ssc.registerInputStream(inputStream2)
ssc.registerOutputStream(outputStream)
@@ -158,18 +188,37 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
* Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
* returns the collected output. It will wait until `numExpectedOutput` number of
* output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
+ *
+ * Returns a sequence of items for each RDD.
*/
def runStreams[V: ClassTag](
ssc: StreamingContext,
numBatches: Int,
numExpectedOutput: Int
): Seq[Seq[V]] = {
+ // Flatten each RDD into a single Seq
+ runStreamsWithPartitions(ssc, numBatches, numExpectedOutput).map(_.flatten.toSeq)
+ }
+
+ /**
+ * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
+ * returns the collected output. It will wait until `numExpectedOutput` number of
+ * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
+ *
+ * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each
+ * representing one partition.
+ */
+ def runStreamsWithPartitions[V: ClassTag](
+ ssc: StreamingContext,
+ numBatches: Int,
+ numExpectedOutput: Int
+ ): Seq[Seq[Seq[V]]] = {
assert(numBatches > 0, "Number of batches to run stream computation is zero")
assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero")
logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
// Get the output buffer
- val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
+ val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
val output = outputStream.output
try {
diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala
index f824c472ae..f670f65bf5 100644
--- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala
@@ -199,7 +199,7 @@ object JavaAPICompletenessChecker {
private def toJavaMethod(method: SparkMethod): SparkMethod = {
val params = method.parameters
- .filterNot(_.name == "scala.reflect.ClassManifest")
+ .filterNot(_.name == "scala.reflect.ClassTag")
.map(toJavaType(_, isReturnType = false))
SparkMethod(method.name, toJavaType(method.returnType, isReturnType = true), params)
}
@@ -212,7 +212,7 @@ object JavaAPICompletenessChecker {
// internal Spark components.
val excludedNames = Seq(
"org.apache.spark.rdd.RDD.origin",
- "org.apache.spark.rdd.RDD.elementClassManifest",
+ "org.apache.spark.rdd.RDD.elementClassTag",
"org.apache.spark.rdd.RDD.checkpointData",
"org.apache.spark.rdd.RDD.partitioner",
"org.apache.spark.rdd.RDD.partitions",
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 858b58d338..c1a87d3373 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -17,22 +17,25 @@
package org.apache.spark.deploy.yarn
+import java.io.IOException;
import java.net.Socket
+import java.security.PrivilegedExceptionAction
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.util.ShutdownHookManager
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
-import scala.collection.JavaConversions._
import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.util.Utils
import org.apache.hadoop.security.UserGroupInformation
-import java.security.PrivilegedExceptionAction
+import scala.collection.JavaConversions._
class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
@@ -43,18 +46,26 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
private var appAttemptId: ApplicationAttemptId = null
private var userThread: Thread = null
private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+ private val fs = FileSystem.get(yarnConf)
private var yarnAllocator: YarnAllocationHandler = null
private var isFinished:Boolean = false
private var uiAddress: String = ""
+ private val maxAppAttempts: Int = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES,
+ YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES)
+ private var isLastAMRetry: Boolean = true
def run() {
// setup the directories so things go to yarn approved directories rather
// then user specified and /tmp
System.setProperty("spark.local.dir", getLocalDirs())
+
+ // use priority 30 as its higher then HDFS. Its same priority as MapReduce is using
+ ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30)
appAttemptId = getApplicationAttemptId()
+ isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts;
resourceManager = registerWithResourceManager()
// Workaround until hadoop moves to something which has
@@ -183,6 +194,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
// It need shutdown hook to set SUCCEEDED
successed = true
} finally {
+ logDebug("finishing main")
+ isLastAMRetry = true;
if (successed) {
ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
} else {
@@ -229,8 +242,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
}
}
-
-
private def allocateWorkers() {
try {
logInfo("Allocating " + args.numWorkers + " workers.")
@@ -329,6 +340,40 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
resourceManager.finishApplicationMaster(finishReq)
}
+
+ /**
+ * clean up the staging directory.
+ */
+ private def cleanupStagingDir() {
+ var stagingDirPath: Path = null
+ try {
+ val preserveFiles = System.getProperty("spark.yarn.preserve.staging.files", "false").toBoolean
+ if (!preserveFiles) {
+ stagingDirPath = new Path(System.getenv("SPARK_YARN_JAR_PATH")).getParent()
+ if (stagingDirPath == null) {
+ logError("Staging directory is null")
+ return
+ }
+ logInfo("Deleting staging directory " + stagingDirPath)
+ fs.delete(stagingDirPath, true)
+ }
+ } catch {
+ case e: IOException =>
+ logError("Failed to cleanup staging dir " + stagingDirPath, e)
+ }
+ }
+
+ // The shutdown hook that runs when a signal is received AND during normal
+ // close of the JVM.
+ class AppMasterShutdownHook(appMaster: ApplicationMaster) extends Runnable {
+
+ def run() {
+ logInfo("AppMaster received a signal.")
+ // we need to clean up staging dir before HDFS is shut down
+ // make sure we don't delete it until this is the last AM
+ if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir()
+ }
+ }
}
@@ -368,6 +413,8 @@ object ApplicationMaster {
// Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
// Should not really have to do this, but it helps yarn to evict resources earlier.
// not to mention, prevent Client declaring failure even though we exit'ed properly.
+ // Note that this will unfortunately not properly clean up the staging files because it gets called to
+ // late and the filesystem is already shutdown.
if (modified) {
Runtime.getRuntime().addShutdownHook(new Thread with Logging {
// This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 076dd3c9b0..1a380ae714 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -45,7 +45,13 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
var rpc: YarnRPC = YarnRPC.create(conf)
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
- val credentials = UserGroupInformation.getCurrentUser().getCredentials();
+ val credentials = UserGroupInformation.getCurrentUser().getCredentials()
+ private var distFiles = None: Option[String]
+ private var distFilesTimeStamps = None: Option[String]
+ private var distFilesFileSizes = None: Option[String]
+ private var distArchives = None: Option[String]
+ private var distArchivesTimeStamps = None: Option[String]
+ private var distArchivesFileSizes = None: Option[String]
def run() {
init(yarnConf)
@@ -57,7 +63,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
verifyClusterResources(newApp)
val appContext = createApplicationSubmissionContext(appId)
- val localResources = prepareLocalResources(appId, "spark")
+ val localResources = prepareLocalResources(appId, ".sparkStaging")
val env = setupLaunchEnv(localResources)
val amContainer = createContainerLaunchContext(newApp, localResources, env)
@@ -109,10 +115,73 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
appContext.setApplicationName(args.appName)
return appContext
}
-
- def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = {
+
+ /**
+ * Copy the local file into HDFS and configure to be distributed with the
+ * job via the distributed cache.
+ * If a fragment is specified the file will be referenced as that fragment.
+ */
+ private def copyLocalFile(
+ dstDir: Path,
+ resourceType: LocalResourceType,
+ originalPath: Path,
+ replication: Short,
+ localResources: HashMap[String,LocalResource],
+ fragment: String,
+ appMasterOnly: Boolean = false): Unit = {
+ val fs = FileSystem.get(conf)
+ val newPath = new Path(dstDir, originalPath.getName())
+ logInfo("Uploading " + originalPath + " to " + newPath)
+ fs.copyFromLocalFile(false, true, originalPath, newPath)
+ fs.setReplication(newPath, replication);
+ val destStatus = fs.getFileStatus(newPath)
+
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(resourceType)
+ amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ var pathURI: URI = new URI(newPath.toString() + "#" + originalPath.getName());
+ if ((fragment == null) || (fragment.isEmpty())){
+ localResources(originalPath.getName()) = amJarRsrc
+ } else {
+ localResources(fragment) = amJarRsrc
+ pathURI = new URI(newPath.toString() + "#" + fragment);
+ }
+ val distPath = pathURI.toString()
+ if (appMasterOnly == true) return
+ if (resourceType == LocalResourceType.FILE) {
+ distFiles match {
+ case Some(path) =>
+ distFilesFileSizes = Some(distFilesFileSizes.get + "," +
+ destStatus.getLen().toString())
+ distFilesTimeStamps = Some(distFilesTimeStamps.get + "," +
+ destStatus.getModificationTime().toString())
+ distFiles = Some(path + "," + distPath)
+ case _ =>
+ distFilesFileSizes = Some(destStatus.getLen().toString())
+ distFilesTimeStamps = Some(destStatus.getModificationTime().toString())
+ distFiles = Some(distPath)
+ }
+ } else {
+ distArchives match {
+ case Some(path) =>
+ distArchivesTimeStamps = Some(distArchivesTimeStamps.get + "," +
+ destStatus.getModificationTime().toString())
+ distArchivesFileSizes = Some(distArchivesFileSizes.get + "," +
+ destStatus.getLen().toString())
+ distArchives = Some(path + "," + distPath)
+ case _ =>
+ distArchivesTimeStamps = Some(destStatus.getModificationTime().toString())
+ distArchivesFileSizes = Some(destStatus.getLen().toString())
+ distArchives = Some(distPath)
+ }
+ }
+ }
+
+ def prepareLocalResources(appId: ApplicationId, sparkStagingDir: String): HashMap[String, LocalResource] = {
logInfo("Preparing Local resources")
- val locaResources = HashMap[String, LocalResource]()
// Upload Spark and the application JAR to the remote file system
// Add them as local resources to the AM
val fs = FileSystem.get(conf)
@@ -125,33 +194,69 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
}
}
+ val pathSuffix = sparkStagingDir + "/" + appId.toString() + "/"
+ val dst = new Path(fs.getHomeDirectory(), pathSuffix)
+ val replication = System.getProperty("spark.yarn.submit.file.replication", "3").toShort
+
+ if (UserGroupInformation.isSecurityEnabled()) {
+ val dstFs = dst.getFileSystem(conf)
+ dstFs.addDelegationTokens(delegTokenRenewer, credentials);
+ }
+ val localResources = HashMap[String, LocalResource]()
+
Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
.foreach { case(destName, _localPath) =>
val localPath: String = if (_localPath != null) _localPath.trim() else ""
if (! localPath.isEmpty()) {
val src = new Path(localPath)
- val pathSuffix = appName + "/" + appId.getId() + destName
- val dst = new Path(fs.getHomeDirectory(), pathSuffix)
- logInfo("Uploading " + src + " to " + dst)
- fs.copyFromLocalFile(false, true, src, dst)
- val destStatus = fs.getFileStatus(dst)
-
- // get tokens for anything we upload to hdfs
- if (UserGroupInformation.isSecurityEnabled()) {
- fs.addDelegationTokens(delegTokenRenewer, credentials);
- }
+ val newPath = new Path(dst, destName)
+ logInfo("Uploading " + src + " to " + newPath)
+ fs.copyFromLocalFile(false, true, src, newPath)
+ fs.setReplication(newPath, replication);
+ val destStatus = fs.getFileStatus(newPath)
val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
amJarRsrc.setType(LocalResourceType.FILE)
amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
- amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst))
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath))
amJarRsrc.setTimestamp(destStatus.getModificationTime())
amJarRsrc.setSize(destStatus.getLen())
- locaResources(destName) = amJarRsrc
+ localResources(destName) = amJarRsrc
+ }
+ }
+
+ // handle any add jars
+ if ((args.addJars != null) && (!args.addJars.isEmpty())){
+ args.addJars.split(',').foreach { case file: String =>
+ val tmpURI = new URI(file)
+ val tmp = new Path(tmpURI)
+ copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources,
+ tmpURI.getFragment(), true)
+ }
+ }
+
+ // handle any distributed cache files
+ if ((args.files != null) && (!args.files.isEmpty())){
+ args.files.split(',').foreach { case file: String =>
+ val tmpURI = new URI(file)
+ val tmp = new Path(tmpURI)
+ copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources,
+ tmpURI.getFragment())
+ }
+ }
+
+ // handle any distributed cache archives
+ if ((args.archives != null) && (!args.archives.isEmpty())) {
+ args.archives.split(',').foreach { case file:String =>
+ val tmpURI = new URI(file)
+ val tmp = new Path(tmpURI)
+ copyLocalFile(dst, LocalResourceType.ARCHIVE, tmp, replication,
+ localResources, tmpURI.getFragment())
}
}
+
UserGroupInformation.getCurrentUser().addCredentials(credentials);
- return locaResources
+ return localResources
}
def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
@@ -160,12 +265,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
val env = new HashMap[String, String]()
- // If log4j present, ensure ours overrides all others
- if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
-
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
- Client.populateHadoopClasspath(yarnConf, env)
+ Client.populateClasspath(yarnConf, log4jConfLocalRes != null, env)
env("SPARK_YARN_MODE") = "true"
env("SPARK_YARN_JAR_PATH") =
localResources("spark.jar").getResource().getScheme.toString() + "://" +
@@ -186,6 +286,18 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
}
+ // set the environment variables to be passed on to the Workers
+ if (distFiles != None) {
+ env("SPARK_YARN_CACHE_FILES") = distFiles.get
+ env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = distFilesTimeStamps.get
+ env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = distFilesFileSizes.get
+ }
+ if (distArchives != None) {
+ env("SPARK_YARN_CACHE_ARCHIVES") = distArchives.get
+ env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = distArchivesTimeStamps.get
+ env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = distArchivesFileSizes.get
+ }
+
// allow users to specify some environment variables
Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
@@ -335,4 +447,30 @@ object Client {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
}
}
+
+ def populateClasspath(conf: Configuration, addLog4j: Boolean, env: HashMap[String, String]) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
+ // If log4j present, ensure ours overrides all others
+ if (addLog4j) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "log4j.properties")
+ }
+ // normally the users app.jar is last in case conflicts with spark jars
+ val userClasspathFirst = System.getProperty("spark.yarn.user.classpath.first", "false")
+ .toBoolean
+ if (userClasspathFirst) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "app.jar")
+ }
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "spark.jar")
+ Client.populateHadoopClasspath(conf, env)
+
+ if (!userClasspathFirst) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "app.jar")
+ }
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "*")
+ }
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index c56dbd99ba..852dbd7dab 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -24,6 +24,9 @@ import org.apache.spark.scheduler.{InputFormatInfo, SplitInfo}
// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware !
class ClientArguments(val args: Array[String]) {
+ var addJars: String = null
+ var files: String = null
+ var archives: String = null
var userJar: String = null
var userClass: String = null
var userArgs: Seq[String] = Seq[String]()
@@ -81,6 +84,17 @@ class ClientArguments(val args: Array[String]) {
case ("--name") :: value :: tail =>
appName = value
+
+ case ("--addJars") :: value :: tail =>
+ addJars = value
+ args = tail
+
+ case ("--files") :: value :: tail =>
+ files = value
+ args = tail
+
+ case ("--archives") :: value :: tail =>
+ archives = value
args = tail
case Nil =>
@@ -97,7 +111,7 @@ class ClientArguments(val args: Array[String]) {
inputFormatInfo = inputFormatMap.values.toList
}
-
+
def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
if (unknownParam != null) {
System.err.println("Unknown/unsupported param " + unknownParam)
@@ -113,10 +127,13 @@ class ClientArguments(val args: Array[String]) {
" --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
" --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
- " --name NAME The name of your application (Default: Spark)\n" +
- " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')"
+ " --name NAME The name of your application (Default: Spark)\n" +
+ " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
+ " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
+ " --files files Comma separated list of files to be distributed with the job.\n" +
+ " --archives archives Comma separated list of archives to be distributed with the job."
)
System.exit(exitCode)
}
-
+
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
index a60e8a3007..ba352daac4 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
@@ -121,7 +121,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
// TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
" -XX:OnOutOfMemoryError='kill %p' " +
JAVA_OPTS +
- " org.apache.spark.executor.StandaloneExecutorBackend " +
+ " org.apache.spark.executor.CoarseGrainedExecutorBackend " +
masterAddress + " " +
slaveId + " " +
hostname + " " +
@@ -137,11 +137,26 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
startReq.setContainerLaunchContext(ctx)
cm.startContainer(startReq)
}
+
+ private def setupDistributedCache(file: String,
+ rtype: LocalResourceType,
+ localResources: HashMap[String, LocalResource],
+ timestamp: String,
+ size: String) = {
+ val uri = new URI(file)
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(rtype)
+ amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri))
+ amJarRsrc.setTimestamp(timestamp.toLong)
+ amJarRsrc.setSize(size.toLong)
+ localResources(uri.getFragment()) = amJarRsrc
+ }
def prepareLocalResources: HashMap[String, LocalResource] = {
logInfo("Preparing Local resources")
- val locaResources = HashMap[String, LocalResource]()
+ val localResources = HashMap[String, LocalResource]()
// Spark JAR
val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
@@ -151,7 +166,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
- locaResources("spark.jar") = sparkJarResource
+ localResources("spark.jar") = sparkJarResource
// User JAR
val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
userJarResource.setType(LocalResourceType.FILE)
@@ -160,7 +175,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
- locaResources("app.jar") = userJarResource
+ localResources("app.jar") = userJarResource
// Log4j conf - if available
if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
@@ -171,27 +186,37 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
- locaResources("log4j.properties") = log4jConfResource
+ localResources("log4j.properties") = log4jConfResource
+ }
+
+ if (System.getenv("SPARK_YARN_CACHE_FILES") != null) {
+ val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
+ val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
+ val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',')
+ for( i <- 0 to distFiles.length - 1) {
+ setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i),
+ fileSizes(i))
+ }
}
+ if (System.getenv("SPARK_YARN_CACHE_ARCHIVES") != null) {
+ val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',')
+ val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',')
+ val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',')
+ for( i <- 0 to distArchives.length - 1) {
+ setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources,
+ timeStamps(i), fileSizes(i))
+ }
+ }
- logInfo("Prepared Local resources " + locaResources)
- return locaResources
+ logInfo("Prepared Local resources " + localResources)
+ return localResources
}
def prepareEnvironment: HashMap[String, String] = {
val env = new HashMap[String, String]()
- // If log4j present, ensure ours overrides all others
- if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
- // Which is correct ?
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties")
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
- }
-
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
- Client.populateHadoopClasspath(yarnConf, env)
+ Client.populateClasspath(yarnConf, System.getenv("SPARK_YARN_LOG4J_PATH") != null, env)
// allow users to specify some environment variables
Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
index d222f412a0..4beb5229fe 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -22,7 +22,7 @@ import org.apache.spark.util.Utils
import org.apache.spark.scheduler.SplitInfo
import scala.collection
import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
import org.apache.hadoop.yarn.util.{RackResolver, Records}
import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
@@ -211,7 +211,7 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
val workerId = workerIdCounter.incrementAndGet().toString
val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
- StandaloneSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
logInfo("launching container on " + containerId + " host " + workerHostname)
// just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..