aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark')
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulators.scala256
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala61
-rw-r--r--core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala89
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala82
-rw-r--r--core/src/main/scala/org/apache/spark/ClosureCleaner.scala231
-rw-r--r--core/src/main/scala/org/apache/spark/Dependency.scala81
-rw-r--r--core/src/main/scala/org/apache/spark/DoubleRDDFunctions.scala78
-rw-r--r--core/src/main/scala/org/apache/spark/FetchFailedException.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/HttpFileServer.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/HttpServer.scala88
-rw-r--r--core/src/main/scala/org/apache/spark/JavaSerializer.scala83
-rw-r--r--core/src/main/scala/org/apache/spark/KryoSerializer.scala156
-rw-r--r--core/src/main/scala/org/apache/spark/Logging.scala95
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala338
-rw-r--r--core/src/main/scala/org/apache/spark/PairRDDFunctions.scala703
-rw-r--r--core/src/main/scala/org/apache/spark/Partition.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/Partitioner.scala135
-rw-r--r--core/src/main/scala/org/apache/spark/RDD.scala957
-rw-r--r--core/src/main/scala/org/apache/spark/RDDCheckpointData.scala130
-rw-r--r--core/src/main/scala/org/apache/spark/SequenceFileRDDFunctions.scala107
-rw-r--r--core/src/main/scala/org/apache/spark/SerializableWritable.scala42
-rw-r--r--core/src/main/scala/org/apache/spark/ShuffleFetcher.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/SizeEstimator.scala283
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala995
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala240
-rw-r--r--core/src/main/scala/org/apache/spark/SparkException.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/SparkFiles.java42
-rw-r--r--core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala201
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala41
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/TaskState.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/Utils.scala780
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala167
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala601
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala114
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala426
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala418
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java64
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/StorageLevels.java48
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java37
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java34
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function.java39
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function2.java38
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java46
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java45
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala50
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala344
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala132
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala1057
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala70
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala171
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala409
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala602
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/Command.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala130
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala86
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala69
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/WebUI.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/Client.scala145
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala386
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala89
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala77
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala118
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala141
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala80
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala199
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala213
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala153
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala115
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala190
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala269
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala95
-rw-r--r--core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala107
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala105
-rw-r--r--core/src/main/scala/org/apache/spark/io/CompressionCodec.scala82
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala100
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala163
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala59
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala68
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/source/Source.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/network/BufferMessage.scala111
-rw-r--r--core/src/main/scala/org/apache/spark/network/Connection.scala586
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala720
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala102
-rw-r--r--core/src/main/scala/org/apache/spark/network/Message.scala93
-rw-r--r--core/src/main/scala/org/apache/spark/network/MessageChunk.scala42
-rw-r--r--core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala75
-rw-r--r--core/src/main/scala/org/apache/spark/network/ReceiverTest.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/network/SenderTest.scala70
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala74
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala118
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala70
-rw-r--r--core/src/main/scala/org/apache/spark/package.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala87
-rw-r--r--core/src/main/scala/org/apache/spark/partial/ApproximateEvaluator.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala79
-rw-r--r--core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala82
-rw-r--r--core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala89
-rw-r--r--core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala58
-rw-r--r--core/src/main/scala/org/apache/spark/partial/PartialResult.scala137
-rw-r--r--core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala43
-rw-r--r--core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala68
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala90
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala155
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala144
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala342
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala137
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala120
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala41
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala126
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala151
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala72
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala125
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala66
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala67
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala129
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala73
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala143
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala39
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala849
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala63
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala178
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobListener.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala292
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobResult.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala66
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala134
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala189
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala204
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala74
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala78
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala112
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala115
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala72
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala52
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala45
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala440
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala712
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala121
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala137
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala81
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala91
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala63
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala198
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala72
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala272
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala194
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala286
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala342
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala112
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockException.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala348
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala1046
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala118
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala178
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala404
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala110
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala39
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala139
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessage.scala223
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala159
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala65
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockStore.scala61
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala329
-rw-r--r--core/src/main/scala/org/apache/spark/storage/MemoryStore.scala257
-rw-r--r--core/src/main/scala/org/apache/spark/storage/PutResult.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala67
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageLevel.scala146
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageUtils.scala115
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala113
-rw-r--r--core/src/main/scala/org/apache/spark/ui/JettyUtils.scala132
-rw-r--r--core/src/main/scala/org/apache/spark/ui/Page.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/ui/SparkUI.scala87
-rw-r--r--core/src/main/scala/org/apache/spark/ui/UIUtils.scala131
-rw-r--r--core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala105
-rw-r--r--core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala91
-rw-r--r--core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala136
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala90
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala156
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala183
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala107
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala41
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala65
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala132
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala72
-rw-r--r--core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala80
-rw-r--r--core/src/main/scala/org/apache/spark/util/Clock.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/util/CompletionIterator.scala42
-rw-r--r--core/src/main/scala/org/apache/spark/util/Distribution.scala82
-rw-r--r--core/src/main/scala/org/apache/spark/util/IdGenerator.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/util/IntParam.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/util/MemoryParam.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala61
-rw-r--r--core/src/main/scala/org/apache/spark/util/MutablePair.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/util/NextIterator.scala88
-rw-r--r--core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala79
-rw-r--r--core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/util/StatCounter.scala131
-rw-r--r--core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala122
-rw-r--r--core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala86
-rw-r--r--core/src/main/scala/org/apache/spark/util/Vector.scala139
264 files changed, 34428 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
new file mode 100644
index 0000000000..5177ee58fa
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -0,0 +1,256 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io._
+
+import scala.collection.mutable.Map
+import scala.collection.generic.Growable
+
+/**
+ * A datatype that can be accumulated, i.e. has an commutative and associative "add" operation,
+ * but where the result type, `R`, may be different from the element type being added, `T`.
+ *
+ * You must define how to add data, and how to merge two of these together. For some datatypes,
+ * such as a counter, these might be the same operation. In that case, you can use the simpler
+ * [[org.apache.spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are
+ * accumulating a set. You will add items to the set, and you will union two sets together.
+ *
+ * @param initialValue initial value of accumulator
+ * @param param helper object defining how to add elements of type `R` and `T`
+ * @tparam R the full accumulated data (result type)
+ * @tparam T partial data that can be added in
+ */
+class Accumulable[R, T] (
+ @transient initialValue: R,
+ param: AccumulableParam[R, T])
+ extends Serializable {
+
+ val id = Accumulators.newId
+ @transient private var value_ = initialValue // Current value on master
+ val zero = param.zero(initialValue) // Zero value to be passed to workers
+ var deserialized = false
+
+ Accumulators.register(this, true)
+
+ /**
+ * Add more data to this accumulator / accumulable
+ * @param term the data to add
+ */
+ def += (term: T) { value_ = param.addAccumulator(value_, term) }
+
+ /**
+ * Add more data to this accumulator / accumulable
+ * @param term the data to add
+ */
+ def add(term: T) { value_ = param.addAccumulator(value_, term) }
+
+ /**
+ * Merge two accumulable objects together
+ *
+ * Normally, a user will not want to use this version, but will instead call `+=`.
+ * @param term the other `R` that will get merged with this
+ */
+ def ++= (term: R) { value_ = param.addInPlace(value_, term)}
+
+ /**
+ * Merge two accumulable objects together
+ *
+ * Normally, a user will not want to use this version, but will instead call `add`.
+ * @param term the other `R` that will get merged with this
+ */
+ def merge(term: R) { value_ = param.addInPlace(value_, term)}
+
+ /**
+ * Access the accumulator's current value; only allowed on master.
+ */
+ def value: R = {
+ if (!deserialized) {
+ value_
+ } else {
+ throw new UnsupportedOperationException("Can't read accumulator value in task")
+ }
+ }
+
+ /**
+ * Get the current value of this accumulator from within a task.
+ *
+ * This is NOT the global value of the accumulator. To get the global value after a
+ * completed operation on the dataset, call `value`.
+ *
+ * The typical use of this method is to directly mutate the local value, eg., to add
+ * an element to a Set.
+ */
+ def localValue = value_
+
+ /**
+ * Set the accumulator's value; only allowed on master.
+ */
+ def value_= (newValue: R) {
+ if (!deserialized) value_ = newValue
+ else throw new UnsupportedOperationException("Can't assign accumulator value in task")
+ }
+
+ /**
+ * Set the accumulator's value; only allowed on master
+ */
+ def setValue(newValue: R) {
+ this.value = newValue
+ }
+
+ // Called by Java when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ value_ = zero
+ deserialized = true
+ Accumulators.register(this, false)
+ }
+
+ override def toString = value_.toString
+}
+
+/**
+ * Helper object defining how to accumulate values of a particular type. An implicit
+ * AccumulableParam needs to be available when you create Accumulables of a specific type.
+ *
+ * @tparam R the full accumulated data (result type)
+ * @tparam T partial data that can be added in
+ */
+trait AccumulableParam[R, T] extends Serializable {
+ /**
+ * Add additional data to the accumulator value. Is allowed to modify and return `r`
+ * for efficiency (to avoid allocating objects).
+ *
+ * @param r the current value of the accumulator
+ * @param t the data to be added to the accumulator
+ * @return the new value of the accumulator
+ */
+ def addAccumulator(r: R, t: T): R
+
+ /**
+ * Merge two accumulated values together. Is allowed to modify and return the first value
+ * for efficiency (to avoid allocating objects).
+ *
+ * @param r1 one set of accumulated data
+ * @param r2 another set of accumulated data
+ * @return both data sets merged together
+ */
+ def addInPlace(r1: R, r2: R): R
+
+ /**
+ * Return the "zero" (identity) value for an accumulator type, given its initial value. For
+ * example, if R was a vector of N dimensions, this would return a vector of N zeroes.
+ */
+ def zero(initialValue: R): R
+}
+
+private[spark]
+class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
+ extends AccumulableParam[R,T] {
+
+ def addAccumulator(growable: R, elem: T): R = {
+ growable += elem
+ growable
+ }
+
+ def addInPlace(t1: R, t2: R): R = {
+ t1 ++= t2
+ t1
+ }
+
+ def zero(initialValue: R): R = {
+ // We need to clone initialValue, but it's hard to specify that R should also be Cloneable.
+ // Instead we'll serialize it to a buffer and load it back.
+ val ser = new JavaSerializer().newInstance()
+ val copy = ser.deserialize[R](ser.serialize(initialValue))
+ copy.clear() // In case it contained stuff
+ copy
+ }
+}
+
+/**
+ * A simpler value of [[org.apache.spark.Accumulable]] where the result type being accumulated is the same
+ * as the types of elements being merged.
+ *
+ * @param initialValue initial value of accumulator
+ * @param param helper object defining how to add elements of type `T`
+ * @tparam T result type
+ */
+class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T])
+ extends Accumulable[T,T](initialValue, param)
+
+/**
+ * A simpler version of [[org.apache.spark.AccumulableParam]] where the only datatype you can add in is the same type
+ * as the accumulated value. An implicit AccumulatorParam object needs to be available when you create
+ * Accumulators of a specific type.
+ *
+ * @tparam T type of value to accumulate
+ */
+trait AccumulatorParam[T] extends AccumulableParam[T, T] {
+ def addAccumulator(t1: T, t2: T): T = {
+ addInPlace(t1, t2)
+ }
+}
+
+// TODO: The multi-thread support in accumulators is kind of lame; check
+// if there's a more intuitive way of doing it right
+private object Accumulators {
+ // TODO: Use soft references? => need to make readObject work properly then
+ val originals = Map[Long, Accumulable[_, _]]()
+ val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
+ var lastId: Long = 0
+
+ def newId: Long = synchronized {
+ lastId += 1
+ return lastId
+ }
+
+ def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
+ if (original) {
+ originals(a.id) = a
+ } else {
+ val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map())
+ accums(a.id) = a
+ }
+ }
+
+ // Clear the local (non-original) accumulators for the current thread
+ def clear() {
+ synchronized {
+ localAccums.remove(Thread.currentThread)
+ }
+ }
+
+ // Get the values of the local accumulators for the current thread (by ID)
+ def values: Map[Long, Any] = synchronized {
+ val ret = Map[Long, Any]()
+ for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
+ ret(id) = accum.localValue
+ }
+ return ret
+ }
+
+ // Add values to the original accumulators with some given IDs
+ def add(values: Map[Long, Any]): Unit = synchronized {
+ for ((id, value) <- values) {
+ if (originals.contains(id)) {
+ originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
new file mode 100644
index 0000000000..3ef402926e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.JavaConversions._
+
+/** A set of functions used to aggregate data.
+ *
+ * @param createCombiner function to create the initial value of the aggregation.
+ * @param mergeValue function to merge a new value into the aggregation result.
+ * @param mergeCombiners function to merge outputs from multiple mergeValue function.
+ */
+case class Aggregator[K, V, C] (
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C) {
+
+ def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
+ val combiners = new JHashMap[K, C]
+ for (kv <- iter) {
+ val oldC = combiners.get(kv._1)
+ if (oldC == null) {
+ combiners.put(kv._1, createCombiner(kv._2))
+ } else {
+ combiners.put(kv._1, mergeValue(oldC, kv._2))
+ }
+ }
+ combiners.iterator
+ }
+
+ def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
+ val combiners = new JHashMap[K, C]
+ iter.foreach { case(k, c) =>
+ val oldC = combiners.get(k)
+ if (oldC == null) {
+ combiners.put(k, c)
+ } else {
+ combiners.put(k, mergeCombiners(oldC, c))
+ }
+ }
+ combiners.iterator
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
new file mode 100644
index 0000000000..908ff56a6b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.CompletionIterator
+
+
+private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
+
+ override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer)
+ : Iterator[T] =
+ {
+
+ logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
+ val blockManager = SparkEnv.get.blockManager
+
+ val startTime = System.currentTimeMillis
+ val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
+ logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
+ shuffleId, reduceId, System.currentTimeMillis - startTime))
+
+ val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
+ for (((address, size), index) <- statuses.zipWithIndex) {
+ splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
+ }
+
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
+ case (address, splits) =>
+ (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
+ }
+
+ def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = {
+ val blockId = blockPair._1
+ val blockOption = blockPair._2
+ blockOption match {
+ case Some(block) => {
+ block.asInstanceOf[Iterator[T]]
+ }
+ case None => {
+ val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
+ blockId match {
+ case regex(shufId, mapId, _) =>
+ val address = statuses(mapId.toInt)._1
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
+ case _ =>
+ throw new SparkException(
+ "Failed to get block " + blockId + ", which is not a shuffle block")
+ }
+ }
+ }
+ }
+
+ val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
+ val itr = blockFetcherItr.flatMap(unpackBlock)
+
+ CompletionIterator[T, Iterator[T]](itr, {
+ val shuffleMetrics = new ShuffleReadMetrics
+ shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
+ shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
+ shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
+ shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
+ shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
+ shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
+ shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
+ metrics.shuffleReadMetrics = Some(shuffleMetrics)
+ })
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
new file mode 100644
index 0000000000..42e465b9d8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable.{ArrayBuffer, HashSet}
+import org.apache.spark.storage.{BlockManager, StorageLevel}
+
+
+/** Spark class responsible for passing RDDs split contents to the BlockManager and making
+ sure a node doesn't load two copies of an RDD at once.
+ */
+private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
+ private val loading = new HashSet[String]
+
+ /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
+ def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
+ : Iterator[T] = {
+ val key = "rdd_%d_%d".format(rdd.id, split.index)
+ logInfo("Cache key is " + key)
+ blockManager.get(key) match {
+ case Some(cachedValues) =>
+ // Partition is in cache, so just return its values
+ logInfo("Found partition in cache!")
+ return cachedValues.asInstanceOf[Iterator[T]]
+
+ case None =>
+ // Mark the split as loading (unless someone else marks it first)
+ loading.synchronized {
+ if (loading.contains(key)) {
+ logInfo("Loading contains " + key + ", waiting...")
+ while (loading.contains(key)) {
+ try {loading.wait()} catch {case _ : Throwable =>}
+ }
+ logInfo("Loading no longer contains " + key + ", so returning cached result")
+ // See whether someone else has successfully loaded it. The main way this would fail
+ // is for the RDD-level cache eviction policy if someone else has loaded the same RDD
+ // partition but we didn't want to make space for it. However, that case is unlikely
+ // because it's unlikely that two threads would work on the same RDD partition. One
+ // downside of the current code is that threads wait serially if this does happen.
+ blockManager.get(key) match {
+ case Some(values) =>
+ return values.asInstanceOf[Iterator[T]]
+ case None =>
+ logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
+ loading.add(key)
+ }
+ } else {
+ loading.add(key)
+ }
+ }
+ try {
+ // If we got here, we have to load the split
+ val elements = new ArrayBuffer[Any]
+ logInfo("Computing partition " + split)
+ elements ++= rdd.computeOrReadCheckpoint(split, context)
+ // Try to put this block in the blockManager
+ blockManager.put(key, elements, storageLevel, true)
+ return elements.iterator.asInstanceOf[Iterator[T]]
+ } finally {
+ loading.synchronized {
+ loading.remove(key)
+ loading.notifyAll()
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/ClosureCleaner.scala
new file mode 100644
index 0000000000..71d9e62d4f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ClosureCleaner.scala
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.lang.reflect.Field
+
+import scala.collection.mutable.Map
+import scala.collection.mutable.Set
+
+import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
+import org.objectweb.asm.Opcodes._
+import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream}
+
+private[spark] object ClosureCleaner extends Logging {
+ // Get an ASM class reader for a given class from the JAR that loaded it
+ private def getClassReader(cls: Class[_]): ClassReader = {
+ // Copy data over, before delegating to ClassReader - else we can run out of open file handles.
+ val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
+ val resourceStream = cls.getResourceAsStream(className)
+ // todo: Fixme - continuing with earlier behavior ...
+ if (resourceStream == null) return new ClassReader(resourceStream)
+
+ val baos = new ByteArrayOutputStream(128)
+ Utils.copyStream(resourceStream, baos, true)
+ new ClassReader(new ByteArrayInputStream(baos.toByteArray))
+ }
+
+ // Check whether a class represents a Scala closure
+ private def isClosure(cls: Class[_]): Boolean = {
+ cls.getName.contains("$anonfun$")
+ }
+
+ // Get a list of the classes of the outer objects of a given closure object, obj;
+ // the outer objects are defined as any closures that obj is nested within, plus
+ // possibly the class that the outermost closure is in, if any. We stop searching
+ // for outer objects beyond that because cloning the user's object is probably
+ // not a good idea (whereas we can clone closure objects just fine since we
+ // understand how all their fields are used).
+ private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
+ for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
+ f.setAccessible(true)
+ if (isClosure(f.getType)) {
+ return f.getType :: getOuterClasses(f.get(obj))
+ } else {
+ return f.getType :: Nil // Stop at the first $outer that is not a closure
+ }
+ }
+ return Nil
+ }
+
+ // Get a list of the outer objects for a given closure object.
+ private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
+ for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
+ f.setAccessible(true)
+ if (isClosure(f.getType)) {
+ return f.get(obj) :: getOuterObjects(f.get(obj))
+ } else {
+ return f.get(obj) :: Nil // Stop at the first $outer that is not a closure
+ }
+ }
+ return Nil
+ }
+
+ private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
+ val seen = Set[Class[_]](obj.getClass)
+ var stack = List[Class[_]](obj.getClass)
+ while (!stack.isEmpty) {
+ val cr = getClassReader(stack.head)
+ stack = stack.tail
+ val set = Set[Class[_]]()
+ cr.accept(new InnerClosureFinder(set), 0)
+ for (cls <- set -- seen) {
+ seen += cls
+ stack = cls :: stack
+ }
+ }
+ return (seen - obj.getClass).toList
+ }
+
+ private def createNullValue(cls: Class[_]): AnyRef = {
+ if (cls.isPrimitive) {
+ new java.lang.Byte(0: Byte) // Should be convertible to any primitive type
+ } else {
+ null
+ }
+ }
+
+ def clean(func: AnyRef) {
+ // TODO: cache outerClasses / innerClasses / accessedFields
+ val outerClasses = getOuterClasses(func)
+ val innerClasses = getInnerClasses(func)
+ val outerObjects = getOuterObjects(func)
+
+ val accessedFields = Map[Class[_], Set[String]]()
+ for (cls <- outerClasses)
+ accessedFields(cls) = Set[String]()
+ for (cls <- func.getClass :: innerClasses)
+ getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
+ //logInfo("accessedFields: " + accessedFields)
+
+ val inInterpreter = {
+ try {
+ val interpClass = Class.forName("spark.repl.Main")
+ interpClass.getMethod("interp").invoke(null) != null
+ } catch {
+ case _: ClassNotFoundException => true
+ }
+ }
+
+ var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse
+ var outer: AnyRef = null
+ if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) {
+ // The closure is ultimately nested inside a class; keep the object of that
+ // class without cloning it since we don't want to clone the user's objects.
+ outer = outerPairs.head._2
+ outerPairs = outerPairs.tail
+ }
+ // Clone the closure objects themselves, nulling out any fields that are not
+ // used in the closure we're working on or any of its inner closures.
+ for ((cls, obj) <- outerPairs) {
+ outer = instantiateClass(cls, outer, inInterpreter)
+ for (fieldName <- accessedFields(cls)) {
+ val field = cls.getDeclaredField(fieldName)
+ field.setAccessible(true)
+ val value = field.get(obj)
+ //logInfo("1: Setting " + fieldName + " on " + cls + " to " + value);
+ field.set(outer, value)
+ }
+ }
+
+ if (outer != null) {
+ //logInfo("2: Setting $outer on " + func.getClass + " to " + outer);
+ val field = func.getClass.getDeclaredField("$outer")
+ field.setAccessible(true)
+ field.set(func, outer)
+ }
+ }
+
+ private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
+ //logInfo("Creating a " + cls + " with outer = " + outer)
+ if (!inInterpreter) {
+ // This is a bona fide closure class, whose constructor has no effects
+ // other than to set its fields, so use its constructor
+ val cons = cls.getConstructors()(0)
+ val params = cons.getParameterTypes.map(createNullValue).toArray
+ if (outer != null)
+ params(0) = outer // First param is always outer object
+ return cons.newInstance(params: _*).asInstanceOf[AnyRef]
+ } else {
+ // Use reflection to instantiate object without calling constructor
+ val rf = sun.reflect.ReflectionFactory.getReflectionFactory()
+ val parentCtor = classOf[java.lang.Object].getDeclaredConstructor()
+ val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
+ val obj = newCtor.newInstance().asInstanceOf[AnyRef]
+ if (outer != null) {
+ //logInfo("3: Setting $outer on " + cls + " to " + outer);
+ val field = cls.getDeclaredField("$outer")
+ field.setAccessible(true)
+ field.set(obj, outer)
+ }
+ return obj
+ }
+ }
+}
+
+private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
+ override def visitMethod(access: Int, name: String, desc: String,
+ sig: String, exceptions: Array[String]): MethodVisitor = {
+ return new MethodVisitor(ASM4) {
+ override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
+ if (op == GETFIELD) {
+ for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
+ output(cl) += name
+ }
+ }
+ }
+
+ override def visitMethodInsn(op: Int, owner: String, name: String,
+ desc: String) {
+ // Check for calls a getter method for a variable in an interpreter wrapper object.
+ // This means that the corresponding field will be accessed, so we should save it.
+ if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) {
+ for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
+ output(cl) += name
+ }
+ }
+ }
+ }
+ }
+}
+
+private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) {
+ var myName: String = null
+
+ override def visit(version: Int, access: Int, name: String, sig: String,
+ superName: String, interfaces: Array[String]) {
+ myName = name
+ }
+
+ override def visitMethod(access: Int, name: String, desc: String,
+ sig: String, exceptions: Array[String]): MethodVisitor = {
+ return new MethodVisitor(ASM4) {
+ override def visitMethodInsn(op: Int, owner: String, name: String,
+ desc: String) {
+ val argTypes = Type.getArgumentTypes(desc)
+ if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0
+ && argTypes(0).toString.startsWith("L") // is it an object?
+ && argTypes(0).getInternalName == myName)
+ output += Class.forName(
+ owner.replace('/', '.'),
+ false,
+ Thread.currentThread.getContextClassLoader)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
new file mode 100644
index 0000000000..cc3c2474a6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * Base class for dependencies.
+ */
+abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
+
+
+/**
+ * Base class for dependencies where each partition of the parent RDD is used by at most one
+ * partition of the child RDD. Narrow dependencies allow for pipelined execution.
+ */
+abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
+ /**
+ * Get the parent partitions for a child partition.
+ * @param partitionId a partition of the child RDD
+ * @return the partitions of the parent RDD that the child partition depends upon
+ */
+ def getParents(partitionId: Int): Seq[Int]
+}
+
+
+/**
+ * Represents a dependency on the output of a shuffle stage.
+ * @param rdd the parent RDD
+ * @param partitioner partitioner used to partition the shuffle output
+ * @param serializerClass class name of the serializer to use
+ */
+class ShuffleDependency[K, V](
+ @transient rdd: RDD[_ <: Product2[K, V]],
+ val partitioner: Partitioner,
+ val serializerClass: String = null)
+ extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
+
+ val shuffleId: Int = rdd.context.newShuffleId()
+}
+
+
+/**
+ * Represents a one-to-one dependency between partitions of the parent and child RDDs.
+ */
+class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
+ override def getParents(partitionId: Int) = List(partitionId)
+}
+
+
+/**
+ * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs.
+ * @param rdd the parent RDD
+ * @param inStart the start of the range in the parent RDD
+ * @param outStart the start of the range in the child RDD
+ * @param length the length of the range
+ */
+class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
+ extends NarrowDependency[T](rdd) {
+
+ override def getParents(partitionId: Int) = {
+ if (partitionId >= outStart && partitionId < outStart + length) {
+ List(partitionId - outStart + inStart)
+ } else {
+ Nil
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/DoubleRDDFunctions.scala
new file mode 100644
index 0000000000..dd344491b8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/DoubleRDDFunctions.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.partial.BoundedDouble
+import org.apache.spark.partial.MeanEvaluator
+import org.apache.spark.partial.PartialResult
+import org.apache.spark.partial.SumEvaluator
+import org.apache.spark.util.StatCounter
+
+/**
+ * Extra functions available on RDDs of Doubles through an implicit conversion.
+ * Import `spark.SparkContext._` at the top of your program to use these functions.
+ */
+class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
+ /** Add up the elements in this RDD. */
+ def sum(): Double = {
+ self.reduce(_ + _)
+ }
+
+ /**
+ * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and count
+ * of the RDD's elements in one operation.
+ */
+ def stats(): StatCounter = {
+ self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
+ }
+
+ /** Compute the mean of this RDD's elements. */
+ def mean(): Double = stats().mean
+
+ /** Compute the variance of this RDD's elements. */
+ def variance(): Double = stats().variance
+
+ /** Compute the standard deviation of this RDD's elements. */
+ def stdev(): Double = stats().stdev
+
+ /**
+ * Compute the sample standard deviation of this RDD's elements (which corrects for bias in
+ * estimating the standard deviation by dividing by N-1 instead of N).
+ */
+ def sampleStdev(): Double = stats().sampleStdev
+
+ /**
+ * Compute the sample variance of this RDD's elements (which corrects for bias in
+ * estimating the variance by dividing by N-1 instead of N).
+ */
+ def sampleVariance(): Double = stats().sampleVariance
+
+ /** (Experimental) Approximate operation to return the mean within a timeout. */
+ def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+ val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
+ val evaluator = new MeanEvaluator(self.partitions.size, confidence)
+ self.context.runApproximateJob(self, processPartition, evaluator, timeout)
+ }
+
+ /** (Experimental) Approximate operation to return the sum within a timeout. */
+ def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+ val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
+ val evaluator = new SumEvaluator(self.partitions.size, confidence)
+ self.context.runApproximateJob(self, processPartition, evaluator, timeout)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/FetchFailedException.scala
new file mode 100644
index 0000000000..d242047502
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/FetchFailedException.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.storage.BlockManagerId
+
+private[spark] class FetchFailedException(
+ taskEndReason: TaskEndReason,
+ message: String,
+ cause: Throwable)
+ extends Exception {
+
+ def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) =
+ this(FetchFailed(bmAddress, shuffleId, mapId, reduceId),
+ "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId),
+ cause)
+
+ def this (shuffleId: Int, reduceId: Int, cause: Throwable) =
+ this(FetchFailed(null, shuffleId, -1, reduceId),
+ "Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause)
+
+ override def getMessage(): String = message
+
+
+ override def getCause(): Throwable = cause
+
+ def toTaskEndReason: TaskEndReason = taskEndReason
+
+}
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
new file mode 100644
index 0000000000..9b3a896648
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.{File}
+import com.google.common.io.Files
+
+private[spark] class HttpFileServer extends Logging {
+
+ var baseDir : File = null
+ var fileDir : File = null
+ var jarDir : File = null
+ var httpServer : HttpServer = null
+ var serverUri : String = null
+
+ def initialize() {
+ baseDir = Utils.createTempDir()
+ fileDir = new File(baseDir, "files")
+ jarDir = new File(baseDir, "jars")
+ fileDir.mkdir()
+ jarDir.mkdir()
+ logInfo("HTTP File server directory is " + baseDir)
+ httpServer = new HttpServer(baseDir)
+ httpServer.start()
+ serverUri = httpServer.uri
+ }
+
+ def stop() {
+ httpServer.stop()
+ }
+
+ def addFile(file: File) : String = {
+ addFileToDir(file, fileDir)
+ return serverUri + "/files/" + file.getName
+ }
+
+ def addJar(file: File) : String = {
+ addFileToDir(file, jarDir)
+ return serverUri + "/jars/" + file.getName
+ }
+
+ def addFileToDir(file: File, dir: File) : String = {
+ Files.copy(file, new File(dir, file.getName))
+ return dir + "/" + file.getName
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
new file mode 100644
index 0000000000..db36c7c9dd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.File
+import java.net.InetAddress
+
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.server.bio.SocketConnector
+import org.eclipse.jetty.server.handler.DefaultHandler
+import org.eclipse.jetty.server.handler.HandlerList
+import org.eclipse.jetty.server.handler.ResourceHandler
+import org.eclipse.jetty.util.thread.QueuedThreadPool
+
+/**
+ * Exception type thrown by HttpServer when it is in the wrong state for an operation.
+ */
+private[spark] class ServerStateException(message: String) extends Exception(message)
+
+/**
+ * An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext
+ * as well as classes created by the interpreter when the user types in code. This is just a wrapper
+ * around a Jetty server.
+ */
+private[spark] class HttpServer(resourceBase: File) extends Logging {
+ private var server: Server = null
+ private var port: Int = -1
+
+ def start() {
+ if (server != null) {
+ throw new ServerStateException("Server is already started")
+ } else {
+ server = new Server()
+ val connector = new SocketConnector
+ connector.setMaxIdleTime(60*1000)
+ connector.setSoLingerTime(-1)
+ connector.setPort(0)
+ server.addConnector(connector)
+
+ val threadPool = new QueuedThreadPool
+ threadPool.setDaemon(true)
+ server.setThreadPool(threadPool)
+ val resHandler = new ResourceHandler
+ resHandler.setResourceBase(resourceBase.getAbsolutePath)
+ val handlerList = new HandlerList
+ handlerList.setHandlers(Array(resHandler, new DefaultHandler))
+ server.setHandler(handlerList)
+ server.start()
+ port = server.getConnectors()(0).getLocalPort()
+ }
+ }
+
+ def stop() {
+ if (server == null) {
+ throw new ServerStateException("Server is already stopped")
+ } else {
+ server.stop()
+ port = -1
+ server = null
+ }
+ }
+
+ /**
+ * Get the URI of this HTTP server (http://host:port)
+ */
+ def uri: String = {
+ if (server == null) {
+ throw new ServerStateException("Server is not started")
+ } else {
+ return "http://" + Utils.localIpAddress + ":" + port
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/JavaSerializer.scala
new file mode 100644
index 0000000000..f43396cb6b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/JavaSerializer.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io._
+import java.nio.ByteBuffer
+
+import serializer.{Serializer, SerializerInstance, DeserializationStream, SerializationStream}
+import org.apache.spark.util.ByteBufferInputStream
+
+private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream {
+ val objOut = new ObjectOutputStream(out)
+ def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this }
+ def flush() { objOut.flush() }
+ def close() { objOut.close() }
+}
+
+private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
+extends DeserializationStream {
+ val objIn = new ObjectInputStream(in) {
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, loader)
+ }
+
+ def readObject[T](): T = objIn.readObject().asInstanceOf[T]
+ def close() { objIn.close() }
+}
+
+private[spark] class JavaSerializerInstance extends SerializerInstance {
+ def serialize[T](t: T): ByteBuffer = {
+ val bos = new ByteArrayOutputStream()
+ val out = serializeStream(bos)
+ out.writeObject(t)
+ out.close()
+ ByteBuffer.wrap(bos.toByteArray)
+ }
+
+ def deserialize[T](bytes: ByteBuffer): T = {
+ val bis = new ByteBufferInputStream(bytes)
+ val in = deserializeStream(bis)
+ in.readObject().asInstanceOf[T]
+ }
+
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ val bis = new ByteBufferInputStream(bytes)
+ val in = deserializeStream(bis, loader)
+ in.readObject().asInstanceOf[T]
+ }
+
+ def serializeStream(s: OutputStream): SerializationStream = {
+ new JavaSerializationStream(s)
+ }
+
+ def deserializeStream(s: InputStream): DeserializationStream = {
+ new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader)
+ }
+
+ def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
+ new JavaDeserializationStream(s, loader)
+ }
+}
+
+/**
+ * A Spark serializer that uses Java's built-in serialization.
+ */
+class JavaSerializer extends Serializer {
+ def newInstance(): SerializerInstance = new JavaSerializerInstance
+}
diff --git a/core/src/main/scala/org/apache/spark/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/KryoSerializer.scala
new file mode 100644
index 0000000000..db86e6db43
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/KryoSerializer.scala
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io._
+import java.nio.ByteBuffer
+import com.esotericsoftware.kryo.{Kryo, KryoException}
+import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
+import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer}
+import com.twitter.chill.ScalaKryoInstantiator
+import serializer.{SerializerInstance, DeserializationStream, SerializationStream}
+import org.apache.spark.broadcast._
+import org.apache.spark.storage._
+
+private[spark]
+class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
+ val output = new KryoOutput(outStream)
+
+ def writeObject[T](t: T): SerializationStream = {
+ kryo.writeClassAndObject(output, t)
+ this
+ }
+
+ def flush() { output.flush() }
+ def close() { output.close() }
+}
+
+private[spark]
+class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
+ val input = new KryoInput(inStream)
+
+ def readObject[T](): T = {
+ try {
+ kryo.readClassAndObject(input).asInstanceOf[T]
+ } catch {
+ // DeserializationStream uses the EOF exception to indicate stopping condition.
+ case _: KryoException => throw new EOFException
+ }
+ }
+
+ def close() {
+ // Kryo's Input automatically closes the input stream it is using.
+ input.close()
+ }
+}
+
+private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
+ val kryo = ks.newKryo()
+ val output = ks.newKryoOutput()
+ val input = ks.newKryoInput()
+
+ def serialize[T](t: T): ByteBuffer = {
+ output.clear()
+ kryo.writeClassAndObject(output, t)
+ ByteBuffer.wrap(output.toBytes)
+ }
+
+ def deserialize[T](bytes: ByteBuffer): T = {
+ input.setBuffer(bytes.array)
+ kryo.readClassAndObject(input).asInstanceOf[T]
+ }
+
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ val oldClassLoader = kryo.getClassLoader
+ kryo.setClassLoader(loader)
+ input.setBuffer(bytes.array)
+ val obj = kryo.readClassAndObject(input).asInstanceOf[T]
+ kryo.setClassLoader(oldClassLoader)
+ obj
+ }
+
+ def serializeStream(s: OutputStream): SerializationStream = {
+ new KryoSerializationStream(kryo, s)
+ }
+
+ def deserializeStream(s: InputStream): DeserializationStream = {
+ new KryoDeserializationStream(kryo, s)
+ }
+}
+
+/**
+ * Interface implemented by clients to register their classes with Kryo when using Kryo
+ * serialization.
+ */
+trait KryoRegistrator {
+ def registerClasses(kryo: Kryo)
+}
+
+/**
+ * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
+ */
+class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging {
+ private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+
+ def newKryoOutput() = new KryoOutput(bufferSize)
+
+ def newKryoInput() = new KryoInput(bufferSize)
+
+ def newKryo(): Kryo = {
+ val instantiator = new ScalaKryoInstantiator
+ val kryo = instantiator.newKryo()
+ val classLoader = Thread.currentThread.getContextClassLoader
+
+ // Register some commonly used classes
+ val toRegister: Seq[AnyRef] = Seq(
+ ByteBuffer.allocate(1),
+ StorageLevel.MEMORY_ONLY,
+ PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
+ GotBlock("1", ByteBuffer.allocate(1)),
+ GetBlock("1")
+ )
+
+ for (obj <- toRegister) kryo.register(obj.getClass)
+
+ // Allow sending SerializableWritable
+ kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
+ kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
+
+ // Allow the user to register their own classes by setting spark.kryo.registrator
+ try {
+ Option(System.getProperty("spark.kryo.registrator")).foreach { regCls =>
+ logDebug("Running user registrator: " + regCls)
+ val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]
+ reg.registerClasses(kryo)
+ }
+ } catch {
+ case _: Exception => println("Failed to register spark.kryo.registrator")
+ }
+
+ kryo.setClassLoader(classLoader)
+
+ // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops
+ kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean)
+
+ kryo
+ }
+
+ def newInstance(): SerializerInstance = {
+ new KryoSerializerInstance(this)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
new file mode 100644
index 0000000000..6a973ea495
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.slf4j.Logger
+import org.slf4j.LoggerFactory
+
+/**
+ * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows
+ * logging messages at different levels using methods that only evaluate parameters lazily if the
+ * log level is enabled.
+ */
+trait Logging {
+ // Make the log field transient so that objects with Logging can
+ // be serialized and used on another machine
+ @transient private var log_ : Logger = null
+
+ // Method to get or create the logger for this object
+ protected def log: Logger = {
+ if (log_ == null) {
+ var className = this.getClass.getName
+ // Ignore trailing $'s in the class names for Scala objects
+ if (className.endsWith("$")) {
+ className = className.substring(0, className.length - 1)
+ }
+ log_ = LoggerFactory.getLogger(className)
+ }
+ return log_
+ }
+
+ // Log methods that take only a String
+ protected def logInfo(msg: => String) {
+ if (log.isInfoEnabled) log.info(msg)
+ }
+
+ protected def logDebug(msg: => String) {
+ if (log.isDebugEnabled) log.debug(msg)
+ }
+
+ protected def logTrace(msg: => String) {
+ if (log.isTraceEnabled) log.trace(msg)
+ }
+
+ protected def logWarning(msg: => String) {
+ if (log.isWarnEnabled) log.warn(msg)
+ }
+
+ protected def logError(msg: => String) {
+ if (log.isErrorEnabled) log.error(msg)
+ }
+
+ // Log methods that take Throwables (Exceptions/Errors) too
+ protected def logInfo(msg: => String, throwable: Throwable) {
+ if (log.isInfoEnabled) log.info(msg, throwable)
+ }
+
+ protected def logDebug(msg: => String, throwable: Throwable) {
+ if (log.isDebugEnabled) log.debug(msg, throwable)
+ }
+
+ protected def logTrace(msg: => String, throwable: Throwable) {
+ if (log.isTraceEnabled) log.trace(msg, throwable)
+ }
+
+ protected def logWarning(msg: => String, throwable: Throwable) {
+ if (log.isWarnEnabled) log.warn(msg, throwable)
+ }
+
+ protected def logError(msg: => String, throwable: Throwable) {
+ if (log.isErrorEnabled) log.error(msg, throwable)
+ }
+
+ protected def isTraceEnabled(): Boolean = {
+ log.isTraceEnabled
+ }
+
+ // Method for ensuring that logging is initialized, to avoid having multiple
+ // threads do it concurrently (as SLF4J initialization is not thread safe).
+ protected def initLogging() { log }
+}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
new file mode 100644
index 0000000000..0f422d910a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -0,0 +1,338 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io._
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+
+import akka.actor._
+import akka.dispatch._
+import akka.pattern.ask
+import akka.remote._
+import akka.util.Duration
+
+
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
+
+
+private[spark] sealed trait MapOutputTrackerMessage
+private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
+ extends MapOutputTrackerMessage
+private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
+
+private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
+ def receive = {
+ case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
+ logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
+ sender ! tracker.getSerializedLocations(shuffleId)
+
+ case StopMapOutputTracker =>
+ logInfo("MapOutputTrackerActor stopped!")
+ sender ! true
+ context.stop(self)
+ }
+}
+
+private[spark] class MapOutputTracker extends Logging {
+
+ private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+
+ // Set to the MapOutputTrackerActor living on the driver
+ var trackerActor: ActorRef = _
+
+ private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+
+ // Incremented every time a fetch fails so that client nodes know to clear
+ // their cache of map output locations if this happens.
+ private var epoch: Long = 0
+ private val epochLock = new java.lang.Object
+
+ // Cache a serialized version of the output statuses for each shuffle to send them out faster
+ var cacheEpoch = epoch
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+
+ val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
+
+ // Send a message to the trackerActor and get its result within a default timeout, or
+ // throw a SparkException if this fails.
+ def askTracker(message: Any): Any = {
+ try {
+ val future = trackerActor.ask(message)(timeout)
+ return Await.result(future, timeout)
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Error communicating with MapOutputTracker", e)
+ }
+ }
+
+ // Send a one-way message to the trackerActor, to which we expect it to reply with true.
+ def communicate(message: Any) {
+ if (askTracker(message) != true) {
+ throw new SparkException("Error reply received from MapOutputTracker")
+ }
+ }
+
+ def registerShuffle(shuffleId: Int, numMaps: Int) {
+ if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+ throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+ }
+ }
+
+ def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
+ var array = mapStatuses(shuffleId)
+ array.synchronized {
+ array(mapId) = status
+ }
+ }
+
+ def registerMapOutputs(
+ shuffleId: Int,
+ statuses: Array[MapStatus],
+ changeEpoch: Boolean = false) {
+ mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
+ if (changeEpoch) {
+ incrementEpoch()
+ }
+ }
+
+ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ var arrayOpt = mapStatuses.get(shuffleId)
+ if (arrayOpt.isDefined && arrayOpt.get != null) {
+ var array = arrayOpt.get
+ array.synchronized {
+ if (array(mapId) != null && array(mapId).location == bmAddress) {
+ array(mapId) = null
+ }
+ }
+ incrementEpoch()
+ } else {
+ throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
+ }
+ }
+
+ // Remembers which map output locations are currently being fetched on a worker
+ private val fetching = new HashSet[Int]
+
+ // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
+ def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
+ val statuses = mapStatuses.get(shuffleId).orNull
+ if (statuses == null) {
+ logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
+ var fetchedStatuses: Array[MapStatus] = null
+ fetching.synchronized {
+ if (fetching.contains(shuffleId)) {
+ // Someone else is fetching it; wait for them to be done
+ while (fetching.contains(shuffleId)) {
+ try {
+ fetching.wait()
+ } catch {
+ case e: InterruptedException =>
+ }
+ }
+ }
+
+ // Either while we waited the fetch happened successfully, or
+ // someone fetched it in between the get and the fetching.synchronized.
+ fetchedStatuses = mapStatuses.get(shuffleId).orNull
+ if (fetchedStatuses == null) {
+ // We have to do the fetch, get others to wait for us.
+ fetching += shuffleId
+ }
+ }
+
+ if (fetchedStatuses == null) {
+ // We won the race to fetch the output locs; do so
+ logInfo("Doing the fetch; tracker actor = " + trackerActor)
+ val hostPort = Utils.localHostPort()
+ // This try-finally prevents hangs due to timeouts:
+ try {
+ val fetchedBytes =
+ askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
+ fetchedStatuses = deserializeStatuses(fetchedBytes)
+ logInfo("Got the output locations")
+ mapStatuses.put(shuffleId, fetchedStatuses)
+ } finally {
+ fetching.synchronized {
+ fetching -= shuffleId
+ fetching.notifyAll()
+ }
+ }
+ }
+ if (fetchedStatuses != null) {
+ fetchedStatuses.synchronized {
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
+ }
+ }
+ else{
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing all output locations for shuffle " + shuffleId))
+ }
+ } else {
+ statuses.synchronized {
+ return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
+ }
+ }
+ }
+
+ private def cleanup(cleanupTime: Long) {
+ mapStatuses.clearOldValues(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
+ }
+
+ def stop() {
+ communicate(StopMapOutputTracker)
+ mapStatuses.clear()
+ metadataCleaner.cancel()
+ trackerActor = null
+ }
+
+ // Called on master to increment the epoch number
+ def incrementEpoch() {
+ epochLock.synchronized {
+ epoch += 1
+ logDebug("Increasing epoch to " + epoch)
+ }
+ }
+
+ // Called on master or workers to get current epoch number
+ def getEpoch: Long = {
+ epochLock.synchronized {
+ return epoch
+ }
+ }
+
+ // Called on workers to update the epoch number, potentially clearing old outputs
+ // because of a fetch failure. (Each worker task calls this with the latest epoch
+ // number on the master at the time it was created.)
+ def updateEpoch(newEpoch: Long) {
+ epochLock.synchronized {
+ if (newEpoch > epoch) {
+ logInfo("Updating epoch to " + newEpoch + " and clearing cache")
+ // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ mapStatuses.clear()
+ epoch = newEpoch
+ }
+ }
+ }
+
+ def getSerializedLocations(shuffleId: Int): Array[Byte] = {
+ var statuses: Array[MapStatus] = null
+ var epochGotten: Long = -1
+ epochLock.synchronized {
+ if (epoch > cacheEpoch) {
+ cachedSerializedStatuses.clear()
+ cacheEpoch = epoch
+ }
+ cachedSerializedStatuses.get(shuffleId) match {
+ case Some(bytes) =>
+ return bytes
+ case None =>
+ statuses = mapStatuses(shuffleId)
+ epochGotten = epoch
+ }
+ }
+ // If we got here, we failed to find the serialized locations in the cache, so we pulled
+ // out a snapshot of the locations as "locs"; let's serialize and return that
+ val bytes = serializeStatuses(statuses)
+ logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
+ // Add them into the table only if the epoch hasn't changed while we were working
+ epochLock.synchronized {
+ if (epoch == epochGotten) {
+ cachedSerializedStatuses(shuffleId) = bytes
+ }
+ }
+ return bytes
+ }
+
+ // Serialize an array of map output locations into an efficient byte format so that we can send
+ // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
+ // generally be pretty compressible because many map outputs will be on the same hostname.
+ private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ val out = new ByteArrayOutputStream
+ val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
+ // Since statuses can be modified in parallel, sync on it
+ statuses.synchronized {
+ objOut.writeObject(statuses)
+ }
+ objOut.close()
+ out.toByteArray
+ }
+
+ // Opposite of serializeStatuses.
+ def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
+ val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
+ objIn.readObject().
+ // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
+ // comment this out - nulls could be due to missing location ?
+ asInstanceOf[Array[MapStatus]] // .filter( _ != null )
+ }
+}
+
+private[spark] object MapOutputTracker {
+ private val LOG_BASE = 1.1
+
+ // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
+ // any of the statuses is null (indicating a missing location due to a failed mapper),
+ // throw a FetchFailedException.
+ private def convertMapStatuses(
+ shuffleId: Int,
+ reduceId: Int,
+ statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
+ assert (statuses != null)
+ statuses.map {
+ status =>
+ if (status == null) {
+ throw new FetchFailedException(null, shuffleId, -1, reduceId,
+ new Exception("Missing an output location for shuffle " + shuffleId))
+ } else {
+ (status.location, decompressSize(status.compressedSizes(reduceId)))
+ }
+ }
+ }
+
+ /**
+ * Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
+ * We do this by encoding the log base 1.1 of the size as an integer, which can support
+ * sizes up to 35 GB with at most 10% error.
+ */
+ def compressSize(size: Long): Byte = {
+ if (size == 0) {
+ 0
+ } else if (size <= 1L) {
+ 1
+ } else {
+ math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
+ }
+ }
+
+ /**
+ * Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
+ */
+ def decompressSize(compressedSize: Byte): Long = {
+ if (compressedSize == 0) {
+ 0
+ } else {
+ math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/PairRDDFunctions.scala
new file mode 100644
index 0000000000..d046e7c1a4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/PairRDDFunctions.scala
@@ -0,0 +1,703 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.nio.ByteBuffer
+import java.util.{Date, HashMap => JHashMap}
+import java.text.SimpleDateFormat
+
+import scala.collection.{mutable, Map}
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.compress.CompressionCodec
+import org.apache.hadoop.io.SequenceFile.CompressionType
+import org.apache.hadoop.mapred.FileOutputCommitter
+import org.apache.hadoop.mapred.FileOutputFormat
+import org.apache.hadoop.mapred.SparkHadoopWriter
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.OutputFormat
+
+import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat,
+ RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, SparkHadoopMapReduceUtil}
+import org.apache.hadoop.security.UserGroupInformation
+
+import org.apache.spark.partial.BoundedDouble
+import org.apache.spark.partial.PartialResult
+import org.apache.spark.rdd._
+import org.apache.spark.SparkContext._
+import org.apache.spark.Partitioner._
+
+/**
+ * Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
+ * Import `spark.SparkContext._` at the top of your program to use these functions.
+ */
+class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
+ extends Logging
+ with SparkHadoopMapReduceUtil
+ with Serializable {
+
+ /**
+ * Generic function to combine the elements for each key using a custom set of aggregation
+ * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C
+ * Note that V and C can be different -- for example, one might group an RDD of type
+ * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions:
+ *
+ * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
+ * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
+ * - `mergeCombiners`, to combine two C's into a single one.
+ *
+ * In addition, users can control the partitioning of the output RDD, and whether to perform
+ * map-side aggregation (if a mapper can produce multiple items with the same key).
+ */
+ def combineByKey[C](createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C,
+ partitioner: Partitioner,
+ mapSideCombine: Boolean = true,
+ serializerClass: String = null): RDD[(K, C)] = {
+ if (getKeyClass().isArray) {
+ if (mapSideCombine) {
+ throw new SparkException("Cannot use map-side combining with array keys.")
+ }
+ if (partitioner.isInstanceOf[HashPartitioner]) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ }
+ val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ if (self.partitioner == Some(partitioner)) {
+ self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ } else if (mapSideCombine) {
+ val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
+ .setSerializer(serializerClass)
+ partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
+ } else {
+ // Don't apply map-side combiner.
+ // A sanity check to make sure mergeCombiners is not defined.
+ assert(mergeCombiners == null)
+ val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
+ values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ }
+ }
+
+ /**
+ * Simplified version of combineByKey that hash-partitions the output RDD.
+ */
+ def combineByKey[C](createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C,
+ numPartitions: Int): RDD[(K, C)] = {
+ combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
+ }
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
+ // Serialize the zero value to a byte array so that we can get a new clone of it on each key
+ val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroArray = new Array[Byte](zeroBuffer.limit)
+ zeroBuffer.get(zeroArray)
+
+ // When deserializing, use a lazy val to create just one instance of the serializer per task
+ lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
+
+ combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
+ }
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = {
+ foldByKey(zeroValue, new HashPartitioner(numPartitions))(func)
+ }
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = {
+ foldByKey(zeroValue, defaultPartitioner(self))(func)
+ }
+
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce.
+ */
+ def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = {
+ combineByKey[V]((v: V) => v, func, func, partitioner)
+ }
+
+ /**
+ * Merge the values for each key using an associative reduce function, but return the results
+ * immediately to the master as a Map. This will also perform the merging locally on each mapper
+ * before sending results to a reducer, similarly to a "combiner" in MapReduce.
+ */
+ def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
+
+ if (getKeyClass().isArray) {
+ throw new SparkException("reduceByKeyLocally() does not support array keys")
+ }
+
+ def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
+ val map = new JHashMap[K, V]
+ iter.foreach { case (k, v) =>
+ val old = map.get(k)
+ map.put(k, if (old == null) v else func(old, v))
+ }
+ Iterator(map)
+ }
+
+ def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = {
+ m2.foreach { case (k, v) =>
+ val old = m1.get(k)
+ m1.put(k, if (old == null) v else func(old, v))
+ }
+ m1
+ }
+
+ self.mapPartitions(reducePartition).reduce(mergeMaps)
+ }
+
+ /** Alias for reduceByKeyLocally */
+ def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
+
+ /** Count the number of elements for each key, and return the result to the master as a Map. */
+ def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
+
+ /**
+ * (Experimental) Approximate version of countByKey that can return a partial result if it does
+ * not finish within a timeout.
+ */
+ def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
+ : PartialResult[Map[K, BoundedDouble]] = {
+ self.map(_._1).countByValueApprox(timeout, confidence)
+ }
+
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
+ */
+ def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = {
+ reduceByKey(new HashPartitioner(numPartitions), func)
+ }
+
+ /**
+ * Group the values for each key in the RDD into a single sequence. Allows controlling the
+ * partitioning of the resulting key-value pair RDD by passing a Partitioner.
+ */
+ def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = {
+ // groupByKey shouldn't use map side combine because map side combine does not
+ // reduce the amount of data shuffled and requires all map side data be inserted
+ // into a hash table, leading to more objects in the old gen.
+ def createCombiner(v: V) = ArrayBuffer(v)
+ def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
+ val bufs = combineByKey[ArrayBuffer[V]](
+ createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false)
+ bufs.asInstanceOf[RDD[(K, Seq[V])]]
+ }
+
+ /**
+ * Group the values for each key in the RDD into a single sequence. Hash-partitions the
+ * resulting RDD with into `numPartitions` partitions.
+ */
+ def groupByKey(numPartitions: Int): RDD[(K, Seq[V])] = {
+ groupByKey(new HashPartitioner(numPartitions))
+ }
+
+ /**
+ * Return a copy of the RDD partitioned using the specified partitioner.
+ */
+ def partitionBy(partitioner: Partitioner): RDD[(K, V)] = {
+ if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ new ShuffledRDD[K, V, (K, V)](self, partitioner)
+ }
+
+ /**
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD.
+ */
+ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
+ this.cogroup(other, partitioner).flatMapValues { case (vs, ws) =>
+ for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
+ }
+ }
+
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to
+ * partition the output RDD.
+ */
+ def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = {
+ this.cogroup(other, partitioner).flatMapValues { case (vs, ws) =>
+ if (ws.isEmpty) {
+ vs.iterator.map(v => (v, None))
+ } else {
+ for (v <- vs.iterator; w <- ws.iterator) yield (v, Some(w))
+ }
+ }
+ }
+
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to
+ * partition the output RDD.
+ */
+ def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner)
+ : RDD[(K, (Option[V], W))] = {
+ this.cogroup(other, partitioner).flatMapValues { case (vs, ws) =>
+ if (vs.isEmpty) {
+ ws.iterator.map(w => (None, w))
+ } else {
+ for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), w)
+ }
+ }
+ }
+
+ /**
+ * Simplified version of combineByKey that hash-partitions the resulting RDD using the
+ * existing partitioner/parallelism level.
+ */
+ def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C)
+ : RDD[(K, C)] = {
+ combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self))
+ }
+
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
+ * parallelism level.
+ */
+ def reduceByKey(func: (V, V) => V): RDD[(K, V)] = {
+ reduceByKey(defaultPartitioner(self), func)
+ }
+
+ /**
+ * Group the values for each key in the RDD into a single sequence. Hash-partitions the
+ * resulting RDD with the existing partitioner/parallelism level.
+ */
+ def groupByKey(): RDD[(K, Seq[V])] = {
+ groupByKey(defaultPartitioner(self))
+ }
+
+ /**
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Performs a hash join across the cluster.
+ */
+ def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = {
+ join(other, defaultPartitioner(self, other))
+ }
+
+ /**
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Performs a hash join across the cluster.
+ */
+ def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = {
+ join(other, new HashPartitioner(numPartitions))
+ }
+
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
+ * using the existing partitioner/parallelism level.
+ */
+ def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = {
+ leftOuterJoin(other, defaultPartitioner(self, other))
+ }
+
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
+ * into `numPartitions` partitions.
+ */
+ def leftOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, Option[W]))] = {
+ leftOuterJoin(other, new HashPartitioner(numPartitions))
+ }
+
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
+ * RDD using the existing partitioner/parallelism level.
+ */
+ def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = {
+ rightOuterJoin(other, defaultPartitioner(self, other))
+ }
+
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
+ * RDD into the given number of partitions.
+ */
+ def rightOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], W))] = {
+ rightOuterJoin(other, new HashPartitioner(numPartitions))
+ }
+
+ /**
+ * Return the key-value pairs in this RDD to the master as a Map.
+ */
+ def collectAsMap(): Map[K, V] = {
+ val data = self.toArray()
+ val map = new mutable.HashMap[K, V]
+ map.sizeHint(data.length)
+ data.foreach { case (k, v) => map.put(k, v) }
+ map
+ }
+
+ /**
+ * Pass each value in the key-value pair RDD through a map function without changing the keys;
+ * this also retains the original RDD's partitioning.
+ */
+ def mapValues[U](f: V => U): RDD[(K, U)] = {
+ val cleanF = self.context.clean(f)
+ new MappedValuesRDD(self, cleanF)
+ }
+
+ /**
+ * Pass each value in the key-value pair RDD through a flatMap function without changing the
+ * keys; this also retains the original RDD's partitioning.
+ */
+ def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = {
+ val cleanF = self.context.clean(f)
+ new FlatMappedValuesRDD(self, cleanF)
+ }
+
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
+ def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ val cg = new CoGroupedRDD[K](Seq(self, other), partitioner)
+ val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
+ prfs.mapValues { case Seq(vs, ws) =>
+ (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
+ }
+ }
+
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
+ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
+ : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
+ if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) {
+ throw new SparkException("Default partitioner cannot partition array keys.")
+ }
+ val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner)
+ val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
+ prfs.mapValues { case Seq(vs, w1s, w2s) =>
+ (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]])
+ }
+ }
+
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
+ def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
+ cogroup(other, defaultPartitioner(self, other))
+ }
+
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
+ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)])
+ : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
+ cogroup(other1, other2, defaultPartitioner(self, other1, other2))
+ }
+
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
+ def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Seq[V], Seq[W]))] = {
+ cogroup(other, new HashPartitioner(numPartitions))
+ }
+
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
+ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int)
+ : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
+ cogroup(other1, other2, new HashPartitioner(numPartitions))
+ }
+
+ /** Alias for cogroup. */
+ def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
+ cogroup(other, defaultPartitioner(self, other))
+ }
+
+ /** Alias for cogroup. */
+ def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)])
+ : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
+ cogroup(other1, other2, defaultPartitioner(self, other1, other2))
+ }
+
+ /**
+ * Return an RDD with the pairs from `this` whose keys are not in `other`.
+ *
+ * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+ * RDD will be <= us.
+ */
+ def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] =
+ subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
+
+ /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+ def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
+ subtractByKey(other, new HashPartitioner(numPartitions))
+
+ /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+ def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
+ new SubtractedRDD[K, V, W](self, other, p)
+
+ /**
+ * Return the list of values in the RDD for key `key`. This operation is done efficiently if the
+ * RDD has a known partitioner by only searching the partition that the key maps to.
+ */
+ def lookup(key: K): Seq[V] = {
+ self.partitioner match {
+ case Some(p) =>
+ val index = p.getPartition(key)
+ def process(it: Iterator[(K, V)]): Seq[V] = {
+ val buf = new ArrayBuffer[V]
+ for ((k, v) <- it if k == key) {
+ buf += v
+ }
+ buf
+ }
+ val res = self.context.runJob(self, process _, Array(index), false)
+ res(0)
+ case None =>
+ self.filter(_._1 == key).map(_._2).collect()
+ }
+ }
+
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD.
+ */
+ def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
+ saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ }
+
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD. Compress the result with the
+ * supplied codec.
+ */
+ def saveAsHadoopFile[F <: OutputFormat[K, V]](
+ path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) {
+ saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec)
+ }
+
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
+ * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
+ */
+ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
+ saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ }
+
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
+ * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
+ */
+ def saveAsNewAPIHadoopFile(
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
+ conf: Configuration = self.context.hadoopConfiguration) {
+ val job = new NewAPIHadoopJob(conf)
+ job.setOutputKeyClass(keyClass)
+ job.setOutputValueClass(valueClass)
+ val wrappedConf = new SerializableWritable(job.getConfiguration)
+ NewFileOutputFormat.setOutputPath(job, new Path(path))
+ val formatter = new SimpleDateFormat("yyyyMMddHHmm")
+ val jobtrackerID = formatter.format(new Date())
+ val stageId = self.id
+ def writeShard(context: TaskContext, iter: Iterator[(K,V)]): Int = {
+ // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
+ // around by taking a mod. We expect that no task will be attempted 2 billion times.
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ /* "reduce task" <split #> <attempt # = spark task #> */
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
+ val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
+ val format = outputFormatClass.newInstance
+ val committer = format.getOutputCommitter(hadoopContext)
+ committer.setupTask(hadoopContext)
+ val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
+ while (iter.hasNext) {
+ val (k, v) = iter.next()
+ writer.write(k, v)
+ }
+ writer.close(hadoopContext)
+ committer.commitTask(hadoopContext)
+ return 1
+ }
+ val jobFormat = outputFormatClass.newInstance
+ /* apparently we need a TaskAttemptID to construct an OutputCommitter;
+ * however we're only going to use this local OutputCommitter for
+ * setupJob/commitJob, so we just use a dummy "map" task.
+ */
+ val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0)
+ val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
+ val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
+ jobCommitter.setupJob(jobTaskContext)
+ val count = self.context.runJob(self, writeShard _).sum
+ jobCommitter.commitJob(jobTaskContext)
+ jobCommitter.cleanupJob(jobTaskContext)
+ }
+
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD. Compress with the supplied codec.
+ */
+ def saveAsHadoopFile(
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: OutputFormat[_, _]],
+ codec: Class[_ <: CompressionCodec]) {
+ saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass,
+ new JobConf(self.context.hadoopConfiguration), Some(codec))
+ }
+
+ /**
+ * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
+ * supporting the key and value types K and V in this RDD.
+ */
+ def saveAsHadoopFile(
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[_ <: OutputFormat[_, _]],
+ conf: JobConf = new JobConf(self.context.hadoopConfiguration),
+ codec: Option[Class[_ <: CompressionCodec]] = None) {
+ conf.setOutputKeyClass(keyClass)
+ conf.setOutputValueClass(valueClass)
+ // conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
+ conf.set("mapred.output.format.class", outputFormatClass.getName)
+ for (c <- codec) {
+ conf.setCompressMapOutput(true)
+ conf.set("mapred.output.compress", "true")
+ conf.setMapOutputCompressorClass(c)
+ conf.set("mapred.output.compression.codec", c.getCanonicalName)
+ conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString)
+ }
+ conf.setOutputCommitter(classOf[FileOutputCommitter])
+ FileOutputFormat.setOutputPath(conf, SparkHadoopWriter.createPathFromString(path, conf))
+ saveAsHadoopDataset(conf)
+ }
+
+ /**
+ * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for
+ * that storage system. The JobConf should set an OutputFormat and any output paths required
+ * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop
+ * MapReduce job.
+ */
+ def saveAsHadoopDataset(conf: JobConf) {
+ val outputFormatClass = conf.getOutputFormat
+ val keyClass = conf.getOutputKeyClass
+ val valueClass = conf.getOutputValueClass
+ if (outputFormatClass == null) {
+ throw new SparkException("Output format class not set")
+ }
+ if (keyClass == null) {
+ throw new SparkException("Output key class not set")
+ }
+ if (valueClass == null) {
+ throw new SparkException("Output value class not set")
+ }
+
+ logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")")
+
+ val writer = new SparkHadoopWriter(conf)
+ writer.preSetup()
+
+ def writeToFile(context: TaskContext, iter: Iterator[(K, V)]) {
+ // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
+ // around by taking a mod. We expect that no task will be attempted 2 billion times.
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+
+ writer.setup(context.stageId, context.splitId, attemptNumber)
+ writer.open()
+
+ var count = 0
+ while(iter.hasNext) {
+ val record = iter.next()
+ count += 1
+ writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
+ }
+
+ writer.close()
+ writer.commit()
+ }
+
+ self.context.runJob(self, writeToFile _)
+ writer.commitJob()
+ writer.cleanup()
+ }
+
+ /**
+ * Return an RDD with the keys of each tuple.
+ */
+ def keys: RDD[K] = self.map(_._1)
+
+ /**
+ * Return an RDD with the values of each tuple.
+ */
+ def values: RDD[V] = self.map(_._2)
+
+ private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure
+
+ private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure
+}
+
+
+private[spark] object Manifests {
+ val seqSeqManifest = classManifest[Seq[Seq[_]]]
+}
diff --git a/core/src/main/scala/org/apache/spark/Partition.scala b/core/src/main/scala/org/apache/spark/Partition.scala
new file mode 100644
index 0000000000..87914a061f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/Partition.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * A partition of an RDD.
+ */
+trait Partition extends Serializable {
+ /**
+ * Get the split's index within its parent RDD
+ */
+ def index: Int
+
+ // A better default implementation of HashCode
+ override def hashCode(): Int = index
+}
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
new file mode 100644
index 0000000000..4dce2607b0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * An object that defines how the elements in a key-value pair RDD are partitioned by key.
+ * Maps each key to a partition ID, from 0 to `numPartitions - 1`.
+ */
+abstract class Partitioner extends Serializable {
+ def numPartitions: Int
+ def getPartition(key: Any): Int
+}
+
+object Partitioner {
+ /**
+ * Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
+ *
+ * If any of the RDDs already has a partitioner, choose that one.
+ *
+ * Otherwise, we use a default HashPartitioner. For the number of partitions, if
+ * spark.default.parallelism is set, then we'll use the value from SparkContext
+ * defaultParallelism, otherwise we'll use the max number of upstream partitions.
+ *
+ * Unless spark.default.parallelism is set, He number of partitions will be the
+ * same as the number of partitions in the largest upstream RDD, as this should
+ * be least likely to cause out-of-memory errors.
+ *
+ * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD.
+ */
+ def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
+ val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
+ for (r <- bySize if r.partitioner != None) {
+ return r.partitioner.get
+ }
+ if (System.getProperty("spark.default.parallelism") != null) {
+ return new HashPartitioner(rdd.context.defaultParallelism)
+ } else {
+ return new HashPartitioner(bySize.head.partitions.size)
+ }
+ }
+}
+
+/**
+ * A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
+ *
+ * Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
+ * so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
+ * produce an unexpected or incorrect result.
+ */
+class HashPartitioner(partitions: Int) extends Partitioner {
+ def numPartitions = partitions
+
+ def getPartition(key: Any): Int = key match {
+ case null => 0
+ case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case h: HashPartitioner =>
+ h.numPartitions == numPartitions
+ case _ =>
+ false
+ }
+}
+
+/**
+ * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly equal ranges.
+ * Determines the ranges by sampling the RDD passed in.
+ */
+class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
+ partitions: Int,
+ @transient rdd: RDD[_ <: Product2[K,V]],
+ private val ascending: Boolean = true)
+ extends Partitioner {
+
+ // An array of upper bounds for the first (partitions - 1) partitions
+ private val rangeBounds: Array[K] = {
+ if (partitions == 1) {
+ Array()
+ } else {
+ val rddSize = rdd.count()
+ val maxSampleSize = partitions * 20.0
+ val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
+ val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _)
+ if (rddSample.length == 0) {
+ Array()
+ } else {
+ val bounds = new Array[K](partitions - 1)
+ for (i <- 0 until partitions - 1) {
+ val index = (rddSample.length - 1) * (i + 1) / partitions
+ bounds(i) = rddSample(index)
+ }
+ bounds
+ }
+ }
+ }
+
+ def numPartitions = partitions
+
+ def getPartition(key: Any): Int = {
+ // TODO: Use a binary search here if number of partitions is large
+ val k = key.asInstanceOf[K]
+ var partition = 0
+ while (partition < rangeBounds.length && k > rangeBounds(partition)) {
+ partition += 1
+ }
+ if (ascending) {
+ partition
+ } else {
+ rangeBounds.length - partition
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case r: RangePartitioner[_,_] =>
+ r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
+ case _ =>
+ false
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/RDD.scala b/core/src/main/scala/org/apache/spark/RDD.scala
new file mode 100644
index 0000000000..0d1f07f76c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/RDD.scala
@@ -0,0 +1,957 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.Random
+
+import scala.collection.Map
+import scala.collection.JavaConversions.mapAsScalaMap
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.hadoop.io.BytesWritable
+import org.apache.hadoop.io.compress.CompressionCodec
+import org.apache.hadoop.io.NullWritable
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapred.TextOutputFormat
+
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+import org.apache.spark.Partitioner._
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.partial.BoundedDouble
+import org.apache.spark.partial.CountEvaluator
+import org.apache.spark.partial.GroupedCountEvaluator
+import org.apache.spark.partial.PartialResult
+import org.apache.spark.rdd.CoalescedRDD
+import org.apache.spark.rdd.CartesianRDD
+import org.apache.spark.rdd.FilteredRDD
+import org.apache.spark.rdd.FlatMappedRDD
+import org.apache.spark.rdd.GlommedRDD
+import org.apache.spark.rdd.MappedRDD
+import org.apache.spark.rdd.MapPartitionsRDD
+import org.apache.spark.rdd.MapPartitionsWithIndexRDD
+import org.apache.spark.rdd.PipedRDD
+import org.apache.spark.rdd.SampledRDD
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.rdd.UnionRDD
+import org.apache.spark.rdd.ZippedRDD
+import org.apache.spark.rdd.ZippedPartitionsRDD2
+import org.apache.spark.rdd.ZippedPartitionsRDD3
+import org.apache.spark.rdd.ZippedPartitionsRDD4
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.BoundedPriorityQueue
+
+import SparkContext._
+
+/**
+ * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
+ * partitioned collection of elements that can be operated on in parallel. This class contains the
+ * basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition,
+ * [[org.apache.spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such
+ * as `groupByKey` and `join`; [[org.apache.spark.DoubleRDDFunctions]] contains operations available only on
+ * RDDs of Doubles; and [[org.apache.spark.SequenceFileRDDFunctions]] contains operations available on RDDs
+ * that can be saved as SequenceFiles. These operations are automatically available on any RDD of
+ * the right type (e.g. RDD[(Int, Int)] through implicit conversions when you
+ * `import org.apache.spark.SparkContext._`.
+ *
+ * Internally, each RDD is characterized by five main properties:
+ *
+ * - A list of partitions
+ * - A function for computing each split
+ * - A list of dependencies on other RDDs
+ * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned)
+ * - Optionally, a list of preferred locations to compute each split on (e.g. block locations for
+ * an HDFS file)
+ *
+ * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD
+ * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for
+ * reading data from a new storage system) by overriding these functions. Please refer to the
+ * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
+ * on RDD internals.
+ */
+abstract class RDD[T: ClassManifest](
+ @transient private var sc: SparkContext,
+ @transient private var deps: Seq[Dependency[_]]
+ ) extends Serializable with Logging {
+
+ /** Construct an RDD with just a one-to-one dependency on one parent */
+ def this(@transient oneParent: RDD[_]) =
+ this(oneParent.context , List(new OneToOneDependency(oneParent)))
+
+ // =======================================================================
+ // Methods that should be implemented by subclasses of RDD
+ // =======================================================================
+
+ /** Implemented by subclasses to compute a given partition. */
+ def compute(split: Partition, context: TaskContext): Iterator[T]
+
+ /**
+ * Implemented by subclasses to return the set of partitions in this RDD. This method will only
+ * be called once, so it is safe to implement a time-consuming computation in it.
+ */
+ protected def getPartitions: Array[Partition]
+
+ /**
+ * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only
+ * be called once, so it is safe to implement a time-consuming computation in it.
+ */
+ protected def getDependencies: Seq[Dependency[_]] = deps
+
+ /** Optionally overridden by subclasses to specify placement preferences. */
+ protected def getPreferredLocations(split: Partition): Seq[String] = Nil
+
+ /** Optionally overridden by subclasses to specify how they are partitioned. */
+ val partitioner: Option[Partitioner] = None
+
+ // =======================================================================
+ // Methods and fields available on all RDDs
+ // =======================================================================
+
+ /** The SparkContext that created this RDD. */
+ def sparkContext: SparkContext = sc
+
+ /** A unique ID for this RDD (within its SparkContext). */
+ val id: Int = sc.newRddId()
+
+ /** A friendly name for this RDD */
+ var name: String = null
+
+ /** Assign a name to this RDD */
+ def setName(_name: String) = {
+ name = _name
+ this
+ }
+
+ /** User-defined generator of this RDD*/
+ var generator = Utils.getCallSiteInfo.firstUserClass
+
+ /** Reset generator*/
+ def setGenerator(_generator: String) = {
+ generator = _generator
+ }
+
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet..
+ */
+ def persist(newLevel: StorageLevel): RDD[T] = {
+ // TODO: Handle changes of StorageLevel
+ if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
+ throw new UnsupportedOperationException(
+ "Cannot change storage level of an RDD after it was already assigned a level")
+ }
+ storageLevel = newLevel
+ // Register the RDD with the SparkContext
+ sc.persistentRdds(id) = this
+ this
+ }
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ def cache(): RDD[T] = persist()
+
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ * @return This RDD.
+ */
+ def unpersist(blocking: Boolean = true): RDD[T] = {
+ logInfo("Removing RDD " + id + " from persistence list")
+ sc.env.blockManager.master.removeRdd(id, blocking)
+ sc.persistentRdds.remove(id)
+ storageLevel = StorageLevel.NONE
+ this
+ }
+
+ /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
+ def getStorageLevel = storageLevel
+
+ // Our dependencies and partitions will be gotten by calling subclass's methods below, and will
+ // be overwritten when we're checkpointed
+ private var dependencies_ : Seq[Dependency[_]] = null
+ @transient private var partitions_ : Array[Partition] = null
+
+ /** An Option holding our checkpoint RDD, if we are checkpointed */
+ private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
+
+ /**
+ * Get the list of dependencies of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def dependencies: Seq[Dependency[_]] = {
+ checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
+ if (dependencies_ == null) {
+ dependencies_ = getDependencies
+ }
+ dependencies_
+ }
+ }
+
+ /**
+ * Get the array of partitions of this RDD, taking into account whether the
+ * RDD is checkpointed or not.
+ */
+ final def partitions: Array[Partition] = {
+ checkpointRDD.map(_.partitions).getOrElse {
+ if (partitions_ == null) {
+ partitions_ = getPartitions
+ }
+ partitions_
+ }
+ }
+
+ /**
+ * Get the preferred locations of a partition (as hostnames), taking into account whether the
+ * RDD is checkpointed.
+ */
+ final def preferredLocations(split: Partition): Seq[String] = {
+ checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
+ getPreferredLocations(split)
+ }
+ }
+
+ /**
+ * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
+ * This should ''not'' be called by users directly, but is available for implementors of custom
+ * subclasses of RDD.
+ */
+ final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
+ if (storageLevel != StorageLevel.NONE) {
+ SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
+ } else {
+ computeOrReadCheckpoint(split, context)
+ }
+ }
+
+ /**
+ * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
+ */
+ private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = {
+ if (isCheckpointed) {
+ firstParent[T].iterator(split, context)
+ } else {
+ compute(split, context)
+ }
+ }
+
+ // Transformations (return a new RDD)
+
+ /**
+ * Return a new RDD by applying a function to all elements of this RDD.
+ */
+ def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f))
+
+ /**
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
+ */
+ def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] =
+ new FlatMappedRDD(this, sc.clean(f))
+
+ /**
+ * Return a new RDD containing only the elements that satisfy a predicate.
+ */
+ def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f))
+
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
+ def distinct(numPartitions: Int): RDD[T] =
+ map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1)
+
+ def distinct(): RDD[T] = distinct(partitions.size)
+
+ /**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T] = {
+ if (shuffle) {
+ // include a shuffle step so that our upstream tasks are still distributed
+ new CoalescedRDD(
+ new ShuffledRDD[T, Null, (T, Null)](map(x => (x, null)),
+ new HashPartitioner(numPartitions)),
+ numPartitions).keys
+ } else {
+ new CoalescedRDD(this, numPartitions)
+ }
+ }
+
+ /**
+ * Return a sampled subset of this RDD.
+ */
+ def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
+ new SampledRDD(this, withReplacement, fraction, seed)
+
+ def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
+ var fraction = 0.0
+ var total = 0
+ val multiplier = 3.0
+ val initialCount = this.count()
+ var maxSelected = 0
+
+ if (num < 0) {
+ throw new IllegalArgumentException("Negative number of elements requested")
+ }
+
+ if (initialCount > Integer.MAX_VALUE - 1) {
+ maxSelected = Integer.MAX_VALUE - 1
+ } else {
+ maxSelected = initialCount.toInt
+ }
+
+ if (num > initialCount && !withReplacement) {
+ total = maxSelected
+ fraction = multiplier * (maxSelected + 1) / initialCount
+ } else {
+ fraction = multiplier * (num + 1) / initialCount
+ total = num
+ }
+
+ val rand = new Random(seed)
+ var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
+
+ // If the first sample didn't turn out large enough, keep trying to take samples;
+ // this shouldn't happen often because we use a big multiplier for thei initial size
+ while (samples.length < total) {
+ samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
+ }
+
+ Utils.randomizeInPlace(samples, rand).take(total)
+ }
+
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
+ def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
+
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
+ def ++(other: RDD[T]): RDD[T] = this.union(other)
+
+ /**
+ * Return an RDD created by coalescing all elements within each partition into an array.
+ */
+ def glom(): RDD[Array[T]] = new GlommedRDD(this)
+
+ /**
+ * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
+ * elements (a, b) where a is in `this` and b is in `other`.
+ */
+ def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
+
+ /**
+ * Return an RDD of grouped items.
+ */
+ def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] =
+ groupBy[K](f, defaultPartitioner(this))
+
+ /**
+ * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
+ * mapping to that key.
+ */
+ def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] =
+ groupBy(f, new HashPartitioner(numPartitions))
+
+ /**
+ * Return an RDD of grouped items.
+ */
+ def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = {
+ val cleanF = sc.clean(f)
+ this.map(t => (cleanF(t), t)).groupByKey(p)
+ }
+
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ */
+ def pipe(command: String): RDD[String] = new PipedRDD(this, command)
+
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ */
+ def pipe(command: String, env: Map[String, String]): RDD[String] =
+ new PipedRDD(this, command, env)
+
+
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ * The print behavior can be customized by providing two functions.
+ *
+ * @param command command to run in forked process.
+ * @param env environment variables to set.
+ * @param printPipeContext Before piping elements, this function is called as an oppotunity
+ * to pipe context data. Print line function (like out.println) will be
+ * passed as printPipeContext's parameter.
+ * @param printRDDElement Use this function to customize how to pipe elements. This function
+ * will be called with each RDD element as the 1st parameter, and the
+ * print line function (like out.println()) as the 2nd parameter.
+ * An example of pipe the RDD data of groupBy() in a streaming way,
+ * instead of constructing a huge String to concat all the elements:
+ * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
+ * for (e <- record._2){f(e)}
+ * @return the result RDD
+ */
+ def pipe(
+ command: Seq[String],
+ env: Map[String, String] = Map(),
+ printPipeContext: (String => Unit) => Unit = null,
+ printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
+ new PipedRDD(this, command, env,
+ if (printPipeContext ne null) sc.clean(printPipeContext) else null,
+ if (printRDDElement ne null) sc.clean(printRDDElement) else null)
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD.
+ */
+ def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] =
+ new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
+ * of the original partition.
+ */
+ def mapPartitionsWithIndex[U: ClassManifest](
+ f: (Int, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] =
+ new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
+ * of the original partition.
+ */
+ @deprecated("use mapPartitionsWithIndex", "0.7.0")
+ def mapPartitionsWithSplit[U: ClassManifest](
+ f: (Int, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] =
+ new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+
+ /**
+ * Maps f over this RDD, where f takes an additional parameter of type A. This
+ * additional parameter is produced by constructA, which is called in each
+ * partition with the index of that partition.
+ */
+ def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f:(T, A) => U): RDD[U] = {
+ def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(index)
+ iter.map(t => f(t, a))
+ }
+ new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }
+
+ /**
+ * FlatMaps f over this RDD, where f takes an additional parameter of type A. This
+ * additional parameter is produced by constructA, which is called in each
+ * partition with the index of that partition.
+ */
+ def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f:(T, A) => Seq[U]): RDD[U] = {
+ def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(index)
+ iter.flatMap(t => f(t, a))
+ }
+ new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }
+
+ /**
+ * Applies f to each element of this RDD, where f takes an additional parameter of type A.
+ * This additional parameter is produced by constructA, which is called in each
+ * partition with the index of that partition.
+ */
+ def foreachWith[A: ClassManifest](constructA: Int => A)
+ (f:(T, A) => Unit) {
+ def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(index)
+ iter.map(t => {f(t, a); t})
+ }
+ (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
+ }
+
+ /**
+ * Filters this RDD with p, where p takes an additional parameter of type A. This
+ * additional parameter is produced by constructA, which is called in each
+ * partition with the index of that partition.
+ */
+ def filterWith[A: ClassManifest](constructA: Int => A)
+ (p:(T, A) => Boolean): RDD[T] = {
+ def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(index)
+ iter.filter(t => p(t, a))
+ }
+ new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
+ }
+
+ /**
+ * Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
+ * second element in each RDD, etc. Assumes that the two RDDs have the *same number of
+ * partitions* and the *same number of elements in each partition* (e.g. one was made through
+ * a map on the other).
+ */
+ def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
+
+ /**
+ * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
+ * applying a function to the zipped partitions. Assumes that all the RDDs have the
+ * *same number of partitions*, but does *not* require them to have the same number
+ * of elements in each partition.
+ */
+ def zipPartitions[B: ClassManifest, V: ClassManifest]
+ (rdd2: RDD[B])
+ (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] =
+ new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2)
+
+ def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]
+ (rdd2: RDD[B], rdd3: RDD[C])
+ (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] =
+ new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3)
+
+ def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]
+ (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D])
+ (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] =
+ new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4)
+
+
+ // Actions (launch a job to return a value to the user program)
+
+ /**
+ * Applies a function f to all elements of this RDD.
+ */
+ def foreach(f: T => Unit) {
+ val cleanF = sc.clean(f)
+ sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
+ }
+
+ /**
+ * Applies a function f to each partition of this RDD.
+ */
+ def foreachPartition(f: Iterator[T] => Unit) {
+ val cleanF = sc.clean(f)
+ sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
+ }
+
+ /**
+ * Return an array that contains all of the elements in this RDD.
+ */
+ def collect(): Array[T] = {
+ val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
+ Array.concat(results: _*)
+ }
+
+ /**
+ * Return an array that contains all of the elements in this RDD.
+ */
+ def toArray(): Array[T] = collect()
+
+ /**
+ * Return an RDD that contains all matching values by applying `f`.
+ */
+ def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = {
+ filter(f.isDefinedAt).map(f)
+ }
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ *
+ * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+ * RDD will be <= us.
+ */
+ def subtract(other: RDD[T]): RDD[T] =
+ subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: RDD[T], numPartitions: Int): RDD[T] =
+ subtract(other, new HashPartitioner(numPartitions))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
+ if (partitioner == Some(p)) {
+ // Our partitioner knows how to handle T (which, since we have a partitioner, is
+ // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples
+ val p2 = new Partitioner() {
+ override def numPartitions = p.numPartitions
+ override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1)
+ }
+ // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies
+ // anyway, and when calling .keys, will not have a partitioner set, even though
+ // the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be
+ // partitioned by the right/real keys (e.g. p).
+ this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys
+ } else {
+ this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys
+ }
+ }
+
+ /**
+ * Reduces the elements of this RDD using the specified commutative and associative binary operator.
+ */
+ def reduce(f: (T, T) => T): T = {
+ val cleanF = sc.clean(f)
+ val reducePartition: Iterator[T] => Option[T] = iter => {
+ if (iter.hasNext) {
+ Some(iter.reduceLeft(cleanF))
+ } else {
+ None
+ }
+ }
+ var jobResult: Option[T] = None
+ val mergeResult = (index: Int, taskResult: Option[T]) => {
+ if (taskResult != None) {
+ jobResult = jobResult match {
+ case Some(value) => Some(f(value, taskResult.get))
+ case None => taskResult
+ }
+ }
+ }
+ sc.runJob(this, reducePartition, mergeResult)
+ // Get the final result out of our Option, or throw an exception if the RDD was empty
+ jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
+ }
+
+ /**
+ * Aggregate the elements of each partition, and then the results for all the partitions, using a
+ * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
+ * modify t1 and return it as its result value to avoid object allocation; however, it should not
+ * modify t2.
+ */
+ def fold(zeroValue: T)(op: (T, T) => T): T = {
+ // Clone the zero value since we will also be serializing it as part of tasks
+ var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
+ val cleanOp = sc.clean(op)
+ val foldPartition = (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)
+ val mergeResult = (index: Int, taskResult: T) => jobResult = op(jobResult, taskResult)
+ sc.runJob(this, foldPartition, mergeResult)
+ jobResult
+ }
+
+ /**
+ * Aggregate the elements of each partition, and then the results for all the partitions, using
+ * given combine functions and a neutral "zero value". This function can return a different result
+ * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U
+ * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are
+ * allowed to modify and return their first argument instead of creating a new U to avoid memory
+ * allocation.
+ */
+ def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
+ // Clone the zero value since we will also be serializing it as part of tasks
+ var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
+ val cleanSeqOp = sc.clean(seqOp)
+ val cleanCombOp = sc.clean(combOp)
+ val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
+ val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
+ sc.runJob(this, aggregatePartition, mergeResult)
+ jobResult
+ }
+
+ /**
+ * Return the number of elements in the RDD.
+ */
+ def count(): Long = {
+ sc.runJob(this, (iter: Iterator[T]) => {
+ var result = 0L
+ while (iter.hasNext) {
+ result += 1L
+ iter.next()
+ }
+ result
+ }).sum
+ }
+
+ /**
+ * (Experimental) Approximate version of count() that returns a potentially incomplete result
+ * within a timeout, even if not all tasks have finished.
+ */
+ def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
+ val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) =>
+ var result = 0L
+ while (iter.hasNext) {
+ result += 1L
+ iter.next()
+ }
+ result
+ }
+ val evaluator = new CountEvaluator(partitions.size, confidence)
+ sc.runApproximateJob(this, countElements, evaluator, timeout)
+ }
+
+ /**
+ * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
+ * combine step happens locally on the master, equivalent to running a single reduce task.
+ */
+ def countByValue(): Map[T, Long] = {
+ if (elementClassManifest.erasure.isArray) {
+ throw new SparkException("countByValue() does not support arrays")
+ }
+ // TODO: This should perhaps be distributed by default.
+ def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
+ val map = new OLMap[T]
+ while (iter.hasNext) {
+ val v = iter.next()
+ map.put(v, map.getLong(v) + 1L)
+ }
+ Iterator(map)
+ }
+ def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = {
+ val iter = m2.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue)
+ }
+ return m1
+ }
+ val myResult = mapPartitions(countPartition).reduce(mergeMaps)
+ myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map
+ }
+
+ /**
+ * (Experimental) Approximate version of countByValue().
+ */
+ def countByValueApprox(
+ timeout: Long,
+ confidence: Double = 0.95
+ ): PartialResult[Map[T, BoundedDouble]] = {
+ if (elementClassManifest.erasure.isArray) {
+ throw new SparkException("countByValueApprox() does not support arrays")
+ }
+ val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
+ val map = new OLMap[T]
+ while (iter.hasNext) {
+ val v = iter.next()
+ map.put(v, map.getLong(v) + 1L)
+ }
+ map
+ }
+ val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence)
+ sc.runApproximateJob(this, countPartition, evaluator, timeout)
+ }
+
+ /**
+ * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
+ * it will be slow if a lot of partitions are required. In that case, use collect() to get the
+ * whole RDD instead.
+ */
+ def take(num: Int): Array[T] = {
+ if (num == 0) {
+ return new Array[T](0)
+ }
+ val buf = new ArrayBuffer[T]
+ var p = 0
+ while (buf.size < num && p < partitions.size) {
+ val left = num - buf.size
+ val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true)
+ buf ++= res(0)
+ if (buf.size == num)
+ return buf.toArray
+ p += 1
+ }
+ return buf.toArray
+ }
+
+ /**
+ * Return the first element in this RDD.
+ */
+ def first(): T = take(1) match {
+ case Array(t) => t
+ case _ => throw new UnsupportedOperationException("empty collection")
+ }
+
+ /**
+ * Returns the top K elements from this RDD as defined by
+ * the specified implicit Ordering[T].
+ * @param num the number of top elements to return
+ * @param ord the implicit ordering for T
+ * @return an array of top elements
+ */
+ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
+ mapPartitions { items =>
+ val queue = new BoundedPriorityQueue[T](num)
+ queue ++= items
+ Iterator.single(queue)
+ }.reduce { (queue1, queue2) =>
+ queue1 ++= queue2
+ queue1
+ }.toArray.sorted(ord.reverse)
+ }
+
+ /**
+ * Returns the first K elements from this RDD as defined by
+ * the specified implicit Ordering[T] and maintains the
+ * ordering.
+ * @param num the number of top elements to return
+ * @param ord the implicit ordering for T
+ * @return an array of top elements
+ */
+ def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = top(num)(ord.reverse)
+
+ /**
+ * Save this RDD as a text file, using string representations of elements.
+ */
+ def saveAsTextFile(path: String) {
+ this.map(x => (NullWritable.get(), new Text(x.toString)))
+ .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path)
+ }
+
+ /**
+ * Save this RDD as a compressed text file, using string representations of elements.
+ */
+ def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) {
+ this.map(x => (NullWritable.get(), new Text(x.toString)))
+ .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec)
+ }
+
+ /**
+ * Save this RDD as a SequenceFile of serialized objects.
+ */
+ def saveAsObjectFile(path: String) {
+ this.mapPartitions(iter => iter.grouped(10).map(_.toArray))
+ .map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x))))
+ .saveAsSequenceFile(path)
+ }
+
+ /**
+ * Creates tuples of the elements in this RDD by applying `f`.
+ */
+ def keyBy[K](f: T => K): RDD[(K, T)] = {
+ map(x => (f(x), x))
+ }
+
+ /** A private method for tests, to look at the contents of each partition */
+ private[spark] def collectPartitions(): Array[Array[T]] = {
+ sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
+ }
+
+ /**
+ * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+ * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * RDDs will be removed. This function must be called before any job has been
+ * executed on this RDD. It is strongly recommended that this RDD is persisted in
+ * memory, otherwise saving it on a file will require recomputation.
+ */
+ def checkpoint() {
+ if (context.checkpointDir.isEmpty) {
+ throw new Exception("Checkpoint directory has not been set in the SparkContext")
+ } else if (checkpointData.isEmpty) {
+ checkpointData = Some(new RDDCheckpointData(this))
+ checkpointData.get.markForCheckpoint()
+ }
+ }
+
+ /**
+ * Return whether this RDD has been checkpointed or not
+ */
+ def isCheckpointed: Boolean = {
+ checkpointData.map(_.isCheckpointed).getOrElse(false)
+ }
+
+ /**
+ * Gets the name of the file to which this RDD was checkpointed
+ */
+ def getCheckpointFile: Option[String] = {
+ checkpointData.flatMap(_.getCheckpointFile)
+ }
+
+ // =======================================================================
+ // Other internal methods and fields
+ // =======================================================================
+
+ private var storageLevel: StorageLevel = StorageLevel.NONE
+
+ /** Record user function generating this RDD. */
+ private[spark] val origin = Utils.formatSparkCallSite
+
+ private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+
+ private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
+
+ /** Returns the first parent RDD */
+ protected[spark] def firstParent[U: ClassManifest] = {
+ dependencies.head.rdd.asInstanceOf[RDD[U]]
+ }
+
+ /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */
+ def context = sc
+
+ // Avoid handling doCheckpoint multiple times to prevent excessive recursion
+ private var doCheckpointCalled = false
+
+ /**
+ * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
+ * after a job using this RDD has completed (therefore the RDD has been materialized and
+ * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
+ */
+ private[spark] def doCheckpoint() {
+ if (!doCheckpointCalled) {
+ doCheckpointCalled = true
+ if (checkpointData.isDefined) {
+ checkpointData.get.doCheckpoint()
+ } else {
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
+ }
+ }
+
+ /**
+ * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
+ * created from the checkpoint file, and forget its old dependencies and partitions.
+ */
+ private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
+ clearDependencies()
+ partitions_ = null
+ deps = null // Forget the constructor argument for dependencies too
+ }
+
+ /**
+ * Clears the dependencies of this RDD. This method must ensure that all references
+ * to the original parent RDDs is removed to enable the parent RDDs to be garbage
+ * collected. Subclasses of RDD may override this method for implementing their own cleaning
+ * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example.
+ */
+ protected def clearDependencies() {
+ dependencies_ = null
+ }
+
+ /** A description of this RDD and its recursive dependencies for debugging. */
+ def toDebugString: String = {
+ def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = {
+ Seq(prefix + rdd + " (" + rdd.partitions.size + " partitions)") ++
+ rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " "))
+ }
+ debugString(this).mkString("\n")
+ }
+
+ override def toString: String = "%s%s[%d] at %s".format(
+ Option(name).map(_ + " ").getOrElse(""),
+ getClass.getSimpleName,
+ id,
+ origin)
+
+ def toJavaRDD() : JavaRDD[T] = {
+ new JavaRDD(this)(elementClassManifest)
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/RDDCheckpointData.scala
new file mode 100644
index 0000000000..0334de6924
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/RDDCheckpointData.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
+import rdd.{CheckpointRDD, CoalescedRDD}
+import scheduler.{ResultTask, ShuffleMapTask}
+
+/**
+ * Enumeration to manage state transitions of an RDD through checkpointing
+ * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
+ */
+private[spark] object CheckpointState extends Enumeration {
+ type CheckpointState = Value
+ val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
+}
+
+/**
+ * This class contains all the information related to RDD checkpointing. Each instance of this class
+ * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as,
+ * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations
+ * of the checkpointed RDD.
+ */
+private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
+ extends Logging with Serializable {
+
+ import CheckpointState._
+
+ // The checkpoint state of the associated RDD.
+ var cpState = Initialized
+
+ // The file to which the associated RDD has been checkpointed to
+ @transient var cpFile: Option[String] = None
+
+ // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
+ var cpRDD: Option[RDD[T]] = None
+
+ // Mark the RDD for checkpointing
+ def markForCheckpoint() {
+ RDDCheckpointData.synchronized {
+ if (cpState == Initialized) cpState = MarkedForCheckpoint
+ }
+ }
+
+ // Is the RDD already checkpointed
+ def isCheckpointed: Boolean = {
+ RDDCheckpointData.synchronized { cpState == Checkpointed }
+ }
+
+ // Get the file to which this RDD was checkpointed to as an Option
+ def getCheckpointFile: Option[String] = {
+ RDDCheckpointData.synchronized { cpFile }
+ }
+
+ // Do the checkpointing of the RDD. Called after the first job using that RDD is over.
+ def doCheckpoint() {
+ // If it is marked for checkpointing AND checkpointing is not already in progress,
+ // then set it to be in progress, else return
+ RDDCheckpointData.synchronized {
+ if (cpState == MarkedForCheckpoint) {
+ cpState = CheckpointingInProgress
+ } else {
+ return
+ }
+ }
+
+ // Create the output path for the checkpoint
+ val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
+ val fs = path.getFileSystem(new Configuration())
+ if (!fs.mkdirs(path)) {
+ throw new SparkException("Failed to create checkpoint path " + path)
+ }
+
+ // Save to file, and reload it as an RDD
+ rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _)
+ val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
+
+ // Change the dependencies and partitions of the RDD
+ RDDCheckpointData.synchronized {
+ cpFile = Some(path.toString)
+ cpRDD = Some(newRDD)
+ rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
+ cpState = Checkpointed
+ RDDCheckpointData.clearTaskCaches()
+ logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
+ }
+ }
+
+ // Get preferred location of a split after checkpointing
+ def getPreferredLocations(split: Partition): Seq[String] = {
+ RDDCheckpointData.synchronized {
+ cpRDD.get.preferredLocations(split)
+ }
+ }
+
+ def getPartitions: Array[Partition] = {
+ RDDCheckpointData.synchronized {
+ cpRDD.get.partitions
+ }
+ }
+
+ def checkpointRDD: Option[RDD[T]] = {
+ RDDCheckpointData.synchronized {
+ cpRDD
+ }
+ }
+}
+
+private[spark] object RDDCheckpointData {
+ def clearTaskCaches() {
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/SequenceFileRDDFunctions.scala
new file mode 100644
index 0000000000..d58fb4e4bc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SequenceFileRDDFunctions.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.EOFException
+import java.net.URL
+import java.io.ObjectInputStream
+import java.util.concurrent.atomic.AtomicLong
+import java.util.HashSet
+import java.util.Random
+import java.util.Date
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.Map
+import scala.collection.mutable.HashMap
+
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.OutputFormat
+import org.apache.hadoop.mapred.TextOutputFormat
+import org.apache.hadoop.mapred.SequenceFileOutputFormat
+import org.apache.hadoop.mapred.OutputCommitter
+import org.apache.hadoop.mapred.FileOutputCommitter
+import org.apache.hadoop.io.compress.CompressionCodec
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.io.NullWritable
+import org.apache.hadoop.io.BytesWritable
+import org.apache.hadoop.io.Text
+
+import org.apache.spark.SparkContext._
+
+/**
+ * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile,
+ * through an implicit conversion. Note that this can't be part of PairRDDFunctions because
+ * we need more implicit parameters to convert our keys and values to Writable.
+ *
+ * Users should import `spark.SparkContext._` at the top of their program to use these functions.
+ */
+class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest](
+ self: RDD[(K, V)])
+ extends Logging
+ with Serializable {
+
+ private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = {
+ val c = {
+ if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
+ classManifest[T].erasure
+ } else {
+ // We get the type of the Writable class by looking at the apply method which converts
+ // from T to Writable. Since we have two apply methods we filter out the one which
+ // is not of the form "java.lang.Object apply(java.lang.Object)"
+ implicitly[T => Writable].getClass.getDeclaredMethods().filter(
+ m => m.getReturnType().toString != "class java.lang.Object" &&
+ m.getName() == "apply")(0).getReturnType
+
+ }
+ // TODO: use something like WritableConverter to avoid reflection
+ }
+ c.asInstanceOf[Class[_ <: Writable]]
+ }
+
+ /**
+ * Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key
+ * and value types. If the key or value are Writable, then we use their classes directly;
+ * otherwise we map primitive types such as Int and Double to IntWritable, DoubleWritable, etc,
+ * byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
+ * file system.
+ */
+ def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
+ def anyToWritable[U <% Writable](u: U): Writable = u
+
+ val keyClass = getWritableClass[K]
+ val valueClass = getWritableClass[V]
+ val convertKey = !classOf[Writable].isAssignableFrom(self.getKeyClass)
+ val convertValue = !classOf[Writable].isAssignableFrom(self.getValueClass)
+
+ logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" )
+ val format = classOf[SequenceFileOutputFormat[Writable, Writable]]
+ val jobConf = new JobConf(self.context.hadoopConfiguration)
+ if (!convertKey && !convertValue) {
+ self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec)
+ } else if (!convertKey && convertValue) {
+ self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
+ } else if (convertKey && !convertValue) {
+ self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
+ } else if (convertKey && convertValue) {
+ self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(
+ path, keyClass, valueClass, format, jobConf, codec)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala
new file mode 100644
index 0000000000..fdd4c24e23
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io._
+
+import org.apache.hadoop.io.ObjectWritable
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.conf.Configuration
+
+class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable {
+ def value = t
+ override def toString = t.toString
+
+ private def writeObject(out: ObjectOutputStream) {
+ out.defaultWriteObject()
+ new ObjectWritable(t).write(out)
+ }
+
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ val ow = new ObjectWritable()
+ ow.setConf(new Configuration())
+ ow.readFields(in)
+ t = ow.get().asInstanceOf[T]
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
new file mode 100644
index 0000000000..307c383a89
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.serializer.Serializer
+
+
+private[spark] abstract class ShuffleFetcher {
+
+ /**
+ * Fetch the shuffle outputs for a given ShuffleDependency.
+ * @return An iterator over the elements of the fetched shuffle outputs.
+ */
+ def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
+ serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
+
+ /** Stop the fetcher */
+ def stop() {}
+}
diff --git a/core/src/main/scala/org/apache/spark/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/SizeEstimator.scala
new file mode 100644
index 0000000000..4bfc837710
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SizeEstimator.scala
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.lang.reflect.Field
+import java.lang.reflect.Modifier
+import java.lang.reflect.{Array => JArray}
+import java.util.IdentityHashMap
+import java.util.concurrent.ConcurrentHashMap
+import java.util.Random
+
+import javax.management.MBeanServer
+import java.lang.management.ManagementFactory
+
+import scala.collection.mutable.ArrayBuffer
+
+import it.unimi.dsi.fastutil.ints.IntOpenHashSet
+
+/**
+ * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in
+ * memory-aware caches.
+ *
+ * Based on the following JavaWorld article:
+ * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html
+ */
+private[spark] object SizeEstimator extends Logging {
+
+ // Sizes of primitive types
+ private val BYTE_SIZE = 1
+ private val BOOLEAN_SIZE = 1
+ private val CHAR_SIZE = 2
+ private val SHORT_SIZE = 2
+ private val INT_SIZE = 4
+ private val LONG_SIZE = 8
+ private val FLOAT_SIZE = 4
+ private val DOUBLE_SIZE = 8
+
+ // Alignment boundary for objects
+ // TODO: Is this arch dependent ?
+ private val ALIGN_SIZE = 8
+
+ // A cache of ClassInfo objects for each class
+ private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo]
+
+ // Object and pointer sizes are arch dependent
+ private var is64bit = false
+
+ // Size of an object reference
+ // Based on https://wikis.oracle.com/display/HotSpotInternals/CompressedOops
+ private var isCompressedOops = false
+ private var pointerSize = 4
+
+ // Minimum size of a java.lang.Object
+ private var objectSize = 8
+
+ initialize()
+
+ // Sets object size, pointer size based on architecture and CompressedOops settings
+ // from the JVM.
+ private def initialize() {
+ is64bit = System.getProperty("os.arch").contains("64")
+ isCompressedOops = getIsCompressedOops
+
+ objectSize = if (!is64bit) 8 else {
+ if(!isCompressedOops) {
+ 16
+ } else {
+ 12
+ }
+ }
+ pointerSize = if (is64bit && !isCompressedOops) 8 else 4
+ classInfos.clear()
+ classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil))
+ }
+
+ private def getIsCompressedOops : Boolean = {
+ if (System.getProperty("spark.test.useCompressedOops") != null) {
+ return System.getProperty("spark.test.useCompressedOops").toBoolean
+ }
+
+ try {
+ val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"
+ val server = ManagementFactory.getPlatformMBeanServer()
+
+ // NOTE: This should throw an exception in non-Sun JVMs
+ val hotSpotMBeanClass = Class.forName("com.sun.management.HotSpotDiagnosticMXBean")
+ val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption",
+ Class.forName("java.lang.String"))
+
+ val bean = ManagementFactory.newPlatformMXBeanProxy(server,
+ hotSpotMBeanName, hotSpotMBeanClass)
+ // TODO: We could use reflection on the VMOption returned ?
+ return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true")
+ } catch {
+ case e: Exception => {
+ // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
+ val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
+ val guessInWords = if (guess) "yes" else "not"
+ logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords)
+ return guess
+ }
+ }
+ }
+
+ /**
+ * The state of an ongoing size estimation. Contains a stack of objects to visit as well as an
+ * IdentityHashMap of visited objects, and provides utility methods for enqueueing new objects
+ * to visit.
+ */
+ private class SearchState(val visited: IdentityHashMap[AnyRef, AnyRef]) {
+ val stack = new ArrayBuffer[AnyRef]
+ var size = 0L
+
+ def enqueue(obj: AnyRef) {
+ if (obj != null && !visited.containsKey(obj)) {
+ visited.put(obj, null)
+ stack += obj
+ }
+ }
+
+ def isFinished(): Boolean = stack.isEmpty
+
+ def dequeue(): AnyRef = {
+ val elem = stack.last
+ stack.trimEnd(1)
+ return elem
+ }
+ }
+
+ /**
+ * Cached information about each class. We remember two things: the "shell size" of the class
+ * (size of all non-static fields plus the java.lang.Object size), and any fields that are
+ * pointers to objects.
+ */
+ private class ClassInfo(
+ val shellSize: Long,
+ val pointerFields: List[Field]) {}
+
+ def estimate(obj: AnyRef): Long = estimate(obj, new IdentityHashMap[AnyRef, AnyRef])
+
+ private def estimate(obj: AnyRef, visited: IdentityHashMap[AnyRef, AnyRef]): Long = {
+ val state = new SearchState(visited)
+ state.enqueue(obj)
+ while (!state.isFinished) {
+ visitSingleObject(state.dequeue(), state)
+ }
+ return state.size
+ }
+
+ private def visitSingleObject(obj: AnyRef, state: SearchState) {
+ val cls = obj.getClass
+ if (cls.isArray) {
+ visitArray(obj, cls, state)
+ } else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) {
+ // Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses
+ // the size estimator since it references the whole REPL. Do nothing in this case. In
+ // general all ClassLoaders and Classes will be shared between objects anyway.
+ } else {
+ val classInfo = getClassInfo(cls)
+ state.size += classInfo.shellSize
+ for (field <- classInfo.pointerFields) {
+ state.enqueue(field.get(obj))
+ }
+ }
+ }
+
+ // Estimat the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling.
+ private val ARRAY_SIZE_FOR_SAMPLING = 200
+ private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING
+
+ private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) {
+ val length = JArray.getLength(array)
+ val elementClass = cls.getComponentType
+
+ // Arrays have object header and length field which is an integer
+ var arrSize: Long = alignSize(objectSize + INT_SIZE)
+
+ if (elementClass.isPrimitive) {
+ arrSize += alignSize(length * primitiveSize(elementClass))
+ state.size += arrSize
+ } else {
+ arrSize += alignSize(length * pointerSize)
+ state.size += arrSize
+
+ if (length <= ARRAY_SIZE_FOR_SAMPLING) {
+ for (i <- 0 until length) {
+ state.enqueue(JArray.get(array, i))
+ }
+ } else {
+ // Estimate the size of a large array by sampling elements without replacement.
+ var size = 0.0
+ val rand = new Random(42)
+ val drawn = new IntOpenHashSet(ARRAY_SAMPLE_SIZE)
+ for (i <- 0 until ARRAY_SAMPLE_SIZE) {
+ var index = 0
+ do {
+ index = rand.nextInt(length)
+ } while (drawn.contains(index))
+ drawn.add(index)
+ val elem = JArray.get(array, index)
+ size += SizeEstimator.estimate(elem, state.visited)
+ }
+ state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong
+ }
+ }
+ }
+
+ private def primitiveSize(cls: Class[_]): Long = {
+ if (cls == classOf[Byte])
+ BYTE_SIZE
+ else if (cls == classOf[Boolean])
+ BOOLEAN_SIZE
+ else if (cls == classOf[Char])
+ CHAR_SIZE
+ else if (cls == classOf[Short])
+ SHORT_SIZE
+ else if (cls == classOf[Int])
+ INT_SIZE
+ else if (cls == classOf[Long])
+ LONG_SIZE
+ else if (cls == classOf[Float])
+ FLOAT_SIZE
+ else if (cls == classOf[Double])
+ DOUBLE_SIZE
+ else throw new IllegalArgumentException(
+ "Non-primitive class " + cls + " passed to primitiveSize()")
+ }
+
+ /**
+ * Get or compute the ClassInfo for a given class.
+ */
+ private def getClassInfo(cls: Class[_]): ClassInfo = {
+ // Check whether we've already cached a ClassInfo for this class
+ val info = classInfos.get(cls)
+ if (info != null) {
+ return info
+ }
+
+ val parent = getClassInfo(cls.getSuperclass)
+ var shellSize = parent.shellSize
+ var pointerFields = parent.pointerFields
+
+ for (field <- cls.getDeclaredFields) {
+ if (!Modifier.isStatic(field.getModifiers)) {
+ val fieldClass = field.getType
+ if (fieldClass.isPrimitive) {
+ shellSize += primitiveSize(fieldClass)
+ } else {
+ field.setAccessible(true) // Enable future get()'s on this field
+ shellSize += pointerSize
+ pointerFields = field :: pointerFields
+ }
+ }
+ }
+
+ shellSize = alignSize(shellSize)
+
+ // Create and cache a new ClassInfo
+ val newInfo = new ClassInfo(shellSize, pointerFields)
+ classInfos.put(cls, newInfo)
+ return newInfo
+ }
+
+ private def alignSize(size: Long): Long = {
+ val rem = size % ALIGN_SIZE
+ return if (rem == 0) size else (size + ALIGN_SIZE - rem)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
new file mode 100644
index 0000000000..1207b242bc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -0,0 +1,995 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io._
+import java.net.URI
+import java.util.Properties
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.Map
+import scala.collection.generic.Growable
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.util.DynamicVariable
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.ArrayWritable
+import org.apache.hadoop.io.BooleanWritable
+import org.apache.hadoop.io.BytesWritable
+import org.apache.hadoop.io.DoubleWritable
+import org.apache.hadoop.io.FloatWritable
+import org.apache.hadoop.io.IntWritable
+import org.apache.hadoop.io.LongWritable
+import org.apache.hadoop.io.NullWritable
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapred.FileInputFormat
+import org.apache.hadoop.mapred.InputFormat
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.SequenceFileInputFormat
+import org.apache.hadoop.mapred.TextInputFormat
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
+import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
+
+import org.apache.mesos.MesosNativeLibrary
+
+import org.apache.spark.deploy.LocalSparkCluster
+import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
+import org.apache.spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD,
+ OrderedRDDFunctions}
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend,
+ ClusterScheduler, Schedulable, SchedulingMode}
+import org.apache.spark.scheduler.local.LocalScheduler
+import org.apache.spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
+import org.apache.spark.storage.{StorageStatus, StorageUtils, RDDInfo, BlockManagerSource}
+import org.apache.spark.ui.SparkUI
+import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
+import scala.Some
+import org.apache.spark.scheduler.StageInfo
+import org.apache.spark.storage.RDDInfo
+import org.apache.spark.storage.StorageStatus
+
+/**
+ * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
+ * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
+ *
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param appName A name for your application, to display on the cluster web UI.
+ * @param sparkHome Location where Spark is installed on cluster nodes.
+ * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
+ * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ * @param environment Environment variables to set on worker nodes.
+ */
+class SparkContext(
+ val master: String,
+ val appName: String,
+ val sparkHome: String = null,
+ val jars: Seq[String] = Nil,
+ val environment: Map[String, String] = Map(),
+ // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too.
+ // This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host
+ val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map())
+ extends Logging {
+
+ // Ensure logging is initialized before we spawn any threads
+ initLogging()
+
+ // Set Spark driver host and port system properties
+ if (System.getProperty("spark.driver.host") == null) {
+ System.setProperty("spark.driver.host", Utils.localHostName())
+ }
+ if (System.getProperty("spark.driver.port") == null) {
+ System.setProperty("spark.driver.port", "0")
+ }
+
+ val isLocal = (master == "local" || master.startsWith("local["))
+
+ // Create the Spark execution environment (cache, map output tracker, etc)
+ private[spark] val env = SparkEnv.createFromSystemProperties(
+ "<driver>",
+ System.getProperty("spark.driver.host"),
+ System.getProperty("spark.driver.port").toInt,
+ true,
+ isLocal)
+ SparkEnv.set(env)
+
+ // Used to store a URL for each static file/jar together with the file's local timestamp
+ private[spark] val addedFiles = HashMap[String, Long]()
+ private[spark] val addedJars = HashMap[String, Long]()
+
+ // Keeps track of all persisted RDDs
+ private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
+ private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
+
+ // Initalize the Spark UI
+ private[spark] val ui = new SparkUI(this)
+ ui.bind()
+
+ val startTime = System.currentTimeMillis()
+
+ // Add each JAR given through the constructor
+ if (jars != null) {
+ jars.foreach { addJar(_) }
+ }
+
+ // Environment variables to pass to our executors
+ private[spark] val executorEnvs = HashMap[String, String]()
+ // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
+ for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) {
+ val value = System.getenv(key)
+ if (value != null) {
+ executorEnvs(key) = value
+ }
+ }
+ // Since memory can be set with a system property too, use that
+ executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m"
+ if (environment != null) {
+ executorEnvs ++= environment
+ }
+
+ // Create and start the scheduler
+ private var taskScheduler: TaskScheduler = {
+ // Regular expression used for local[N] master format
+ val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
+ // Regular expression for local[N, maxRetries], used in tests with failing tasks
+ val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
+ // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
+ val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
+ // Regular expression for connecting to Spark deploy clusters
+ val SPARK_REGEX = """(spark://.*)""".r
+ //Regular expression for connection to Mesos cluster
+ val MESOS_REGEX = """(mesos://.*)""".r
+
+ master match {
+ case "local" =>
+ new LocalScheduler(1, 0, this)
+
+ case LOCAL_N_REGEX(threads) =>
+ new LocalScheduler(threads.toInt, 0, this)
+
+ case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
+ new LocalScheduler(threads.toInt, maxFailures.toInt, this)
+
+ case SPARK_REGEX(sparkUrl) =>
+ val scheduler = new ClusterScheduler(this)
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
+ scheduler.initialize(backend)
+ scheduler
+
+ case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
+ // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
+ val memoryPerSlaveInt = memoryPerSlave.toInt
+ if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) {
+ throw new SparkException(
+ "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
+ memoryPerSlaveInt, SparkContext.executorMemoryRequested))
+ }
+
+ val scheduler = new ClusterScheduler(this)
+ val localCluster = new LocalSparkCluster(
+ numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
+ val sparkUrl = localCluster.start()
+ val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
+ scheduler.initialize(backend)
+ backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
+ localCluster.stop()
+ }
+ scheduler
+
+ case "yarn-standalone" =>
+ val scheduler = try {
+ val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(this).asInstanceOf[ClusterScheduler]
+ } catch {
+ // TODO: Enumerate the exact reasons why it can fail
+ // But irrespective of it, it means we cannot proceed !
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+ val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
+ scheduler.initialize(backend)
+ scheduler
+
+ case _ =>
+ if (MESOS_REGEX.findFirstIn(master).isEmpty) {
+ logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
+ }
+ MesosNativeLibrary.load()
+ val scheduler = new ClusterScheduler(this)
+ val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
+ val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
+ val backend = if (coarseGrained) {
+ new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ } else {
+ new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ }
+ scheduler.initialize(backend)
+ scheduler
+ }
+ }
+ taskScheduler.start()
+
+ @volatile private var dagScheduler = new DAGScheduler(taskScheduler)
+ dagScheduler.start()
+
+ ui.start()
+
+ /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
+ val hadoopConfiguration = {
+ val env = SparkEnv.get
+ val conf = env.hadoop.newConfiguration()
+ // Explicitly check for S3 environment variables
+ if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
+ conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
+ conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
+ }
+ // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
+ for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
+ conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536")
+ conf.set("io.file.buffer.size", bufferSize)
+ conf
+ }
+
+ private[spark] var checkpointDir: Option[String] = None
+
+ // Thread Local variable that can be used by users to pass information down the stack
+ private val localProperties = new DynamicVariable[Properties](null)
+
+ def initLocalProperties() {
+ localProperties.value = new Properties()
+ }
+
+ def setLocalProperty(key: String, value: String) {
+ if (localProperties.value == null) {
+ localProperties.value = new Properties()
+ }
+ if (value == null) {
+ localProperties.value.remove(key)
+ } else {
+ localProperties.value.setProperty(key, value)
+ }
+ }
+
+ /** Set a human readable description of the current job. */
+ def setJobDescription(value: String) {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
+ }
+
+ // Post init
+ taskScheduler.postStartHook()
+
+ val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler)
+ val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager)
+
+ def initDriverMetrics() {
+ SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
+ SparkEnv.get.metricsSystem.registerSource(blockManagerSource)
+ }
+
+ initDriverMetrics()
+
+ // Methods for creating RDDs
+
+ /** Distribute a local Scala collection to form an RDD. */
+ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
+ }
+
+ /** Distribute a local Scala collection to form an RDD. */
+ def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ parallelize(seq, numSlices)
+ }
+
+ /** Distribute a local Scala collection to form an RDD, with one or more
+ * location preferences (hostnames of Spark nodes) for each object.
+ * Create a new partition for each collection item. */
+ def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
+ new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
+ }
+
+ /**
+ * Read a text file from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI, and return it as an RDD of Strings.
+ */
+ def textFile(path: String, minSplits: Int = defaultMinSplits): RDD[String] = {
+ hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minSplits)
+ .map(pair => pair._2.toString)
+ }
+
+ /**
+ * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any
+ * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
+ * etc).
+ */
+ def hadoopRDD[K, V](
+ conf: JobConf,
+ inputFormatClass: Class[_ <: InputFormat[K, V]],
+ keyClass: Class[K],
+ valueClass: Class[V],
+ minSplits: Int = defaultMinSplits
+ ): RDD[(K, V)] = {
+ new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
+ }
+
+ /** Get an RDD for a Hadoop file with an arbitrary InputFormat */
+ def hadoopFile[K, V](
+ path: String,
+ inputFormatClass: Class[_ <: InputFormat[K, V]],
+ keyClass: Class[K],
+ valueClass: Class[V],
+ minSplits: Int = defaultMinSplits
+ ) : RDD[(K, V)] = {
+ val conf = new JobConf(hadoopConfiguration)
+ FileInputFormat.setInputPaths(conf, path)
+ new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
+ }
+
+ /**
+ * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * values and the InputFormat so that users don't need to pass them directly. Instead, callers
+ * can just write, for example,
+ * {{{
+ * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits)
+ * }}}
+ */
+ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
+ (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
+ : RDD[(K, V)] = {
+ hadoopFile(path,
+ fm.erasure.asInstanceOf[Class[F]],
+ km.erasure.asInstanceOf[Class[K]],
+ vm.erasure.asInstanceOf[Class[V]],
+ minSplits)
+ }
+
+ /**
+ * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * values and the InputFormat so that users don't need to pass them directly. Instead, callers
+ * can just write, for example,
+ * {{{
+ * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path)
+ * }}}
+ */
+ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
+ (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] =
+ hadoopFile[K, V, F](path, defaultMinSplits)
+
+ /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
+ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String)
+ (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = {
+ newAPIHadoopFile(
+ path,
+ fm.erasure.asInstanceOf[Class[F]],
+ km.erasure.asInstanceOf[Class[K]],
+ vm.erasure.asInstanceOf[Class[V]])
+ }
+
+ /**
+ * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
+ * and extra configuration options to pass to the input format.
+ */
+ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](
+ path: String,
+ fClass: Class[F],
+ kClass: Class[K],
+ vClass: Class[V],
+ conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
+ val job = new NewHadoopJob(conf)
+ NewFileInputFormat.addInputPath(job, new Path(path))
+ val updatedConf = job.getConfiguration
+ new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf)
+ }
+
+ /**
+ * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
+ * and extra configuration options to pass to the input format.
+ */
+ def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
+ conf: Configuration = hadoopConfiguration,
+ fClass: Class[F],
+ kClass: Class[K],
+ vClass: Class[V]): RDD[(K, V)] = {
+ new NewHadoopRDD(this, fClass, kClass, vClass, conf)
+ }
+
+ /** Get an RDD for a Hadoop SequenceFile with given key and value types. */
+ def sequenceFile[K, V](path: String,
+ keyClass: Class[K],
+ valueClass: Class[V],
+ minSplits: Int
+ ): RDD[(K, V)] = {
+ val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
+ hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)
+ }
+
+ /** Get an RDD for a Hadoop SequenceFile with given key and value types. */
+ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] =
+ sequenceFile(path, keyClass, valueClass, defaultMinSplits)
+
+ /**
+ * Version of sequenceFile() for types implicitly convertible to Writables through a
+ * WritableConverter. For example, to access a SequenceFile where the keys are Text and the
+ * values are IntWritable, you could simply write
+ * {{{
+ * sparkContext.sequenceFile[String, Int](path, ...)
+ * }}}
+ *
+ * WritableConverters are provided in a somewhat strange way (by an implicit function) to support
+ * both subclasses of Writable and types for which we define a converter (e.g. Int to
+ * IntWritable). The most natural thing would've been to have implicit objects for the
+ * converters, but then we couldn't have an object for every subclass of Writable (you can't
+ * have a parameterized singleton object). We use functions instead to create a new converter
+ * for the appropriate type. In addition, we pass the converter a ClassManifest of its type to
+ * allow it to figure out the Writable class to use in the subclass case.
+ */
+ def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits)
+ (implicit km: ClassManifest[K], vm: ClassManifest[V],
+ kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
+ : RDD[(K, V)] = {
+ val kc = kcf()
+ val vc = vcf()
+ val format = classOf[SequenceFileInputFormat[Writable, Writable]]
+ val writables = hadoopFile(path, format,
+ kc.writableClass(km).asInstanceOf[Class[Writable]],
+ vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits)
+ writables.map{case (k,v) => (kc.convert(k), vc.convert(v))}
+ }
+
+ /**
+ * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and
+ * BytesWritable values that contain a serialized partition. This is still an experimental storage
+ * format and may not be supported exactly as is in future Spark releases. It will also be pretty
+ * slow if you use the default serializer (Java serialization), though the nice thing about it is
+ * that there's very little effort required to save arbitrary objects.
+ */
+ def objectFile[T: ClassManifest](
+ path: String,
+ minSplits: Int = defaultMinSplits
+ ): RDD[T] = {
+ sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits)
+ .flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes))
+ }
+
+
+ protected[spark] def checkpointFile[T: ClassManifest](
+ path: String
+ ): RDD[T] = {
+ new CheckpointRDD[T](this, path)
+ }
+
+ /** Build the union of a list of RDDs. */
+ def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
+
+ /** Build the union of a list of RDDs passed as variable-length arguments. */
+ def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] =
+ new UnionRDD(this, Seq(first) ++ rest)
+
+ // Methods for creating shared variables
+
+ /**
+ * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" values
+ * to using the `+=` method. Only the driver can access the accumulator's `value`.
+ */
+ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
+ new Accumulator(initialValue, param)
+
+ /**
+ * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
+ * Only the driver can access the accumuable's `value`.
+ * @tparam T accumulator type
+ * @tparam R type that can be added to the accumulator
+ */
+ def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
+ new Accumulable(initialValue, param)
+
+ /**
+ * Create an accumulator from a "mutable collection" type.
+ *
+ * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
+ * standard mutable collections. So you can use this with mutable Map, Set, etc.
+ */
+ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
+ val param = new GrowableAccumulableParam[R,T]
+ new Accumulable(initialValue, param)
+ }
+
+ /**
+ * Broadcast a read-only variable to the cluster, returning a [[org.apache.spark.broadcast.Broadcast]] object for
+ * reading it in distributed functions. The variable will be sent to each cluster only once.
+ */
+ def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
+
+ /**
+ * Add a file to be downloaded with this Spark job on every node.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
+ */
+ def addFile(path: String) {
+ val uri = new URI(path)
+ val key = uri.getScheme match {
+ case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
+ case _ => path
+ }
+ addedFiles(key) = System.currentTimeMillis
+
+ // Fetch the file locally in case a job is executed locally.
+ // Jobs that run through LocalScheduler will already fetch the required dependencies,
+ // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
+ Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
+
+ logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
+ }
+
+ def addSparkListener(listener: SparkListener) {
+ dagScheduler.addSparkListener(listener)
+ }
+
+ /**
+ * Return a map from the slave to the max memory available for caching and the remaining
+ * memory available for caching.
+ */
+ def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
+ env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
+ (blockManagerId.host + ":" + blockManagerId.port, mem)
+ }
+ }
+
+ /**
+ * Return information about what RDDs are cached, if they are in mem or on disk, how much space
+ * they take, etc.
+ */
+ def getRDDStorageInfo: Array[RDDInfo] = {
+ StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
+ }
+
+ /**
+ * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call.
+ * Note that this does not necessarily mean the caching or computation was successful.
+ */
+ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap
+
+ def getStageInfo: Map[Stage,StageInfo] = {
+ dagScheduler.stageToInfos
+ }
+
+ /**
+ * Return information about blocks stored in all of the slaves
+ */
+ def getExecutorStorageStatus: Array[StorageStatus] = {
+ env.blockManager.master.getStorageStatus
+ }
+
+ /**
+ * Return pools for fair scheduler
+ * TODO(xiajunluan): We should take nested pools into account
+ */
+ def getAllPools: ArrayBuffer[Schedulable] = {
+ taskScheduler.rootPool.schedulableQueue
+ }
+
+ /**
+ * Return the pool associated with the given name, if one exists
+ */
+ def getPoolForName(pool: String): Option[Schedulable] = {
+ taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)
+ }
+
+ /**
+ * Return current scheduling mode
+ */
+ def getSchedulingMode: SchedulingMode.SchedulingMode = {
+ taskScheduler.schedulingMode
+ }
+
+ /**
+ * Clear the job's list of files added by `addFile` so that they do not get downloaded to
+ * any new nodes.
+ */
+ def clearFiles() {
+ addedFiles.clear()
+ }
+
+ /**
+ * Gets the locality information associated with the partition in a particular rdd
+ * @param rdd of interest
+ * @param partition to be looked up for locality
+ * @return list of preferred locations for the partition
+ */
+ private [spark] def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
+ dagScheduler.getPreferredLocs(rdd, partition)
+ }
+
+ /**
+ * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI.
+ */
+ def addJar(path: String) {
+ if (null == path) {
+ logWarning("null specified as parameter to addJar",
+ new SparkException("null specified as parameter to addJar"))
+ } else {
+ val env = SparkEnv.get
+ val uri = new URI(path)
+ val key = uri.getScheme match {
+ case null | "file" =>
+ if (env.hadoop.isYarnMode()) {
+ logWarning("local jar specified as parameter to addJar under Yarn mode")
+ return
+ }
+ env.httpFileServer.addJar(new File(uri.getPath))
+ case _ => path
+ }
+ addedJars(key) = System.currentTimeMillis
+ logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
+ }
+ }
+
+ /**
+ * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
+ * any new nodes.
+ */
+ def clearJars() {
+ addedJars.clear()
+ }
+
+ /** Shut down the SparkContext. */
+ def stop() {
+ ui.stop()
+ // Do this only if not stopped already - best case effort.
+ // prevent NPE if stopped more than once.
+ val dagSchedulerCopy = dagScheduler
+ dagScheduler = null
+ if (dagSchedulerCopy != null) {
+ metadataCleaner.cancel()
+ dagSchedulerCopy.stop()
+ taskScheduler = null
+ // TODO: Cache.stop()?
+ env.stop()
+ // Clean up locally linked files
+ clearFiles()
+ clearJars()
+ SparkEnv.set(null)
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
+ logInfo("Successfully stopped SparkContext")
+ } else {
+ logInfo("SparkContext already stopped")
+ }
+ }
+
+
+ /**
+ * Get Spark's home location from either a value set through the constructor,
+ * or the spark.home Java property, or the SPARK_HOME environment variable
+ * (in that order of preference). If neither of these is set, return None.
+ */
+ private[spark] def getSparkHome(): Option[String] = {
+ if (sparkHome != null) {
+ Some(sparkHome)
+ } else if (System.getProperty("spark.home") != null) {
+ Some(System.getProperty("spark.home"))
+ } else if (System.getenv("SPARK_HOME") != null) {
+ Some(System.getenv("SPARK_HOME"))
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Run a function on a given set of partitions in an RDD and pass the results to the given
+ * handler function. This is the main entry point for all actions in Spark. The allowLocal
+ * flag specifies whether the scheduler can run the computation on the driver rather than
+ * shipping it out to the cluster, for short actions like first().
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit) {
+ val callSite = Utils.formatSparkCallSite
+ logInfo("Starting job: " + callSite)
+ val start = System.nanoTime
+ val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value)
+ logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
+ rdd.doCheckpoint()
+ result
+ }
+
+ /**
+ * Run a function on a given set of partitions in an RDD and return the results as an array. The
+ * allowLocal flag specifies whether the scheduler can run the computation on the driver rather
+ * than shipping it out to the cluster, for short actions like first().
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean
+ ): Array[U] = {
+ val results = new Array[U](partitions.size)
+ runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
+ results
+ }
+
+ /**
+ * Run a job on a given set of partitions of an RDD, but take a function of type
+ * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ func: Iterator[T] => U,
+ partitions: Seq[Int],
+ allowLocal: Boolean
+ ): Array[U] = {
+ runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal)
+ }
+
+ /**
+ * Run a job on all partitions in an RDD and return the results in an array.
+ */
+ def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
+ runJob(rdd, func, 0 until rdd.partitions.size, false)
+ }
+
+ /**
+ * Run a job on all partitions in an RDD and return the results in an array.
+ */
+ def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
+ runJob(rdd, func, 0 until rdd.partitions.size, false)
+ }
+
+ /**
+ * Run a job on all partitions in an RDD and pass the results to a handler function.
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ processPartition: (TaskContext, Iterator[T]) => U,
+ resultHandler: (Int, U) => Unit)
+ {
+ runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler)
+ }
+
+ /**
+ * Run a job on all partitions in an RDD and pass the results to a handler function.
+ */
+ def runJob[T, U: ClassManifest](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ resultHandler: (Int, U) => Unit)
+ {
+ val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
+ runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler)
+ }
+
+ /**
+ * Run a job that can return approximate results.
+ */
+ def runApproximateJob[T, U, R](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ evaluator: ApproximateEvaluator[U, R],
+ timeout: Long): PartialResult[R] = {
+ val callSite = Utils.formatSparkCallSite
+ logInfo("Starting job: " + callSite)
+ val start = System.nanoTime
+ val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)
+ logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
+ result
+ }
+
+ /**
+ * Clean a closure to make it ready to serialized and send to tasks
+ * (removes unreferenced variables in $outer's, updates REPL variables)
+ */
+ private[spark] def clean[F <: AnyRef](f: F): F = {
+ ClosureCleaner.clean(f)
+ return f
+ }
+
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists and useExisting is set to true, then the
+ * exisiting directory will be used. Otherwise an exception will be thrown to
+ * prevent accidental overriding of checkpoint files in the existing directory.
+ */
+ def setCheckpointDir(dir: String, useExisting: Boolean = false) {
+ val env = SparkEnv.get
+ val path = new Path(dir)
+ val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ if (!useExisting) {
+ if (fs.exists(path)) {
+ throw new Exception("Checkpoint directory '" + path + "' already exists.")
+ } else {
+ fs.mkdirs(path)
+ }
+ }
+ checkpointDir = Some(dir)
+ }
+
+ /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
+ def defaultParallelism: Int = taskScheduler.defaultParallelism
+
+ /** Default min number of partitions for Hadoop RDDs when not given by user */
+ def defaultMinSplits: Int = math.min(defaultParallelism, 2)
+
+ private val nextShuffleId = new AtomicInteger(0)
+
+ private[spark] def newShuffleId(): Int = nextShuffleId.getAndIncrement()
+
+ private val nextRddId = new AtomicInteger(0)
+
+ /** Register a new RDD, returning its RDD ID */
+ private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
+
+ /** Called by MetadataCleaner to clean up the persistentRdds map periodically */
+ private[spark] def cleanup(cleanupTime: Long) {
+ persistentRdds.clearOldValues(cleanupTime)
+ }
+}
+
+/**
+ * The SparkContext object contains a number of implicit conversions and parameters for use with
+ * various Spark features.
+ */
+object SparkContext {
+ val SPARK_JOB_DESCRIPTION = "spark.job.description"
+
+ implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
+ def addInPlace(t1: Double, t2: Double): Double = t1 + t2
+ def zero(initialValue: Double) = 0.0
+ }
+
+ implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
+ def addInPlace(t1: Int, t2: Int): Int = t1 + t2
+ def zero(initialValue: Int) = 0
+ }
+
+ implicit object LongAccumulatorParam extends AccumulatorParam[Long] {
+ def addInPlace(t1: Long, t2: Long) = t1 + t2
+ def zero(initialValue: Long) = 0l
+ }
+
+ implicit object FloatAccumulatorParam extends AccumulatorParam[Float] {
+ def addInPlace(t1: Float, t2: Float) = t1 + t2
+ def zero(initialValue: Float) = 0f
+ }
+
+ // TODO: Add AccumulatorParams for other types, e.g. lists and strings
+
+ implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
+ new PairRDDFunctions(rdd)
+
+ implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
+ rdd: RDD[(K, V)]) =
+ new SequenceFileRDDFunctions(rdd)
+
+ implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
+ rdd: RDD[(K, V)]) =
+ new OrderedRDDFunctions[K, V, (K, V)](rdd)
+
+ implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd)
+
+ implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
+ new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
+
+ // Implicit conversions to common Writable types, for saveAsSequenceFile
+
+ implicit def intToIntWritable(i: Int) = new IntWritable(i)
+
+ implicit def longToLongWritable(l: Long) = new LongWritable(l)
+
+ implicit def floatToFloatWritable(f: Float) = new FloatWritable(f)
+
+ implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d)
+
+ implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b)
+
+ implicit def bytesToBytesWritable (aob: Array[Byte]) = new BytesWritable(aob)
+
+ implicit def stringToText(s: String) = new Text(s)
+
+ private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = {
+ def anyToWritable[U <% Writable](u: U): Writable = u
+
+ new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]],
+ arr.map(x => anyToWritable(x)).toArray)
+ }
+
+ // Helper objects for converting common types to Writable
+ private def simpleWritableConverter[T, W <: Writable: ClassManifest](convert: W => T) = {
+ val wClass = classManifest[W].erasure.asInstanceOf[Class[W]]
+ new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
+ }
+
+ implicit def intWritableConverter() = simpleWritableConverter[Int, IntWritable](_.get)
+
+ implicit def longWritableConverter() = simpleWritableConverter[Long, LongWritable](_.get)
+
+ implicit def doubleWritableConverter() = simpleWritableConverter[Double, DoubleWritable](_.get)
+
+ implicit def floatWritableConverter() = simpleWritableConverter[Float, FloatWritable](_.get)
+
+ implicit def booleanWritableConverter() = simpleWritableConverter[Boolean, BooleanWritable](_.get)
+
+ implicit def bytesWritableConverter() = simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes)
+
+ implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString)
+
+ implicit def writableWritableConverter[T <: Writable]() =
+ new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T])
+
+ /**
+ * Find the JAR from which a given class was loaded, to make it easy for users to pass
+ * their JARs to SparkContext
+ */
+ def jarOfClass(cls: Class[_]): Seq[String] = {
+ val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class")
+ if (uri != null) {
+ val uriStr = uri.toString
+ if (uriStr.startsWith("jar:file:")) {
+ // URI will be of the form "jar:file:/path/foo.jar!/package/cls.class", so pull out the /path/foo.jar
+ List(uriStr.substring("jar:file:".length, uriStr.indexOf('!')))
+ } else {
+ Nil
+ }
+ } else {
+ Nil
+ }
+ }
+
+ /** Find the JAR that contains the class of a particular object */
+ def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
+
+ /** Get the amount of memory per executor requested through system properties or SPARK_MEM */
+ private[spark] val executorMemoryRequested = {
+ // TODO: Might need to add some extra memory for the non-heap parts of the JVM
+ Option(System.getProperty("spark.executor.memory"))
+ .orElse(Option(System.getenv("SPARK_MEM")))
+ .map(Utils.memoryStringToMb)
+ .getOrElse(512)
+ }
+}
+
+/**
+ * A class encapsulating how to convert some type T to Writable. It stores both the Writable class
+ * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion.
+ * The getter for the writable class takes a ClassManifest[T] in case this is a generic object
+ * that doesn't know the type of T when it is created. This sounds strange but is necessary to
+ * support converting subclasses of Writable to themselves (writableWritableConverter).
+ */
+private[spark] class WritableConverter[T](
+ val writableClass: ClassManifest[T] => Class[_ <: Writable],
+ val convert: Writable => T)
+ extends Serializable
+
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
new file mode 100644
index 0000000000..6e6fe5df6b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import collection.mutable
+import serializer.Serializer
+
+import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
+import akka.remote.RemoteActorRefProvider
+
+import org.apache.spark.broadcast.BroadcastManager
+import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster}
+import org.apache.spark.network.ConnectionManager
+import org.apache.spark.serializer.{Serializer, SerializerManager}
+import org.apache.spark.util.AkkaUtils
+import org.apache.spark.api.python.PythonWorkerFactory
+
+
+/**
+ * Holds all the runtime environment objects for a running Spark instance (either master or worker),
+ * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
+ * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
+ * objects needs to have the right SparkEnv set. You can get the current environment with
+ * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
+ */
+class SparkEnv (
+ val executorId: String,
+ val actorSystem: ActorSystem,
+ val serializerManager: SerializerManager,
+ val serializer: Serializer,
+ val closureSerializer: Serializer,
+ val cacheManager: CacheManager,
+ val mapOutputTracker: MapOutputTracker,
+ val shuffleFetcher: ShuffleFetcher,
+ val broadcastManager: BroadcastManager,
+ val blockManager: BlockManager,
+ val connectionManager: ConnectionManager,
+ val httpFileServer: HttpFileServer,
+ val sparkFilesDir: String,
+ val metricsSystem: MetricsSystem) {
+
+ private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
+
+ val hadoop = {
+ val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
+ if(yarnMode) {
+ try {
+ Class.forName("spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
+ } catch {
+ case th: Throwable => throw new SparkException("Unable to load YARN support", th)
+ }
+ } else {
+ new SparkHadoopUtil
+ }
+ }
+
+ def stop() {
+ pythonWorkers.foreach { case(key, worker) => worker.stop() }
+ httpFileServer.stop()
+ mapOutputTracker.stop()
+ shuffleFetcher.stop()
+ broadcastManager.stop()
+ blockManager.stop()
+ blockManager.master.stop()
+ metricsSystem.stop()
+ actorSystem.shutdown()
+ // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
+ // down, but let's call it anyway in case it gets fixed in a later release
+ actorSystem.awaitTermination()
+ }
+
+ def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
+ synchronized {
+ val key = (pythonExec, envVars)
+ pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
+ }
+ }
+}
+
+object SparkEnv extends Logging {
+ private val env = new ThreadLocal[SparkEnv]
+ @volatile private var lastSetSparkEnv : SparkEnv = _
+
+ def set(e: SparkEnv) {
+ lastSetSparkEnv = e
+ env.set(e)
+ }
+
+ /**
+ * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv
+ * previously set in any thread.
+ */
+ def get: SparkEnv = {
+ Option(env.get()).getOrElse(lastSetSparkEnv)
+ }
+
+ /**
+ * Returns the ThreadLocal SparkEnv.
+ */
+ def getThreadLocal : SparkEnv = {
+ env.get()
+ }
+
+ def createFromSystemProperties(
+ executorId: String,
+ hostname: String,
+ port: Int,
+ isDriver: Boolean,
+ isLocal: Boolean): SparkEnv = {
+
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port)
+
+ // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
+ // figure out which port number Akka actually bound to and set spark.driver.port to it.
+ if (isDriver && port == 0) {
+ System.setProperty("spark.driver.port", boundPort.toString)
+ }
+
+ // set only if unset until now.
+ if (System.getProperty("spark.hostPort", null) == null) {
+ if (!isDriver){
+ // unexpected
+ Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
+ }
+ Utils.checkHost(hostname)
+ System.setProperty("spark.hostPort", hostname + ":" + boundPort)
+ }
+
+ val classLoader = Thread.currentThread.getContextClassLoader
+
+ // Create an instance of the class named by the given Java system property, or by
+ // defaultClassName if the property is not set, and return it as a T
+ def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
+ val name = System.getProperty(propertyName, defaultClassName)
+ Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
+ }
+
+ val serializerManager = new SerializerManager
+
+ val serializer = serializerManager.setDefault(
+ System.getProperty("spark.serializer", "org.apache.spark.JavaSerializer"))
+
+ val closureSerializer = serializerManager.get(
+ System.getProperty("spark.closure.serializer", "org.apache.spark.JavaSerializer"))
+
+ def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
+ if (isDriver) {
+ logInfo("Registering " + name)
+ actorSystem.actorOf(Props(newActor), name = name)
+ } else {
+ val driverHost: String = System.getProperty("spark.driver.host", "localhost")
+ val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
+ Utils.checkHost(driverHost, "Expected hostname")
+ val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
+ logInfo("Connecting to " + name + ": " + url)
+ actorSystem.actorFor(url)
+ }
+ }
+
+ val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
+ "BlockManagerMaster",
+ new BlockManagerMasterActor(isLocal)))
+ val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
+
+ val connectionManager = blockManager.connectionManager
+
+ val broadcastManager = new BroadcastManager(isDriver)
+
+ val cacheManager = new CacheManager(blockManager)
+
+ // Have to assign trackerActor after initialization as MapOutputTrackerActor
+ // requires the MapOutputTracker itself
+ val mapOutputTracker = new MapOutputTracker()
+ mapOutputTracker.trackerActor = registerOrLookup(
+ "MapOutputTracker",
+ new MapOutputTrackerActor(mapOutputTracker))
+
+ val shuffleFetcher = instantiateClass[ShuffleFetcher](
+ "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
+
+ val httpFileServer = new HttpFileServer()
+ httpFileServer.initialize()
+ System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
+
+ val metricsSystem = if (isDriver) {
+ MetricsSystem.createMetricsSystem("driver")
+ } else {
+ MetricsSystem.createMetricsSystem("executor")
+ }
+ metricsSystem.start()
+
+ // Set the sparkFiles directory, used when downloading dependencies. In local mode,
+ // this is a temporary directory; in distributed mode, this is the executor's current working
+ // directory.
+ val sparkFilesDir: String = if (isDriver) {
+ Utils.createTempDir().getAbsolutePath
+ } else {
+ "."
+ }
+
+ // Warn about deprecated spark.cache.class property
+ if (System.getProperty("spark.cache.class") != null) {
+ logWarning("The spark.cache.class property is no longer being used! Specify storage " +
+ "levels using the RDD.persist() method instead.")
+ }
+
+ new SparkEnv(
+ executorId,
+ actorSystem,
+ serializerManager,
+ serializer,
+ closureSerializer,
+ cacheManager,
+ mapOutputTracker,
+ shuffleFetcher,
+ broadcastManager,
+ blockManager,
+ connectionManager,
+ httpFileServer,
+ sparkFilesDir,
+ metricsSystem)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala
new file mode 100644
index 0000000000..d34e47e8ca
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkException.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+class SparkException(message: String, cause: Throwable)
+ extends Exception(message, cause) {
+
+ def this(message: String) = this(message, null)
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkFiles.java b/core/src/main/scala/org/apache/spark/SparkFiles.java
new file mode 100644
index 0000000000..af9cf85e37
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkFiles.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.File;
+
+/**
+ * Resolves paths to files added through `SparkContext.addFile()`.
+ */
+public class SparkFiles {
+
+ private SparkFiles() {}
+
+ /**
+ * Get the absolute path of a file added through `SparkContext.addFile()`.
+ */
+ public static String get(String filename) {
+ return new File(getRootDirectory(), filename).getAbsolutePath();
+ }
+
+ /**
+ * Get the root directory that contains files added through `SparkContext.addFile()`.
+ */
+ public static String getRootDirectory() {
+ return SparkEnv.get().sparkFilesDir();
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
new file mode 100644
index 0000000000..2bab9d6e3d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.mapred
+
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+
+import java.text.SimpleDateFormat
+import java.text.NumberFormat
+import java.io.IOException
+import java.util.Date
+
+import org.apache.spark.Logging
+import org.apache.spark.SerializableWritable
+
+/**
+ * Internal helper class that saves an RDD using a Hadoop OutputFormat. This is only public
+ * because we need to access this class from the `spark` package to use some package-private Hadoop
+ * functions, but this class should not be used directly by users.
+ *
+ * Saves the RDD using a JobConf, which should contain an output key class, an output value class,
+ * a filename to write to, etc, exactly like in a Hadoop MapReduce job.
+ */
+class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable {
+
+ private val now = new Date()
+ private val conf = new SerializableWritable(jobConf)
+
+ private var jobID = 0
+ private var splitID = 0
+ private var attemptID = 0
+ private var jID: SerializableWritable[JobID] = null
+ private var taID: SerializableWritable[TaskAttemptID] = null
+
+ @transient private var writer: RecordWriter[AnyRef,AnyRef] = null
+ @transient private var format: OutputFormat[AnyRef,AnyRef] = null
+ @transient private var committer: OutputCommitter = null
+ @transient private var jobContext: JobContext = null
+ @transient private var taskContext: TaskAttemptContext = null
+
+ def preSetup() {
+ setIDs(0, 0, 0)
+ setConfParams()
+
+ val jCtxt = getJobContext()
+ getOutputCommitter().setupJob(jCtxt)
+ }
+
+
+ def setup(jobid: Int, splitid: Int, attemptid: Int) {
+ setIDs(jobid, splitid, attemptid)
+ setConfParams()
+ }
+
+ def open() {
+ val numfmt = NumberFormat.getInstance()
+ numfmt.setMinimumIntegerDigits(5)
+ numfmt.setGroupingUsed(false)
+
+ val outputName = "part-" + numfmt.format(splitID)
+ val path = FileOutputFormat.getOutputPath(conf.value)
+ val fs: FileSystem = {
+ if (path != null) {
+ path.getFileSystem(conf.value)
+ } else {
+ FileSystem.get(conf.value)
+ }
+ }
+
+ getOutputCommitter().setupTask(getTaskContext())
+ writer = getOutputFormat().getRecordWriter(
+ fs, conf.value, outputName, Reporter.NULL)
+ }
+
+ def write(key: AnyRef, value: AnyRef) {
+ if (writer!=null) {
+ //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")")
+ writer.write(key, value)
+ } else {
+ throw new IOException("Writer is null, open() has not been called")
+ }
+ }
+
+ def close() {
+ writer.close(Reporter.NULL)
+ }
+
+ def commit() {
+ val taCtxt = getTaskContext()
+ val cmtr = getOutputCommitter()
+ if (cmtr.needsTaskCommit(taCtxt)) {
+ try {
+ cmtr.commitTask(taCtxt)
+ logInfo (taID + ": Committed")
+ } catch {
+ case e: IOException => {
+ logError("Error committing the output of task: " + taID.value, e)
+ cmtr.abortTask(taCtxt)
+ throw e
+ }
+ }
+ } else {
+ logWarning ("No need to commit output of task: " + taID.value)
+ }
+ }
+
+ def commitJob() {
+ // always ? Or if cmtr.needsTaskCommit ?
+ val cmtr = getOutputCommitter()
+ cmtr.commitJob(getJobContext())
+ }
+
+ def cleanup() {
+ getOutputCommitter().cleanupJob(getJobContext())
+ }
+
+ // ********* Private Functions *********
+
+ private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = {
+ if (format == null) {
+ format = conf.value.getOutputFormat()
+ .asInstanceOf[OutputFormat[AnyRef,AnyRef]]
+ }
+ return format
+ }
+
+ private def getOutputCommitter(): OutputCommitter = {
+ if (committer == null) {
+ committer = conf.value.getOutputCommitter
+ }
+ return committer
+ }
+
+ private def getJobContext(): JobContext = {
+ if (jobContext == null) {
+ jobContext = newJobContext(conf.value, jID.value)
+ }
+ return jobContext
+ }
+
+ private def getTaskContext(): TaskAttemptContext = {
+ if (taskContext == null) {
+ taskContext = newTaskAttemptContext(conf.value, taID.value)
+ }
+ return taskContext
+ }
+
+ private def setIDs(jobid: Int, splitid: Int, attemptid: Int) {
+ jobID = jobid
+ splitID = splitid
+ attemptID = attemptid
+
+ jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid))
+ taID = new SerializableWritable[TaskAttemptID](
+ new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID))
+ }
+
+ private def setConfParams() {
+ conf.value.set("mapred.job.id", jID.value.toString)
+ conf.value.set("mapred.tip.id", taID.value.getTaskID.toString)
+ conf.value.set("mapred.task.id", taID.value.toString)
+ conf.value.setBoolean("mapred.task.is.map", true)
+ conf.value.setInt("mapred.task.partition", splitID)
+ }
+}
+
+object SparkHadoopWriter {
+ def createJobID(time: Date, id: Int): JobID = {
+ val formatter = new SimpleDateFormat("yyyyMMddHHmm")
+ val jobtrackerID = formatter.format(new Date())
+ return new JobID(jobtrackerID, id)
+ }
+
+ def createPathFromString(path: String, conf: JobConf): Path = {
+ if (path == null) {
+ throw new IllegalArgumentException("Output path is null")
+ }
+ var outputPath = new Path(path)
+ val fs = outputPath.getFileSystem(conf)
+ if (outputPath == null || fs == null) {
+ throw new IllegalArgumentException("Incorrectly formatted output path")
+ }
+ outputPath = outputPath.makeQualified(fs)
+ return outputPath
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
new file mode 100644
index 0000000000..b2dd668330
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import executor.TaskMetrics
+import scala.collection.mutable.ArrayBuffer
+
+class TaskContext(
+ val stageId: Int,
+ val splitId: Int,
+ val attemptId: Long,
+ val taskMetrics: TaskMetrics = TaskMetrics.empty()
+) extends Serializable {
+
+ @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+
+ // Add a callback function to be executed on task completion. An example use
+ // is for HadoopRDD to register a callback to close the input stream.
+ def addOnCompleteCallback(f: () => Unit) {
+ onCompleteCallbacks += f
+ }
+
+ def executeOnCompleteCallbacks() {
+ onCompleteCallbacks.foreach{_()}
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
new file mode 100644
index 0000000000..03bf268863
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.BlockManagerId
+
+/**
+ * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry
+ * tasks several times for "ephemeral" failures, and only report back failures that require some
+ * old stages to be resubmitted, such as shuffle map fetch failures.
+ */
+private[spark] sealed trait TaskEndReason
+
+private[spark] case object Success extends TaskEndReason
+
+private[spark]
+case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
+
+private[spark] case class FetchFailed(
+ bmAddress: BlockManagerId,
+ shuffleId: Int,
+ mapId: Int,
+ reduceId: Int)
+ extends TaskEndReason
+
+private[spark] case class ExceptionFailure(
+ className: String,
+ description: String,
+ stackTrace: Array[StackTraceElement],
+ metrics: Option[TaskMetrics])
+ extends TaskEndReason
+
+private[spark] case class OtherFailure(message: String) extends TaskEndReason
+
+private[spark] case class TaskResultTooBigFailure() extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala
new file mode 100644
index 0000000000..19ce8369d9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskState.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.mesos.Protos.{TaskState => MesosTaskState}
+
+private[spark] object TaskState
+ extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") {
+
+ val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value
+
+ val FINISHED_STATES = Set(FINISHED, FAILED, KILLED, LOST)
+
+ type TaskState = Value
+
+ def isFinished(state: TaskState) = FINISHED_STATES.contains(state)
+
+ def toMesos(state: TaskState): MesosTaskState = state match {
+ case LAUNCHING => MesosTaskState.TASK_STARTING
+ case RUNNING => MesosTaskState.TASK_RUNNING
+ case FINISHED => MesosTaskState.TASK_FINISHED
+ case FAILED => MesosTaskState.TASK_FAILED
+ case KILLED => MesosTaskState.TASK_KILLED
+ case LOST => MesosTaskState.TASK_LOST
+ }
+
+ def fromMesos(mesosState: MesosTaskState): TaskState = mesosState match {
+ case MesosTaskState.TASK_STAGING => LAUNCHING
+ case MesosTaskState.TASK_STARTING => LAUNCHING
+ case MesosTaskState.TASK_RUNNING => RUNNING
+ case MesosTaskState.TASK_FINISHED => FINISHED
+ case MesosTaskState.TASK_FAILED => FAILED
+ case MesosTaskState.TASK_KILLED => KILLED
+ case MesosTaskState.TASK_LOST => LOST
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/Utils.scala b/core/src/main/scala/org/apache/spark/Utils.scala
new file mode 100644
index 0000000000..1e17deb010
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/Utils.scala
@@ -0,0 +1,780 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io._
+import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
+import java.util.{Locale, Random, UUID}
+import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
+import java.util.regex.Pattern
+
+import scala.collection.Map
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.JavaConversions._
+import scala.io.Source
+
+import com.google.common.io.Files
+import com.google.common.util.concurrent.ThreadFactoryBuilder
+
+import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
+
+import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
+import org.apache.spark.deploy.SparkHadoopUtil
+import java.nio.ByteBuffer
+
+
+/**
+ * Various utility methods used by Spark.
+ */
+private object Utils extends Logging {
+
+ /** Serialize an object using Java serialization */
+ def serialize[T](o: T): Array[Byte] = {
+ val bos = new ByteArrayOutputStream()
+ val oos = new ObjectOutputStream(bos)
+ oos.writeObject(o)
+ oos.close()
+ return bos.toByteArray
+ }
+
+ /** Deserialize an object using Java serialization */
+ def deserialize[T](bytes: Array[Byte]): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ois = new ObjectInputStream(bis)
+ return ois.readObject.asInstanceOf[T]
+ }
+
+ /** Deserialize an object using Java serialization and the given ClassLoader */
+ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ois = new ObjectInputStream(bis) {
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, loader)
+ }
+ return ois.readObject.asInstanceOf[T]
+ }
+
+ /** Serialize via nested stream using specific serializer */
+ def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = {
+ val osWrapper = ser.serializeStream(new OutputStream {
+ def write(b: Int) = os.write(b)
+
+ override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len)
+ })
+ try {
+ f(osWrapper)
+ } finally {
+ osWrapper.close()
+ }
+ }
+
+ /** Deserialize via nested stream using specific serializer */
+ def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(f: DeserializationStream => Unit) = {
+ val isWrapper = ser.deserializeStream(new InputStream {
+ def read(): Int = is.read()
+
+ override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len)
+ })
+ try {
+ f(isWrapper)
+ } finally {
+ isWrapper.close()
+ }
+ }
+
+ /**
+ * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}.
+ */
+ def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = {
+ if (bb.hasArray) {
+ out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+ } else {
+ val bbval = new Array[Byte](bb.remaining())
+ bb.get(bbval)
+ out.write(bbval)
+ }
+ }
+
+ def isAlpha(c: Char): Boolean = {
+ (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')
+ }
+
+ /** Split a string into words at non-alphabetic characters */
+ def splitWords(s: String): Seq[String] = {
+ val buf = new ArrayBuffer[String]
+ var i = 0
+ while (i < s.length) {
+ var j = i
+ while (j < s.length && isAlpha(s.charAt(j))) {
+ j += 1
+ }
+ if (j > i) {
+ buf += s.substring(i, j)
+ }
+ i = j
+ while (i < s.length && !isAlpha(s.charAt(i))) {
+ i += 1
+ }
+ }
+ return buf
+ }
+
+ private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
+
+ // Register the path to be deleted via shutdown hook
+ def registerShutdownDeleteDir(file: File) {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths += absolutePath
+ }
+ }
+
+ // Is the path already registered to be deleted via a shutdown hook ?
+ def hasShutdownDeleteDir(file: File): Boolean = {
+ val absolutePath = file.getAbsolutePath()
+ shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.contains(absolutePath)
+ }
+ }
+
+ // Note: if file is child of some registered path, while not equal to it, then return true;
+ // else false. This is to ensure that two shutdown hooks do not try to delete each others
+ // paths - resulting in IOException and incomplete cleanup.
+ def hasRootAsShutdownDeleteDir(file: File): Boolean = {
+ val absolutePath = file.getAbsolutePath()
+ val retval = shutdownDeletePaths.synchronized {
+ shutdownDeletePaths.find { path =>
+ !absolutePath.equals(path) && absolutePath.startsWith(path)
+ }.isDefined
+ }
+ if (retval) {
+ logInfo("path = " + file + ", already present as root for deletion.")
+ }
+ retval
+ }
+
+ /** Create a temporary directory inside the given parent directory */
+ def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = {
+ var attempts = 0
+ val maxAttempts = 10
+ var dir: File = null
+ while (dir == null) {
+ attempts += 1
+ if (attempts > maxAttempts) {
+ throw new IOException("Failed to create a temp directory (under " + root + ") after " +
+ maxAttempts + " attempts!")
+ }
+ try {
+ dir = new File(root, "spark-" + UUID.randomUUID.toString)
+ if (dir.exists() || !dir.mkdirs()) {
+ dir = null
+ }
+ } catch { case e: IOException => ; }
+ }
+
+ registerShutdownDeleteDir(dir)
+
+ // Add a shutdown hook to delete the temp dir when the JVM exits
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
+ override def run() {
+ // Attempt to delete if some patch which is parent of this is not already registered.
+ if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
+ }
+ })
+ dir
+ }
+
+ /** Copy all data from an InputStream to an OutputStream */
+ def copyStream(in: InputStream,
+ out: OutputStream,
+ closeStreams: Boolean = false)
+ {
+ val buf = new Array[Byte](8192)
+ var n = 0
+ while (n != -1) {
+ n = in.read(buf)
+ if (n != -1) {
+ out.write(buf, 0, n)
+ }
+ }
+ if (closeStreams) {
+ in.close()
+ out.close()
+ }
+ }
+
+ /**
+ * Download a file requested by the executor. Supports fetching the file in a variety of ways,
+ * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
+ *
+ * Throws SparkException if the target file already exists and has different contents than
+ * the requested file.
+ */
+ def fetchFile(url: String, targetDir: File) {
+ val filename = url.split("/").last
+ val tempDir = getLocalDir
+ val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
+ val targetFile = new File(targetDir, filename)
+ val uri = new URI(url)
+ uri.getScheme match {
+ case "http" | "https" | "ftp" =>
+ logInfo("Fetching " + url + " to " + tempFile)
+ val in = new URL(url).openStream()
+ val out = new FileOutputStream(tempFile)
+ Utils.copyStream(in, out, true)
+ if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
+ tempFile.delete()
+ throw new SparkException(
+ "File " + targetFile + " exists and does not match contents of" + " " + url)
+ } else {
+ Files.move(tempFile, targetFile)
+ }
+ case "file" | null =>
+ // In the case of a local file, copy the local file to the target directory.
+ // Note the difference between uri vs url.
+ val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url)
+ if (targetFile.exists) {
+ // If the target file already exists, warn the user if
+ if (!Files.equal(sourceFile, targetFile)) {
+ throw new SparkException(
+ "File " + targetFile + " exists and does not match contents of" + " " + url)
+ } else {
+ // Do nothing if the file contents are the same, i.e. this file has been copied
+ // previously.
+ logInfo(sourceFile.getAbsolutePath + " has been previously copied to "
+ + targetFile.getAbsolutePath)
+ }
+ } else {
+ // The file does not exist in the target directory. Copy it there.
+ logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
+ Files.copy(sourceFile, targetFile)
+ }
+ case _ =>
+ // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
+ val env = SparkEnv.get
+ val uri = new URI(url)
+ val conf = env.hadoop.newConfiguration()
+ val fs = FileSystem.get(uri, conf)
+ val in = fs.open(new Path(uri))
+ val out = new FileOutputStream(tempFile)
+ Utils.copyStream(in, out, true)
+ if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
+ tempFile.delete()
+ throw new SparkException("File " + targetFile + " exists and does not match contents of" +
+ " " + url)
+ } else {
+ Files.move(tempFile, targetFile)
+ }
+ }
+ // Decompress the file if it's a .tar or .tar.gz
+ if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
+ logInfo("Untarring " + filename)
+ Utils.execute(Seq("tar", "-xzf", filename), targetDir)
+ } else if (filename.endsWith(".tar")) {
+ logInfo("Untarring " + filename)
+ Utils.execute(Seq("tar", "-xf", filename), targetDir)
+ }
+ // Make the file executable - That's necessary for scripts
+ FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
+ }
+
+ /**
+ * Get a temporary directory using Spark's spark.local.dir property, if set. This will always
+ * return a single directory, even though the spark.local.dir property might be a list of
+ * multiple paths.
+ */
+ def getLocalDir: String = {
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0)
+ }
+
+ /**
+ * Shuffle the elements of a collection into a random order, returning the
+ * result in a new collection. Unlike scala.util.Random.shuffle, this method
+ * uses a local random number generator, avoiding inter-thread contention.
+ */
+ def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
+ randomizeInPlace(seq.toArray)
+ }
+
+ /**
+ * Shuffle the elements of an array into a random order, modifying the
+ * original array. Returns the original array.
+ */
+ def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
+ for (i <- (arr.length - 1) to 1 by -1) {
+ val j = rand.nextInt(i)
+ val tmp = arr(j)
+ arr(j) = arr(i)
+ arr(i) = tmp
+ }
+ arr
+ }
+
+ /**
+ * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
+ * Note, this is typically not used from within core spark.
+ */
+ lazy val localIpAddress: String = findLocalIpAddress()
+ lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress)
+
+ private def findLocalIpAddress(): String = {
+ val defaultIpOverride = System.getenv("SPARK_LOCAL_IP")
+ if (defaultIpOverride != null) {
+ defaultIpOverride
+ } else {
+ val address = InetAddress.getLocalHost
+ if (address.isLoopbackAddress) {
+ // Address resolves to something like 127.0.1.1, which happens on Debian; try to find
+ // a better address using the local network interfaces
+ for (ni <- NetworkInterface.getNetworkInterfaces) {
+ for (addr <- ni.getInetAddresses if !addr.isLinkLocalAddress &&
+ !addr.isLoopbackAddress && addr.isInstanceOf[Inet4Address]) {
+ // We've found an address that looks reasonable!
+ logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
+ " a loopback address: " + address.getHostAddress + "; using " + addr.getHostAddress +
+ " instead (on interface " + ni.getName + ")")
+ logWarning("Set SPARK_LOCAL_IP if you need to bind to another address")
+ return addr.getHostAddress
+ }
+ }
+ logWarning("Your hostname, " + InetAddress.getLocalHost.getHostName + " resolves to" +
+ " a loopback address: " + address.getHostAddress + ", but we couldn't find any" +
+ " external IP address!")
+ logWarning("Set SPARK_LOCAL_IP if you need to bind to another address")
+ }
+ address.getHostAddress
+ }
+ }
+
+ private var customHostname: Option[String] = None
+
+ /**
+ * Allow setting a custom host name because when we run on Mesos we need to use the same
+ * hostname it reports to the master.
+ */
+ def setCustomHostname(hostname: String) {
+ // DEBUG code
+ Utils.checkHost(hostname)
+ customHostname = Some(hostname)
+ }
+
+ /**
+ * Get the local machine's hostname.
+ */
+ def localHostName(): String = {
+ customHostname.getOrElse(localIpAddressHostname)
+ }
+
+ def getAddressHostName(address: String): String = {
+ InetAddress.getByName(address).getHostName
+ }
+
+ def localHostPort(): String = {
+ val retval = System.getProperty("spark.hostPort", null)
+ if (retval == null) {
+ logErrorWithStack("spark.hostPort not set but invoking localHostPort")
+ return localHostName()
+ }
+
+ retval
+ }
+
+ def checkHost(host: String, message: String = "") {
+ assert(host.indexOf(':') == -1, message)
+ }
+
+ def checkHostPort(hostPort: String, message: String = "") {
+ assert(hostPort.indexOf(':') != -1, message)
+ }
+
+ // Used by DEBUG code : remove when all testing done
+ def logErrorWithStack(msg: String) {
+ try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
+ }
+
+ // Typically, this will be of order of number of nodes in cluster
+ // If not, we should change it to LRUCache or something.
+ private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
+
+ def parseHostPort(hostPort: String): (String, Int) = {
+ {
+ // Check cache first.
+ var cached = hostPortParseResults.get(hostPort)
+ if (cached != null) return cached
+ }
+
+ val indx: Int = hostPort.lastIndexOf(':')
+ // This is potentially broken - when dealing with ipv6 addresses for example, sigh ...
+ // but then hadoop does not support ipv6 right now.
+ // For now, we assume that if port exists, then it is valid - not check if it is an int > 0
+ if (-1 == indx) {
+ val retval = (hostPort, 0)
+ hostPortParseResults.put(hostPort, retval)
+ return retval
+ }
+
+ val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt)
+ hostPortParseResults.putIfAbsent(hostPort, retval)
+ hostPortParseResults.get(hostPort)
+ }
+
+ private[spark] val daemonThreadFactory: ThreadFactory =
+ new ThreadFactoryBuilder().setDaemon(true).build()
+
+ /**
+ * Wrapper over newCachedThreadPool.
+ */
+ def newDaemonCachedThreadPool(): ThreadPoolExecutor =
+ Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+
+ /**
+ * Return the string to tell how long has passed in seconds. The passing parameter should be in
+ * millisecond.
+ */
+ def getUsedTimeMs(startTimeMs: Long): String = {
+ return " " + (System.currentTimeMillis - startTimeMs) + " ms"
+ }
+
+ /**
+ * Wrapper over newFixedThreadPool.
+ */
+ def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
+ Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+
+ /**
+ * Delete a file or directory and its contents recursively.
+ */
+ def deleteRecursively(file: File) {
+ if (file.isDirectory) {
+ for (child <- file.listFiles()) {
+ deleteRecursively(child)
+ }
+ }
+ if (!file.delete()) {
+ throw new IOException("Failed to delete: " + file)
+ }
+ }
+
+ /**
+ * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes.
+ * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM
+ * environment variable.
+ */
+ def memoryStringToMb(str: String): Int = {
+ val lower = str.toLowerCase
+ if (lower.endsWith("k")) {
+ (lower.substring(0, lower.length-1).toLong / 1024).toInt
+ } else if (lower.endsWith("m")) {
+ lower.substring(0, lower.length-1).toInt
+ } else if (lower.endsWith("g")) {
+ lower.substring(0, lower.length-1).toInt * 1024
+ } else if (lower.endsWith("t")) {
+ lower.substring(0, lower.length-1).toInt * 1024 * 1024
+ } else {// no suffix, so it's just a number in bytes
+ (lower.toLong / 1024 / 1024).toInt
+ }
+ }
+
+ /**
+ * Convert a quantity in bytes to a human-readable string such as "4.0 MB".
+ */
+ def bytesToString(size: Long): String = {
+ val TB = 1L << 40
+ val GB = 1L << 30
+ val MB = 1L << 20
+ val KB = 1L << 10
+
+ val (value, unit) = {
+ if (size >= 2*TB) {
+ (size.asInstanceOf[Double] / TB, "TB")
+ } else if (size >= 2*GB) {
+ (size.asInstanceOf[Double] / GB, "GB")
+ } else if (size >= 2*MB) {
+ (size.asInstanceOf[Double] / MB, "MB")
+ } else if (size >= 2*KB) {
+ (size.asInstanceOf[Double] / KB, "KB")
+ } else {
+ (size.asInstanceOf[Double], "B")
+ }
+ }
+ "%.1f %s".formatLocal(Locale.US, value, unit)
+ }
+
+ /**
+ * Returns a human-readable string representing a duration such as "35ms"
+ */
+ def msDurationToString(ms: Long): String = {
+ val second = 1000
+ val minute = 60 * second
+ val hour = 60 * minute
+
+ ms match {
+ case t if t < second =>
+ "%d ms".format(t)
+ case t if t < minute =>
+ "%.1f s".format(t.toFloat / second)
+ case t if t < hour =>
+ "%.1f m".format(t.toFloat / minute)
+ case t =>
+ "%.2f h".format(t.toFloat / hour)
+ }
+ }
+
+ /**
+ * Convert a quantity in megabytes to a human-readable string such as "4.0 MB".
+ */
+ def megabytesToString(megabytes: Long): String = {
+ bytesToString(megabytes * 1024L * 1024L)
+ }
+
+ /**
+ * Execute a command in the given working directory, throwing an exception if it completes
+ * with an exit code other than 0.
+ */
+ def execute(command: Seq[String], workingDir: File) {
+ val process = new ProcessBuilder(command: _*)
+ .directory(workingDir)
+ .redirectErrorStream(true)
+ .start()
+ new Thread("read stdout for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getInputStream).getLines) {
+ System.err.println(line)
+ }
+ }
+ }.start()
+ val exitCode = process.waitFor()
+ if (exitCode != 0) {
+ throw new SparkException("Process " + command + " exited with code " + exitCode)
+ }
+ }
+
+ /**
+ * Execute a command in the current working directory, throwing an exception if it completes
+ * with an exit code other than 0.
+ */
+ def execute(command: Seq[String]) {
+ execute(command, new File("."))
+ }
+
+ /**
+ * Execute a command and get its output, throwing an exception if it yields a code other than 0.
+ */
+ def executeAndGetOutput(command: Seq[String], workingDir: File = new File("."),
+ extraEnvironment: Map[String, String] = Map.empty): String = {
+ val builder = new ProcessBuilder(command: _*)
+ .directory(workingDir)
+ val environment = builder.environment()
+ for ((key, value) <- extraEnvironment) {
+ environment.put(key, value)
+ }
+ val process = builder.start()
+ new Thread("read stderr for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getErrorStream).getLines) {
+ System.err.println(line)
+ }
+ }
+ }.start()
+ val output = new StringBuffer
+ val stdoutThread = new Thread("read stdout for " + command(0)) {
+ override def run() {
+ for (line <- Source.fromInputStream(process.getInputStream).getLines) {
+ output.append(line)
+ }
+ }
+ }
+ stdoutThread.start()
+ val exitCode = process.waitFor()
+ stdoutThread.join() // Wait for it to finish reading output
+ if (exitCode != 0) {
+ throw new SparkException("Process " + command + " exited with code " + exitCode)
+ }
+ output.toString
+ }
+
+ /**
+ * A regular expression to match classes of the "core" Spark API that we want to skip when
+ * finding the call site of a method.
+ */
+ private val SPARK_CLASS_REGEX = """^spark(\.api\.java)?(\.rdd)?\.[A-Z]""".r
+
+ private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
+ val firstUserLine: Int, val firstUserClass: String)
+
+ /**
+ * When called inside a class in the spark package, returns the name of the user code class
+ * (outside the spark package) that called into Spark, as well as which Spark method they called.
+ * This is used, for example, to tell users where in their code each RDD got created.
+ */
+ def getCallSiteInfo: CallSiteInfo = {
+ val trace = Thread.currentThread.getStackTrace().filter( el =>
+ (!el.getMethodName.contains("getStackTrace")))
+
+ // Keep crawling up the stack trace until we find the first function not inside of the spark
+ // package. We track the last (shallowest) contiguous Spark method. This might be an RDD
+ // transformation, a SparkContext function (such as parallelize), or anything else that leads
+ // to instantiation of an RDD. We also track the first (deepest) user method, file, and line.
+ var lastSparkMethod = "<unknown>"
+ var firstUserFile = "<unknown>"
+ var firstUserLine = 0
+ var finished = false
+ var firstUserClass = "<unknown>"
+
+ for (el <- trace) {
+ if (!finished) {
+ if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName) != None) {
+ lastSparkMethod = if (el.getMethodName == "<init>") {
+ // Spark method is a constructor; get its class name
+ el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)
+ } else {
+ el.getMethodName
+ }
+ }
+ else {
+ firstUserLine = el.getLineNumber
+ firstUserFile = el.getFileName
+ firstUserClass = el.getClassName
+ finished = true
+ }
+ }
+ }
+ new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
+ }
+
+ def formatSparkCallSite = {
+ val callSiteInfo = getCallSiteInfo
+ "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
+ callSiteInfo.firstUserLine)
+ }
+
+ /** Return a string containing part of a file from byte 'start' to 'end'. */
+ def offsetBytes(path: String, start: Long, end: Long): String = {
+ val file = new File(path)
+ val length = file.length()
+ val effectiveEnd = math.min(length, end)
+ val effectiveStart = math.max(0, start)
+ val buff = new Array[Byte]((effectiveEnd-effectiveStart).toInt)
+ val stream = new FileInputStream(file)
+
+ stream.skip(effectiveStart)
+ stream.read(buff)
+ stream.close()
+ Source.fromBytes(buff).mkString
+ }
+
+ /**
+ * Clone an object using a Spark serializer.
+ */
+ def clone[T](value: T, serializer: SerializerInstance): T = {
+ serializer.deserialize[T](serializer.serialize(value))
+ }
+
+ /**
+ * Detect whether this thread might be executing a shutdown hook. Will always return true if
+ * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g.
+ * if System.exit was just called by a concurrent thread).
+ *
+ * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing
+ * an IllegalStateException.
+ */
+ def inShutdown(): Boolean = {
+ try {
+ val hook = new Thread {
+ override def run() {}
+ }
+ Runtime.getRuntime.addShutdownHook(hook)
+ Runtime.getRuntime.removeShutdownHook(hook)
+ } catch {
+ case ise: IllegalStateException => return true
+ }
+ return false
+ }
+
+ def isSpace(c: Char): Boolean = {
+ " \t\r\n".indexOf(c) != -1
+ }
+
+ /**
+ * Split a string of potentially quoted arguments from the command line the way that a shell
+ * would do it to determine arguments to a command. For example, if the string is 'a "b c" d',
+ * then it would be parsed as three arguments: 'a', 'b c' and 'd'.
+ */
+ def splitCommandString(s: String): Seq[String] = {
+ val buf = new ArrayBuffer[String]
+ var inWord = false
+ var inSingleQuote = false
+ var inDoubleQuote = false
+ var curWord = new StringBuilder
+ def endWord() {
+ buf += curWord.toString
+ curWord.clear()
+ }
+ var i = 0
+ while (i < s.length) {
+ var nextChar = s.charAt(i)
+ if (inDoubleQuote) {
+ if (nextChar == '"') {
+ inDoubleQuote = false
+ } else if (nextChar == '\\') {
+ if (i < s.length - 1) {
+ // Append the next character directly, because only " and \ may be escaped in
+ // double quotes after the shell's own expansion
+ curWord.append(s.charAt(i + 1))
+ i += 1
+ }
+ } else {
+ curWord.append(nextChar)
+ }
+ } else if (inSingleQuote) {
+ if (nextChar == '\'') {
+ inSingleQuote = false
+ } else {
+ curWord.append(nextChar)
+ }
+ // Backslashes are not treated specially in single quotes
+ } else if (nextChar == '"') {
+ inWord = true
+ inDoubleQuote = true
+ } else if (nextChar == '\'') {
+ inWord = true
+ inSingleQuote = true
+ } else if (!isSpace(nextChar)) {
+ curWord.append(nextChar)
+ inWord = true
+ } else if (inWord && isSpace(nextChar)) {
+ endWord()
+ inWord = false
+ }
+ i += 1
+ }
+ if (inWord || inDoubleQuote || inSingleQuote) {
+ endWord()
+ }
+ return buf
+ }
+
+ /* Calculates 'x' modulo 'mod', takes to consideration sign of x,
+ * i.e. if 'x' is negative, than 'x' % 'mod' is negative too
+ * so function return (x % mod) + mod in that case.
+ */
+ def nonNegativeMod(x: Int, mod: Int): Int = {
+ val rawMod = x % mod
+ rawMod + (if (rawMod < 0) mod else 0)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
new file mode 100644
index 0000000000..cb25ff728e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java
+
+import org.apache.spark.RDD
+import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions
+import org.apache.spark.api.java.function.{Function => JFunction}
+import org.apache.spark.util.StatCounter
+import org.apache.spark.partial.{BoundedDouble, PartialResult}
+import org.apache.spark.storage.StorageLevel
+import java.lang.Double
+import org.apache.spark.Partitioner
+
+class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] {
+
+ override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]]
+
+ override val rdd: RDD[Double] = srdd.map(x => Double.valueOf(x))
+
+ override def wrapRDD(rdd: RDD[Double]): JavaDoubleRDD =
+ new JavaDoubleRDD(rdd.map(_.doubleValue))
+
+ // Common RDD functions
+
+ import JavaDoubleRDD.fromRDD
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ def cache(): JavaDoubleRDD = fromRDD(srdd.cache())
+
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. Can only be called once on each RDD.
+ */
+ def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
+
+ // first() has to be overriden here in order for its return type to be Double instead of Object.
+ override def first(): Double = srdd.first()
+
+ // Transformations (return a new RDD)
+
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
+ def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct())
+
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
+ def distinct(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numPartitions))
+
+ /**
+ * Return a new RDD containing only the elements that satisfy a predicate.
+ */
+ def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD =
+ fromRDD(srdd.filter(x => f(x).booleanValue()))
+
+ /**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.coalesce(numPartitions))
+
+ /**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD =
+ fromRDD(srdd.coalesce(numPartitions, shuffle))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ *
+ * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+ * RDD will be <= us.
+ */
+ def subtract(other: JavaDoubleRDD): JavaDoubleRDD =
+ fromRDD(srdd.subtract(other))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaDoubleRDD, numPartitions: Int): JavaDoubleRDD =
+ fromRDD(srdd.subtract(other, numPartitions))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaDoubleRDD, p: Partitioner): JavaDoubleRDD =
+ fromRDD(srdd.subtract(other, p))
+
+ /**
+ * Return a sampled subset of this RDD.
+ */
+ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaDoubleRDD =
+ fromRDD(srdd.sample(withReplacement, fraction, seed))
+
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
+ def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd))
+
+ // Double RDD functions
+
+ /** Add up the elements in this RDD. */
+ def sum(): Double = srdd.sum()
+
+ /**
+ * Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and count
+ * of the RDD's elements in one operation.
+ */
+ def stats(): StatCounter = srdd.stats()
+
+ /** Compute the mean of this RDD's elements. */
+ def mean(): Double = srdd.mean()
+
+ /** Compute the variance of this RDD's elements. */
+ def variance(): Double = srdd.variance()
+
+ /** Compute the standard deviation of this RDD's elements. */
+ def stdev(): Double = srdd.stdev()
+
+ /**
+ * Compute the sample standard deviation of this RDD's elements (which corrects for bias in
+ * estimating the standard deviation by dividing by N-1 instead of N).
+ */
+ def sampleStdev(): Double = srdd.sampleStdev()
+
+ /**
+ * Compute the sample variance of this RDD's elements (which corrects for bias in
+ * estimating the standard variance by dividing by N-1 instead of N).
+ */
+ def sampleVariance(): Double = srdd.sampleVariance()
+
+ /** Return the approximate mean of the elements in this RDD. */
+ def meanApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
+ srdd.meanApprox(timeout, confidence)
+
+ /** (Experimental) Approximate operation to return the mean within a timeout. */
+ def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout)
+
+ /** (Experimental) Approximate operation to return the sum within a timeout. */
+ def sumApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
+ srdd.sumApprox(timeout, confidence)
+
+ /** (Experimental) Approximate operation to return the sum within a timeout. */
+ def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
+}
+
+object JavaDoubleRDD {
+ def fromRDD(rdd: RDD[scala.Double]): JavaDoubleRDD = new JavaDoubleRDD(rdd)
+
+ implicit def toRDD(rdd: JavaDoubleRDD): RDD[scala.Double] = rdd.srdd
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
new file mode 100644
index 0000000000..09da35aee6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -0,0 +1,601 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java
+
+import java.util.{List => JList}
+import java.util.Comparator
+
+import scala.Tuple2
+import scala.collection.JavaConversions._
+
+import com.google.common.base.Optional
+import org.apache.hadoop.io.compress.CompressionCodec
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.OutputFormat
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.HashPartitioner
+import org.apache.spark.Partitioner
+import org.apache.spark.Partitioner._
+import org.apache.spark.RDD
+import org.apache.spark.SparkContext.rddToPairRDDFunctions
+import org.apache.spark.api.java.function.{Function2 => JFunction2}
+import org.apache.spark.api.java.function.{Function => JFunction}
+import org.apache.spark.partial.BoundedDouble
+import org.apache.spark.partial.PartialResult
+import org.apache.spark.rdd.OrderedRDDFunctions
+import org.apache.spark.storage.StorageLevel
+
+
+class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K],
+ implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
+
+ override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
+
+ override val classManifest: ClassManifest[(K, V)] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+
+ import JavaPairRDD._
+
+ // Common RDD functions
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache())
+
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. Can only be called once on each RDD.
+ */
+ def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.persist(newLevel))
+
+ // Transformations (return a new RDD)
+
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
+ def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct())
+
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
+ def distinct(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numPartitions))
+
+ /**
+ * Return a new RDD containing only the elements that satisfy a predicate.
+ */
+ def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue()))
+
+ /**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.coalesce(numPartitions))
+
+ /**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] =
+ fromRDD(rdd.coalesce(numPartitions, shuffle))
+
+ /**
+ * Return a sampled subset of this RDD.
+ */
+ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))
+
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
+ def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.union(other.rdd))
+
+ // first() has to be overridden here so that the generated method has the signature
+ // 'public scala.Tuple2 first()'; if the trait's definition is used,
+ // then the method has the signature 'public java.lang.Object first()',
+ // causing NoSuchMethodErrors at runtime.
+ override def first(): (K, V) = rdd.first()
+
+ // Pair RDD functions
+
+ /**
+ * Generic function to combine the elements for each key using a custom set of aggregation
+ * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
+ * "combined type" C * Note that V and C can be different -- for example, one might group an
+ * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
+ * functions:
+ *
+ * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
+ * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
+ * - `mergeCombiners`, to combine two C's into a single one.
+ *
+ * In addition, users can control the partitioning of the output RDD, and whether to perform
+ * map-side aggregation (if a mapper can produce multiple items with the same key).
+ */
+ def combineByKey[C](createCombiner: JFunction[V, C],
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C],
+ partitioner: Partitioner): JavaPairRDD[K, C] = {
+ implicit val cm: ClassManifest[C] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ fromRDD(rdd.combineByKey(
+ createCombiner,
+ mergeValue,
+ mergeCombiners,
+ partitioner
+ ))
+ }
+
+ /**
+ * Simplified version of combineByKey that hash-partitions the output RDD.
+ */
+ def combineByKey[C](createCombiner: JFunction[V, C],
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C],
+ numPartitions: Int): JavaPairRDD[K, C] =
+ combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
+
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce.
+ */
+ def reduceByKey(partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.reduceByKey(partitioner, func))
+
+ /**
+ * Merge the values for each key using an associative reduce function, but return the results
+ * immediately to the master as a Map. This will also perform the merging locally on each mapper
+ * before sending results to a reducer, similarly to a "combiner" in MapReduce.
+ */
+ def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] =
+ mapAsJavaMap(rdd.reduceByKeyLocally(func))
+
+ /** Count the number of elements for each key, and return the result to the master as a Map. */
+ def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
+
+ /**
+ * (Experimental) Approximate version of countByKey that can return a partial result if it does
+ * not finish within a timeout.
+ */
+ def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
+ rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
+
+ /**
+ * (Experimental) Approximate version of countByKey that can return a partial result if it does
+ * not finish within a timeout.
+ */
+ def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
+ : PartialResult[java.util.Map[K, BoundedDouble]] =
+ rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.foldByKey(zeroValue, partitioner)(func))
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, numPartitions: Int, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.foldByKey(zeroValue, numPartitions)(func))
+
+ /**
+ * Merge the values for each key using an associative function and a neutral "zero value" which may
+ * be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
+ * list concatenation, 0 for addition, or 1 for multiplication.).
+ */
+ def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.foldByKey(zeroValue)(func))
+
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
+ */
+ def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairRDD[K, V] =
+ fromRDD(rdd.reduceByKey(func, numPartitions))
+
+ /**
+ * Group the values for each key in the RDD into a single sequence. Allows controlling the
+ * partitioning of the resulting key-value pair RDD by passing a Partitioner.
+ */
+ def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JList[V]] =
+ fromRDD(groupByResultToJava(rdd.groupByKey(partitioner)))
+
+ /**
+ * Group the values for each key in the RDD into a single sequence. Hash-partitions the
+ * resulting RDD with into `numPartitions` partitions.
+ */
+ def groupByKey(numPartitions: Int): JavaPairRDD[K, JList[V]] =
+ fromRDD(groupByResultToJava(rdd.groupByKey(numPartitions)))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ *
+ * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+ * RDD will be <= us.
+ */
+ def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
+ fromRDD(rdd.subtract(other))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaPairRDD[K, V], numPartitions: Int): JavaPairRDD[K, V] =
+ fromRDD(rdd.subtract(other, numPartitions))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaPairRDD[K, V], p: Partitioner): JavaPairRDD[K, V] =
+ fromRDD(rdd.subtract(other, p))
+
+ /**
+ * Return a copy of the RDD partitioned using the specified partitioner.
+ */
+ def partitionBy(partitioner: Partitioner): JavaPairRDD[K, V] =
+ fromRDD(rdd.partitionBy(partitioner))
+
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce.
+ */
+ def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] =
+ fromRDD(rdd.join(other, partitioner))
+
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to
+ * partition the output RDD.
+ */
+ def leftOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
+ : JavaPairRDD[K, (V, Optional[W])] = {
+ val joinResult = rdd.leftOuterJoin(other, partitioner)
+ fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))})
+ }
+
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to
+ * partition the output RDD.
+ */
+ def rightOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
+ : JavaPairRDD[K, (Optional[V], W)] = {
+ val joinResult = rdd.rightOuterJoin(other, partitioner)
+ fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
+ }
+
+ /**
+ * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
+ * partitioner/parallelism level.
+ */
+ def combineByKey[C](createCombiner: JFunction[V, C],
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
+ implicit val cm: ClassManifest[C] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd)))
+ }
+
+ /**
+ * Merge the values for each key using an associative reduce function. This will also perform
+ * the merging locally on each mapper before sending results to a reducer, similarly to a
+ * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
+ * parallelism level.
+ */
+ def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = {
+ fromRDD(reduceByKey(defaultPartitioner(rdd), func))
+ }
+
+ /**
+ * Group the values for each key in the RDD into a single sequence. Hash-partitions the
+ * resulting RDD with the existing partitioner/parallelism level.
+ */
+ def groupByKey(): JavaPairRDD[K, JList[V]] =
+ fromRDD(groupByResultToJava(rdd.groupByKey()))
+
+ /**
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Performs a hash join across the cluster.
+ */
+ def join[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, W)] =
+ fromRDD(rdd.join(other))
+
+ /**
+ * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
+ * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
+ * (k, v2) is in `other`. Performs a hash join across the cluster.
+ */
+ def join[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, W)] =
+ fromRDD(rdd.join(other, numPartitions))
+
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
+ * using the existing partitioner/parallelism level.
+ */
+ def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Optional[W])] = {
+ val joinResult = rdd.leftOuterJoin(other)
+ fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))})
+ }
+
+ /**
+ * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
+ * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
+ * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
+ * into `numPartitions` partitions.
+ */
+ def leftOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, Optional[W])] = {
+ val joinResult = rdd.leftOuterJoin(other, numPartitions)
+ fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))})
+ }
+
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
+ * RDD using the existing partitioner/parallelism level.
+ */
+ def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Optional[V], W)] = {
+ val joinResult = rdd.rightOuterJoin(other)
+ fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
+ }
+
+ /**
+ * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
+ * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
+ * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
+ * RDD into the given number of partitions.
+ */
+ def rightOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (Optional[V], W)] = {
+ val joinResult = rdd.rightOuterJoin(other, numPartitions)
+ fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
+ }
+
+ /**
+ * Return the key-value pairs in this RDD to the master as a Map.
+ */
+ def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap())
+
+ /**
+ * Pass each value in the key-value pair RDD through a map function without changing the keys;
+ * this also retains the original RDD's partitioning.
+ */
+ def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = {
+ implicit val cm: ClassManifest[U] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ fromRDD(rdd.mapValues(f))
+ }
+
+ /**
+ * Pass each value in the key-value pair RDD through a flatMap function without changing the
+ * keys; this also retains the original RDD's partitioning.
+ */
+ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = {
+ import scala.collection.JavaConverters._
+ def fn = (x: V) => f.apply(x).asScala
+ implicit val cm: ClassManifest[U] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ fromRDD(rdd.flatMapValues(fn))
+ }
+
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
+ def cogroup[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
+ : JavaPairRDD[K, (JList[V], JList[W])] =
+ fromRDD(cogroupResultToJava(rdd.cogroup(other, partitioner)))
+
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
+ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], partitioner: Partitioner)
+ : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner)))
+
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
+ def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
+ fromRDD(cogroupResultToJava(rdd.cogroup(other)))
+
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
+ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
+ : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2)))
+
+ /**
+ * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
+ * list of values for that key in `this` as well as `other`.
+ */
+ def cogroup[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (JList[V], JList[W])]
+ = fromRDD(cogroupResultToJava(rdd.cogroup(other, numPartitions)))
+
+ /**
+ * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
+ * tuple with the list of values for that key in `this`, `other1` and `other2`.
+ */
+ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numPartitions: Int)
+ : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions)))
+
+ /** Alias for cogroup. */
+ def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
+ fromRDD(cogroupResultToJava(rdd.groupWith(other)))
+
+ /** Alias for cogroup. */
+ def groupWith[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
+ : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
+ fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2)))
+
+ /**
+ * Return the list of values in the RDD for key `key`. This operation is done efficiently if the
+ * RDD has a known partitioner by only searching the partition that the key maps to.
+ */
+ def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key))
+
+ /** Output the RDD to any Hadoop-supported file system. */
+ def saveAsHadoopFile[F <: OutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ conf: JobConf) {
+ rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
+ }
+
+ /** Output the RDD to any Hadoop-supported file system. */
+ def saveAsHadoopFile[F <: OutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F]) {
+ rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
+ }
+
+ /** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */
+ def saveAsHadoopFile[F <: OutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ codec: Class[_ <: CompressionCodec]) {
+ rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
+ }
+
+ /** Output the RDD to any Hadoop-supported file system. */
+ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F],
+ conf: Configuration) {
+ rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
+ }
+
+ /** Output the RDD to any Hadoop-supported file system. */
+ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
+ path: String,
+ keyClass: Class[_],
+ valueClass: Class[_],
+ outputFormatClass: Class[F]) {
+ rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass)
+ }
+
+ /**
+ * Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for
+ * that storage system. The JobConf should set an OutputFormat and any output paths required
+ * (e.g. a table name to write to) in the same way as it would be configured for a Hadoop
+ * MapReduce job.
+ */
+ def saveAsHadoopDataset(conf: JobConf) {
+ rdd.saveAsHadoopDataset(conf)
+ }
+
+ /**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements in
+ * ascending order. Calling `collect` or `save` on the resulting RDD will return or output an
+ * ordered list of records (in the `save` case, they will be written to multiple `part-X` files
+ * in the filesystem, in order of the keys).
+ */
+ def sortByKey(): JavaPairRDD[K, V] = sortByKey(true)
+
+ /**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+ * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+ * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+ * order of the keys).
+ */
+ def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = {
+ val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
+ sortByKey(comp, ascending)
+ }
+
+ /**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+ * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+ * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+ * order of the keys).
+ */
+ def sortByKey(comp: Comparator[K]): JavaPairRDD[K, V] = sortByKey(comp, true)
+
+ /**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+ * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+ * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+ * order of the keys).
+ */
+ def sortByKey(comp: Comparator[K], ascending: Boolean): JavaPairRDD[K, V] = {
+ class KeyOrdering(val a: K) extends Ordered[K] {
+ override def compare(b: K) = comp.compare(a, b)
+ }
+ implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
+ fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending))
+ }
+
+ /**
+ * Return an RDD with the keys of each tuple.
+ */
+ def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1))
+
+ /**
+ * Return an RDD with the values of each tuple.
+ */
+ def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2))
+}
+
+object JavaPairRDD {
+ def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassManifest[K],
+ vcm: ClassManifest[T]): RDD[(K, JList[T])] =
+ rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _)
+
+ def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassManifest[K],
+ vcm: ClassManifest[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V],
+ Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
+
+ def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1],
+ Seq[W2]))])(implicit kcm: ClassManifest[K]) : RDD[(K, (JList[V], JList[W1],
+ JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues(
+ (x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1),
+ seqAsJavaList(x._2),
+ seqAsJavaList(x._3)))
+
+ def fromRDD[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd)
+
+ implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
new file mode 100644
index 0000000000..68cfcf5999
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java
+
+import org.apache.spark._
+import org.apache.spark.api.java.function.{Function => JFunction}
+import org.apache.spark.storage.StorageLevel
+
+class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends
+JavaRDDLike[T, JavaRDD[T]] {
+
+ override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd)
+
+ // Common RDD functions
+
+ /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
+ def cache(): JavaRDD[T] = wrapRDD(rdd.cache())
+
+ /**
+ * Set this RDD's storage level to persist its values across operations after the first time
+ * it is computed. This can only be used to assign a new storage level if the RDD does not
+ * have a storage level set yet..
+ */
+ def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
+
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ */
+ def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
+
+ // Transformations (return a new RDD)
+
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
+ def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct())
+
+ /**
+ * Return a new RDD containing the distinct elements in this RDD.
+ */
+ def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions))
+
+ /**
+ * Return a new RDD containing only the elements that satisfy a predicate.
+ */
+ def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] =
+ wrapRDD(rdd.filter((x => f(x).booleanValue())))
+
+ /**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int): JavaRDD[T] = rdd.coalesce(numPartitions)
+
+ /**
+ * Return a new RDD that is reduced into `numPartitions` partitions.
+ */
+ def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] =
+ rdd.coalesce(numPartitions, shuffle)
+
+ /**
+ * Return a sampled subset of this RDD.
+ */
+ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
+ wrapRDD(rdd.sample(withReplacement, fraction, seed))
+
+ /**
+ * Return the union of this RDD and another one. Any identical elements will appear multiple
+ * times (use `.distinct()` to eliminate them).
+ */
+ def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ *
+ * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+ * RDD will be <= us.
+ */
+ def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaRDD[T], numPartitions: Int): JavaRDD[T] =
+ wrapRDD(rdd.subtract(other, numPartitions))
+
+ /**
+ * Return an RDD with the elements from `this` that are not in `other`.
+ */
+ def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
+ wrapRDD(rdd.subtract(other, p))
+}
+
+object JavaRDD {
+
+ implicit def fromRDD[T: ClassManifest](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd)
+
+ implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd
+}
+
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
new file mode 100644
index 0000000000..1ad8514980
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -0,0 +1,426 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java
+
+import java.util.{List => JList, Comparator}
+import scala.Tuple2
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.io.compress.CompressionCodec
+import org.apache.spark.{SparkContext, Partition, RDD, TaskContext}
+import org.apache.spark.api.java.JavaPairRDD._
+import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
+import org.apache.spark.partial.{PartialResult, BoundedDouble}
+import org.apache.spark.storage.StorageLevel
+import com.google.common.base.Optional
+
+
+trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
+ def wrapRDD(rdd: RDD[T]): This
+
+ implicit val classManifest: ClassManifest[T]
+
+ def rdd: RDD[T]
+
+ /** Set of partitions in this RDD. */
+ def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq)
+
+ /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */
+ def context: SparkContext = rdd.context
+
+ /** A unique ID for this RDD (within its SparkContext). */
+ def id: Int = rdd.id
+
+ /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
+ def getStorageLevel: StorageLevel = rdd.getStorageLevel
+
+ /**
+ * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
+ * This should ''not'' be called by users directly, but is available for implementors of custom
+ * subclasses of RDD.
+ */
+ def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] =
+ asJavaIterator(rdd.iterator(split, taskContext))
+
+ // Transformations (return a new RDD)
+
+ /**
+ * Return a new RDD by applying a function to all elements of this RDD.
+ */
+ def map[R](f: JFunction[T, R]): JavaRDD[R] =
+ new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType())
+
+ /**
+ * Return a new RDD by applying a function to all elements of this RDD.
+ */
+ def map[R](f: DoubleFunction[T]): JavaDoubleRDD =
+ new JavaDoubleRDD(rdd.map(x => f(x).doubleValue()))
+
+ /**
+ * Return a new RDD by applying a function to all elements of this RDD.
+ */
+ def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
+ def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
+ new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType())
+ }
+
+ /**
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
+ */
+ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = {
+ import scala.collection.JavaConverters._
+ def fn = (x: T) => f.apply(x).asScala
+ JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType())
+ }
+
+ /**
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
+ */
+ def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = {
+ import scala.collection.JavaConverters._
+ def fn = (x: T) => f.apply(x).asScala
+ new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue()))
+ }
+
+ /**
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
+ */
+ def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
+ import scala.collection.JavaConverters._
+ def fn = (x: T) => f.apply(x).asScala
+ def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
+ JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
+ }
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD.
+ */
+ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
+ }
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD.
+ */
+ def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue()))
+ }
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD.
+ */
+ def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]):
+ JavaPairRDD[K2, V2] = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
+ }
+
+ /**
+ * Return an RDD created by coalescing all elements within each partition into an array.
+ */
+ def glom(): JavaRDD[JList[T]] =
+ new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
+
+ /**
+ * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
+ * elements (a, b) where a is in `this` and b is in `other`.
+ */
+ def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] =
+ JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest,
+ other.classManifest)
+
+ /**
+ * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
+ * mapping to that key.
+ */
+ def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = {
+ implicit val kcm: ClassManifest[K] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val vcm: ClassManifest[JList[T]] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm)
+ }
+
+ /**
+ * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
+ * mapping to that key.
+ */
+ def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = {
+ implicit val kcm: ClassManifest[K] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val vcm: ClassManifest[JList[T]] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm)
+ }
+
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ */
+ def pipe(command: String): JavaRDD[String] = rdd.pipe(command)
+
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ */
+ def pipe(command: JList[String]): JavaRDD[String] =
+ rdd.pipe(asScalaBuffer(command))
+
+ /**
+ * Return an RDD created by piping elements to a forked external process.
+ */
+ def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] =
+ rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env))
+
+ /**
+ * Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
+ * second element in each RDD, etc. Assumes that the two RDDs have the *same number of
+ * partitions* and the *same number of elements in each partition* (e.g. one was made through
+ * a map on the other).
+ */
+ def zip[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = {
+ JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest)
+ }
+
+ /**
+ * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
+ * applying a function to the zipped partitions. Assumes that all the RDDs have the
+ * *same number of partitions*, but does *not* require them to have the same number
+ * of elements in each partition.
+ */
+ def zipPartitions[U, V](
+ other: JavaRDDLike[U, _],
+ f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = {
+ def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
+ f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
+ JavaRDD.fromRDD(
+ rdd.zipPartitions(other.rdd)(fn)(other.classManifest, f.elementType()))(f.elementType())
+ }
+
+ // Actions (launch a job to return a value to the user program)
+
+ /**
+ * Applies a function f to all elements of this RDD.
+ */
+ def foreach(f: VoidFunction[T]) {
+ val cleanF = rdd.context.clean(f)
+ rdd.foreach(cleanF)
+ }
+
+ /**
+ * Return an array that contains all of the elements in this RDD.
+ */
+ def collect(): JList[T] = {
+ import scala.collection.JavaConversions._
+ val arr: java.util.Collection[T] = rdd.collect().toSeq
+ new java.util.ArrayList(arr)
+ }
+
+ /**
+ * Reduces the elements of this RDD using the specified commutative and associative binary operator.
+ */
+ def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
+
+ /**
+ * Aggregate the elements of each partition, and then the results for all the partitions, using a
+ * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
+ * modify t1 and return it as its result value to avoid object allocation; however, it should not
+ * modify t2.
+ */
+ def fold(zeroValue: T)(f: JFunction2[T, T, T]): T =
+ rdd.fold(zeroValue)(f)
+
+ /**
+ * Aggregate the elements of each partition, and then the results for all the partitions, using
+ * given combine functions and a neutral "zero value". This function can return a different result
+ * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U
+ * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are
+ * allowed to modify and return their first argument instead of creating a new U to avoid memory
+ * allocation.
+ */
+ def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U],
+ combOp: JFunction2[U, U, U]): U =
+ rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType)
+
+ /**
+ * Return the number of elements in the RDD.
+ */
+ def count(): Long = rdd.count()
+
+ /**
+ * (Experimental) Approximate version of count() that returns a potentially incomplete result
+ * within a timeout, even if not all tasks have finished.
+ */
+ def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
+ rdd.countApprox(timeout, confidence)
+
+ /**
+ * (Experimental) Approximate version of count() that returns a potentially incomplete result
+ * within a timeout, even if not all tasks have finished.
+ */
+ def countApprox(timeout: Long): PartialResult[BoundedDouble] =
+ rdd.countApprox(timeout)
+
+ /**
+ * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
+ * combine step happens locally on the master, equivalent to running a single reduce task.
+ */
+ def countByValue(): java.util.Map[T, java.lang.Long] =
+ mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
+
+ /**
+ * (Experimental) Approximate version of countByValue().
+ */
+ def countByValueApprox(
+ timeout: Long,
+ confidence: Double
+ ): PartialResult[java.util.Map[T, BoundedDouble]] =
+ rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap)
+
+ /**
+ * (Experimental) Approximate version of countByValue().
+ */
+ def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] =
+ rdd.countByValueApprox(timeout).map(mapAsJavaMap)
+
+ /**
+ * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
+ * it will be slow if a lot of partitions are required. In that case, use collect() to get the
+ * whole RDD instead.
+ */
+ def take(num: Int): JList[T] = {
+ import scala.collection.JavaConversions._
+ val arr: java.util.Collection[T] = rdd.take(num).toSeq
+ new java.util.ArrayList(arr)
+ }
+
+ def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = {
+ import scala.collection.JavaConversions._
+ val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
+ new java.util.ArrayList(arr)
+ }
+
+ /**
+ * Return the first element in this RDD.
+ */
+ def first(): T = rdd.first()
+
+ /**
+ * Save this RDD as a text file, using string representations of elements.
+ */
+ def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
+
+
+ /**
+ * Save this RDD as a compressed text file, using string representations of elements.
+ */
+ def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) =
+ rdd.saveAsTextFile(path, codec)
+
+ /**
+ * Save this RDD as a SequenceFile of serialized objects.
+ */
+ def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
+
+ /**
+ * Creates tuples of the elements in this RDD by applying `f`.
+ */
+ def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
+ implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ JavaPairRDD.fromRDD(rdd.keyBy(f))
+ }
+
+ /**
+ * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
+ * directory set with SparkContext.setCheckpointDir() and all references to its parent
+ * RDDs will be removed. This function must be called before any job has been
+ * executed on this RDD. It is strongly recommended that this RDD is persisted in
+ * memory, otherwise saving it on a file will require recomputation.
+ */
+ def checkpoint() = rdd.checkpoint()
+
+ /**
+ * Return whether this RDD has been checkpointed or not
+ */
+ def isCheckpointed: Boolean = rdd.isCheckpointed
+
+ /**
+ * Gets the name of the file to which this RDD was checkpointed
+ */
+ def getCheckpointFile(): Optional[String] = {
+ JavaUtils.optionToOptional(rdd.getCheckpointFile)
+ }
+
+ /** A description of this RDD and its recursive dependencies for debugging. */
+ def toDebugString(): String = {
+ rdd.toDebugString
+ }
+
+ /**
+ * Returns the top K elements from this RDD as defined by
+ * the specified Comparator[T].
+ * @param num the number of top elements to return
+ * @param comp the comparator that defines the order
+ * @return an array of top elements
+ */
+ def top(num: Int, comp: Comparator[T]): JList[T] = {
+ import scala.collection.JavaConversions._
+ val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp))
+ val arr: java.util.Collection[T] = topElems.toSeq
+ new java.util.ArrayList(arr)
+ }
+
+ /**
+ * Returns the top K elements from this RDD using the
+ * natural ordering for T.
+ * @param num the number of top elements to return
+ * @return an array of top elements
+ */
+ def top(num: Int): JList[T] = {
+ val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
+ top(num, comp)
+ }
+
+ /**
+ * Returns the first K elements from this RDD as defined by
+ * the specified Comparator[T] and maintains the order.
+ * @param num the number of top elements to return
+ * @param comp the comparator that defines the order
+ * @return an array of top elements
+ */
+ def takeOrdered(num: Int, comp: Comparator[T]): JList[T] = {
+ import scala.collection.JavaConversions._
+ val topElems = rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp))
+ val arr: java.util.Collection[T] = topElems.toSeq
+ new java.util.ArrayList(arr)
+ }
+
+ /**
+ * Returns the first K elements from this RDD using the
+ * natural ordering for T while maintain the order.
+ * @param num the number of top elements to return
+ * @return an array of top elements
+ */
+ def takeOrdered(num: Int): JList[T] = {
+ val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
+ takeOrdered(num, comp)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
new file mode 100644
index 0000000000..618a7b3bf7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java
+
+import java.util.{Map => JMap}
+
+import scala.collection.JavaConversions
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapred.InputFormat
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+
+import org.apache.spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext}
+import org.apache.spark.SparkContext.IntAccumulatorParam
+import org.apache.spark.SparkContext.DoubleAccumulatorParam
+import org.apache.spark.broadcast.Broadcast
+
+import com.google.common.base.Optional
+
+/**
+ * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns [[org.apache.spark.api.java.JavaRDD]]s and
+ * works with Java collections instead of Scala ones.
+ */
+class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround {
+
+ /**
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param appName A name for your application, to display on the cluster web UI
+ */
+ def this(master: String, appName: String) = this(new SparkContext(master, appName))
+
+ /**
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param appName A name for your application, to display on the cluster web UI
+ * @param sparkHome The SPARK_HOME directory on the slave nodes
+ * @param jarFile JAR file to send to the cluster. This can be a path on the local file system
+ * or an HDFS, HTTP, HTTPS, or FTP URL.
+ */
+ def this(master: String, appName: String, sparkHome: String, jarFile: String) =
+ this(new SparkContext(master, appName, sparkHome, Seq(jarFile)))
+
+ /**
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param appName A name for your application, to display on the cluster web UI
+ * @param sparkHome The SPARK_HOME directory on the slave nodes
+ * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
+ * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ */
+ def this(master: String, appName: String, sparkHome: String, jars: Array[String]) =
+ this(new SparkContext(master, appName, sparkHome, jars.toSeq))
+
+ /**
+ * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
+ * @param appName A name for your application, to display on the cluster web UI
+ * @param sparkHome The SPARK_HOME directory on the slave nodes
+ * @param jars Collection of JARs to send to the cluster. These can be paths on the local file
+ * system or HDFS, HTTP, HTTPS, or FTP URLs.
+ * @param environment Environment variables to set on worker nodes
+ */
+ def this(master: String, appName: String, sparkHome: String, jars: Array[String],
+ environment: JMap[String, String]) =
+ this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment))
+
+ private[spark] val env = sc.env
+
+ /** Distribute a local Scala collection to form an RDD. */
+ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)
+ }
+
+ /** Distribute a local Scala collection to form an RDD. */
+ def parallelize[T](list: java.util.List[T]): JavaRDD[T] =
+ parallelize(list, sc.defaultParallelism)
+
+ /** Distribute a local Scala collection to form an RDD. */
+ def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int)
+ : JavaPairRDD[K, V] = {
+ implicit val kcm: ClassManifest[K] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val vcm: ClassManifest[V] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices))
+ }
+
+ /** Distribute a local Scala collection to form an RDD. */
+ def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]]): JavaPairRDD[K, V] =
+ parallelizePairs(list, sc.defaultParallelism)
+
+ /** Distribute a local Scala collection to form an RDD. */
+ def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD =
+ JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()),
+ numSlices))
+
+ /** Distribute a local Scala collection to form an RDD. */
+ def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD =
+ parallelizeDoubles(list, sc.defaultParallelism)
+
+ /**
+ * Read a text file from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI, and return it as an RDD of Strings.
+ */
+ def textFile(path: String): JavaRDD[String] = sc.textFile(path)
+
+ /**
+ * Read a text file from HDFS, a local file system (available on all nodes), or any
+ * Hadoop-supported file system URI, and return it as an RDD of Strings.
+ */
+ def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits)
+
+ /**Get an RDD for a Hadoop SequenceFile with given key and value types. */
+ def sequenceFile[K, V](path: String,
+ keyClass: Class[K],
+ valueClass: Class[V],
+ minSplits: Int
+ ): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits))
+ }
+
+ /**Get an RDD for a Hadoop SequenceFile. */
+ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]):
+ JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass))
+ }
+
+ /**
+ * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and
+ * BytesWritable values that contain a serialized partition. This is still an experimental storage
+ * format and may not be supported exactly as is in future Spark releases. It will also be pretty
+ * slow if you use the default serializer (Java serialization), though the nice thing about it is
+ * that there's very little effort required to save arbitrary objects.
+ */
+ def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ sc.objectFile(path, minSplits)(cm)
+ }
+
+ /**
+ * Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and
+ * BytesWritable values that contain a serialized partition. This is still an experimental storage
+ * format and may not be supported exactly as is in future Spark releases. It will also be pretty
+ * slow if you use the default serializer (Java serialization), though the nice thing about it is
+ * that there's very little effort required to save arbitrary objects.
+ */
+ def objectFile[T](path: String): JavaRDD[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ sc.objectFile(path)(cm)
+ }
+
+ /**
+ * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any
+ * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
+ * etc).
+ */
+ def hadoopRDD[K, V, F <: InputFormat[K, V]](
+ conf: JobConf,
+ inputFormatClass: Class[F],
+ keyClass: Class[K],
+ valueClass: Class[V],
+ minSplits: Int
+ ): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits))
+ }
+
+ /**
+ * Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any
+ * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
+ * etc).
+ */
+ def hadoopRDD[K, V, F <: InputFormat[K, V]](
+ conf: JobConf,
+ inputFormatClass: Class[F],
+ keyClass: Class[K],
+ valueClass: Class[V]
+ ): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass))
+ }
+
+ /** Get an RDD for a Hadoop file with an arbitrary InputFormat */
+ def hadoopFile[K, V, F <: InputFormat[K, V]](
+ path: String,
+ inputFormatClass: Class[F],
+ keyClass: Class[K],
+ valueClass: Class[V],
+ minSplits: Int
+ ): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits))
+ }
+
+ /** Get an RDD for a Hadoop file with an arbitrary InputFormat */
+ def hadoopFile[K, V, F <: InputFormat[K, V]](
+ path: String,
+ inputFormatClass: Class[F],
+ keyClass: Class[K],
+ valueClass: Class[V]
+ ): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(keyClass)
+ implicit val vcm = ClassManifest.fromClass(valueClass)
+ new JavaPairRDD(sc.hadoopFile(path,
+ inputFormatClass, keyClass, valueClass))
+ }
+
+ /**
+ * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
+ * and extra configuration options to pass to the input format.
+ */
+ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](
+ path: String,
+ fClass: Class[F],
+ kClass: Class[K],
+ vClass: Class[V],
+ conf: Configuration): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(kClass)
+ implicit val vcm = ClassManifest.fromClass(vClass)
+ new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf))
+ }
+
+ /**
+ * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
+ * and extra configuration options to pass to the input format.
+ */
+ def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
+ conf: Configuration,
+ fClass: Class[F],
+ kClass: Class[K],
+ vClass: Class[V]): JavaPairRDD[K, V] = {
+ implicit val kcm = ClassManifest.fromClass(kClass)
+ implicit val vcm = ClassManifest.fromClass(vClass)
+ new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass))
+ }
+
+ /** Build the union of two or more RDDs. */
+ override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = {
+ val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
+ implicit val cm: ClassManifest[T] = first.classManifest
+ sc.union(rdds)(cm)
+ }
+
+ /** Build the union of two or more RDDs. */
+ override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]])
+ : JavaPairRDD[K, V] = {
+ val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
+ implicit val cm: ClassManifest[(K, V)] = first.classManifest
+ implicit val kcm: ClassManifest[K] = first.kManifest
+ implicit val vcm: ClassManifest[V] = first.vManifest
+ new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm)
+ }
+
+ /** Build the union of two or more RDDs. */
+ override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = {
+ val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd)
+ new JavaDoubleRDD(sc.union(rdds))
+ }
+
+ /**
+ * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ */
+ def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] =
+ sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]]
+
+ /**
+ * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ */
+ def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] =
+ sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]]
+
+ /**
+ * Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ */
+ def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue)
+
+ /**
+ * Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ */
+ def accumulator(initialValue: Double): Accumulator[java.lang.Double] =
+ doubleAccumulator(initialValue)
+
+ /**
+ * Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" values
+ * to using the `add` method. Only the master can access the accumulator's `value`.
+ */
+ def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
+ sc.accumulator(initialValue)(accumulatorParam)
+
+ /**
+ * Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks can
+ * "add" values with `add`. Only the master can access the accumuable's `value`.
+ */
+ def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] =
+ sc.accumulable(initialValue)(param)
+
+ /**
+ * Broadcast a read-only variable to the cluster, returning a [[org.apache.spark.Broadcast]] object for
+ * reading it in distributed functions. The variable will be sent to each cluster only once.
+ */
+ def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)
+
+ /** Shut down the SparkContext. */
+ def stop() {
+ sc.stop()
+ }
+
+ /**
+ * Get Spark's home location from either a value set through the constructor,
+ * or the spark.home Java property, or the SPARK_HOME environment variable
+ * (in that order of preference). If neither of these is set, return None.
+ */
+ def getSparkHome(): Optional[String] = JavaUtils.optionToOptional(sc.getSparkHome())
+
+ /**
+ * Add a file to be downloaded with this Spark job on every node.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(path)` to find its download location.
+ */
+ def addFile(path: String) {
+ sc.addFile(path)
+ }
+
+ /**
+ * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI.
+ */
+ def addJar(path: String) {
+ sc.addJar(path)
+ }
+
+ /**
+ * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
+ * any new nodes.
+ */
+ def clearJars() {
+ sc.clearJars()
+ }
+
+ /**
+ * Clear the job's list of files added by `addFile` so that they do not get downloaded to
+ * any new nodes.
+ */
+ def clearFiles() {
+ sc.clearFiles()
+ }
+
+ /**
+ * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse.
+ */
+ def hadoopConfiguration(): Configuration = {
+ sc.hadoopConfiguration
+ }
+
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists and useExisting is set to true, then the
+ * exisiting directory will be used. Otherwise an exception will be thrown to
+ * prevent accidental overriding of checkpoint files in the existing directory.
+ */
+ def setCheckpointDir(dir: String, useExisting: Boolean) {
+ sc.setCheckpointDir(dir, useExisting)
+ }
+
+ /**
+ * Set the directory under which RDDs are going to be checkpointed. The directory must
+ * be a HDFS path if running on a cluster. If the directory does not exist, it will
+ * be created. If the directory exists, an exception will be thrown to prevent accidental
+ * overriding of checkpoint files.
+ */
+ def setCheckpointDir(dir: String) {
+ sc.setCheckpointDir(dir)
+ }
+
+ protected def checkpointFile[T](path: String): JavaRDD[T] = {
+ implicit val cm: ClassManifest[T] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ new JavaRDD(sc.checkpointFile(path))
+ }
+}
+
+object JavaSparkContext {
+ implicit def fromSparkContext(sc: SparkContext): JavaSparkContext = new JavaSparkContext(sc)
+
+ implicit def toSparkContext(jsc: JavaSparkContext): SparkContext = jsc.sc
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
new file mode 100644
index 0000000000..c9cbce5624
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java;
+
+import java.util.Arrays;
+import java.util.ArrayList;
+import java.util.List;
+
+// See
+// http://scala-programming-language.1934581.n4.nabble.com/Workaround-for-implementing-java-varargs-in-2-7-2-final-tp1944767p1944772.html
+abstract class JavaSparkContextVarargsWorkaround {
+ public <T> JavaRDD<T> union(JavaRDD<T>... rdds) {
+ if (rdds.length == 0) {
+ throw new IllegalArgumentException("Union called on empty list");
+ }
+ ArrayList<JavaRDD<T>> rest = new ArrayList<JavaRDD<T>>(rdds.length - 1);
+ for (int i = 1; i < rdds.length; i++) {
+ rest.add(rdds[i]);
+ }
+ return union(rdds[0], rest);
+ }
+
+ public JavaDoubleRDD union(JavaDoubleRDD... rdds) {
+ if (rdds.length == 0) {
+ throw new IllegalArgumentException("Union called on empty list");
+ }
+ ArrayList<JavaDoubleRDD> rest = new ArrayList<JavaDoubleRDD>(rdds.length - 1);
+ for (int i = 1; i < rdds.length; i++) {
+ rest.add(rdds[i]);
+ }
+ return union(rdds[0], rest);
+ }
+
+ public <K, V> JavaPairRDD<K, V> union(JavaPairRDD<K, V>... rdds) {
+ if (rdds.length == 0) {
+ throw new IllegalArgumentException("Union called on empty list");
+ }
+ ArrayList<JavaPairRDD<K, V>> rest = new ArrayList<JavaPairRDD<K, V>>(rdds.length - 1);
+ for (int i = 1; i < rdds.length; i++) {
+ rest.add(rdds[i]);
+ }
+ return union(rdds[0], rest);
+ }
+
+ // These methods take separate "first" and "rest" elements to avoid having the same type erasure
+ abstract public <T> JavaRDD<T> union(JavaRDD<T> first, List<JavaRDD<T>> rest);
+ abstract public JavaDoubleRDD union(JavaDoubleRDD first, List<JavaDoubleRDD> rest);
+ abstract public <K, V> JavaPairRDD<K, V> union(JavaPairRDD<K, V> first, List<JavaPairRDD<K, V>> rest);
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
new file mode 100644
index 0000000000..ecbf18849a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java
+
+import com.google.common.base.Optional
+
+object JavaUtils {
+ def optionToOptional[T](option: Option[T]): Optional[T] =
+ option match {
+ case Some(value) => Optional.of(value)
+ case None => Optional.absent()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/StorageLevels.java b/core/src/main/scala/org/apache/spark/api/java/StorageLevels.java
new file mode 100644
index 0000000000..0744269773
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/StorageLevels.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java;
+
+import org.apache.spark.storage.StorageLevel;
+
+/**
+ * Expose some commonly useful storage level constants.
+ */
+public class StorageLevels {
+ public static final StorageLevel NONE = new StorageLevel(false, false, false, 1);
+ public static final StorageLevel DISK_ONLY = new StorageLevel(true, false, false, 1);
+ public static final StorageLevel DISK_ONLY_2 = new StorageLevel(true, false, false, 2);
+ public static final StorageLevel MEMORY_ONLY = new StorageLevel(false, true, true, 1);
+ public static final StorageLevel MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2);
+ public static final StorageLevel MEMORY_ONLY_SER = new StorageLevel(false, true, false, 1);
+ public static final StorageLevel MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2);
+ public static final StorageLevel MEMORY_AND_DISK = new StorageLevel(true, true, true, 1);
+ public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2);
+ public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1);
+ public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2);
+
+ /**
+ * Create a new StorageLevel object.
+ * @param useDisk saved to disk, if true
+ * @param useMemory saved to memory, if true
+ * @param deserialized saved as deserialized objects, if true
+ * @param replication replication factor
+ */
+ public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) {
+ return StorageLevel.apply(useDisk, useMemory, deserialized, replication);
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
new file mode 100644
index 0000000000..4830067f7a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+
+import scala.runtime.AbstractFunction1;
+
+import java.io.Serializable;
+
+/**
+ * A function that returns zero or more records of type Double from each input record.
+ */
+// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
+// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
+public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
+ implements Serializable {
+
+ public abstract Iterable<Double> call(T t);
+
+ @Override
+ public final Iterable<Double> apply(T t) { return call(t); }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
new file mode 100644
index 0000000000..db34cd190a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+
+import scala.runtime.AbstractFunction1;
+
+import java.io.Serializable;
+
+/**
+ * A function that returns Doubles, and can be used to construct DoubleRDDs.
+ */
+// DoubleFunction does not extend Function because some UDF functions, like map,
+// are overloaded for both Function and DoubleFunction.
+public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
+ implements Serializable {
+
+ public abstract Double call(T t) throws Exception;
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
new file mode 100644
index 0000000000..158539a846
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function
+
+/**
+ * A function that returns zero or more output records from each input record.
+ */
+abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
+ @throws(classOf[Exception])
+ def call(x: T) : java.lang.Iterable[R]
+
+ def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]]
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
new file mode 100644
index 0000000000..5ef6a814f5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function
+
+/**
+ * A function that takes two inputs and returns zero or more output records.
+ */
+abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
+ @throws(classOf[Exception])
+ def call(a: A, b:B) : java.lang.Iterable[C]
+
+ def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
new file mode 100644
index 0000000000..b9070cfd83
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import scala.reflect.ClassManifest;
+import scala.reflect.ClassManifest$;
+import scala.runtime.AbstractFunction1;
+
+import java.io.Serializable;
+
+
+/**
+ * Base class for functions whose return types do not create special RDDs. PairFunction and
+ * DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed
+ * when mapping RDDs of other types.
+ */
+public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
+ public abstract R call(T t) throws Exception;
+
+ public ClassManifest<R> returnType() {
+ return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
new file mode 100644
index 0000000000..d4c9154869
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import scala.reflect.ClassManifest;
+import scala.reflect.ClassManifest$;
+import scala.runtime.AbstractFunction2;
+
+import java.io.Serializable;
+
+/**
+ * A two-argument function that takes arguments of type T1 and T2 and returns an R.
+ */
+public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
+ implements Serializable {
+
+ public abstract R call(T1 t1, T2 t2) throws Exception;
+
+ public ClassManifest<R> returnType() {
+ return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
new file mode 100644
index 0000000000..c0e5544b7d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import scala.Tuple2;
+import scala.reflect.ClassManifest;
+import scala.reflect.ClassManifest$;
+import scala.runtime.AbstractFunction1;
+
+import java.io.Serializable;
+
+/**
+ * A function that returns zero or more key-value pair records from each input record. The
+ * key-value pairs are represented as scala.Tuple2 objects.
+ */
+// PairFlatMapFunction does not extend FlatMapFunction because flatMap is
+// overloaded for both FlatMapFunction and PairFlatMapFunction.
+public abstract class PairFlatMapFunction<T, K, V>
+ extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
+ implements Serializable {
+
+ public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception;
+
+ public ClassManifest<K> keyType() {
+ return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+
+ public ClassManifest<V> valueType() {
+ return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
new file mode 100644
index 0000000000..40480fe8e8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import scala.Tuple2;
+import scala.reflect.ClassManifest;
+import scala.reflect.ClassManifest$;
+import scala.runtime.AbstractFunction1;
+
+import java.io.Serializable;
+
+/**
+ * A function that returns key-value pairs (Tuple2<K, V>), and can be used to construct PairRDDs.
+ */
+// PairFunction does not extend Function because some UDF functions, like map,
+// are overloaded for both Function and PairFunction.
+public abstract class PairFunction<T, K, V>
+ extends WrappedFunction1<T, Tuple2<K, V>>
+ implements Serializable {
+
+ public abstract Tuple2<K, V> call(T t) throws Exception;
+
+ public ClassManifest<K> keyType() {
+ return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+
+ public ClassManifest<V> valueType() {
+ return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala
new file mode 100644
index 0000000000..ea94313a4a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/VoidFunction.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function
+
+/**
+ * A function with no return value.
+ */
+// This allows Java users to write void methods without having to return Unit.
+abstract class VoidFunction[T] extends Serializable {
+ @throws(classOf[Exception])
+ def call(t: T) : Unit
+}
+
+// VoidFunction cannot extend AbstractFunction1 (because that would force users to explicitly
+// return Unit), so it is implicitly converted to a Function1[T, Unit]:
+object VoidFunction {
+ implicit def toFunction[T](f: VoidFunction[T]) : Function1[T, Unit] = ((x : T) => f.call(x))
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala
new file mode 100644
index 0000000000..cfe694f65d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction1.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function
+
+import scala.runtime.AbstractFunction1
+
+/**
+ * Subclass of Function1 for ease of calling from Java. The main thing it does is re-expose the
+ * apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply
+ * isn't marked to allow that).
+ */
+private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] {
+ @throws(classOf[Exception])
+ def call(t: T): R
+
+ final def apply(t: T): R = call(t)
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala
new file mode 100644
index 0000000000..eb9277c6fb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction2.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function
+
+import scala.runtime.AbstractFunction2
+
+/**
+ * Subclass of Function2 for ease of calling from Java. The main thing it does is re-expose the
+ * apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply
+ * isn't marked to allow that).
+ */
+private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] {
+ @throws(classOf[Exception])
+ def call(t1: T1, t2: T2): R
+
+ final def apply(t1: T1, t2: T2): R = call(t1, t2)
+}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala
new file mode 100644
index 0000000000..eea63d5a4e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import org.apache.spark.Partitioner
+import org.apache.spark.Utils
+import java.util.Arrays
+
+/**
+ * A [[org.apache.spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
+ *
+ * Stores the unique id() of the Python-side partitioning function so that it is incorporated into
+ * equality comparisons. Correctness requires that the id is a unique identifier for the
+ * lifetime of the program (i.e. that it is not re-used as the id of a different partitioning
+ * function). This can be ensured by using the Python id() function and maintaining a reference
+ * to the Python partitioning function so that its id() is not reused.
+ */
+private[spark] class PythonPartitioner(
+ override val numPartitions: Int,
+ val pyPartitionFunctionId: Long)
+ extends Partitioner {
+
+ override def getPartition(key: Any): Int = key match {
+ case null => 0
+ case key: Array[Byte] => Utils.nonNegativeMod(Arrays.hashCode(key), numPartitions)
+ case _ => Utils.nonNegativeMod(key.hashCode(), numPartitions)
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case h: PythonPartitioner =>
+ h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
+ case _ =>
+ false
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
new file mode 100644
index 0000000000..621f0fe8ee
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.io._
+import java.net._
+import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark._
+import org.apache.spark.rdd.PipedRDD
+
+
+private[spark] class PythonRDD[T: ClassManifest](
+ parent: RDD[T],
+ command: Seq[String],
+ envVars: JMap[String, String],
+ pythonIncludes: JList[String],
+ preservePartitoning: Boolean,
+ pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]])
+ extends RDD[Array[Byte]](parent) {
+
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
+ // Similar to Runtime.exec(), if we are given a single string, split it into words
+ // using a standard StringTokenizer (i.e. by spaces)
+ def this(parent: RDD[T], command: String, envVars: JMap[String, String],
+ pythonIncludes: JList[String],
+ preservePartitoning: Boolean, pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]]) =
+ this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
+ broadcastVars, accumulator)
+
+ override def getPartitions = parent.partitions
+
+ override val partitioner = if (preservePartitoning) parent.partitioner else None
+
+
+ override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
+ val startTime = System.currentTimeMillis
+ val env = SparkEnv.get
+ val worker = env.createPythonWorker(pythonExec, envVars.toMap)
+
+ // Start a thread to feed the process input from our parent's iterator
+ new Thread("stdin writer for " + pythonExec) {
+ override def run() {
+ try {
+ SparkEnv.set(env)
+ val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+ val dataOut = new DataOutputStream(stream)
+ val printOut = new PrintWriter(stream)
+ // Partition index
+ dataOut.writeInt(split.index)
+ // sparkFilesDir
+ PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+ // Broadcast variables
+ dataOut.writeInt(broadcastVars.length)
+ for (broadcast <- broadcastVars) {
+ dataOut.writeLong(broadcast.id)
+ dataOut.writeInt(broadcast.value.length)
+ dataOut.write(broadcast.value)
+ }
+ // Python includes (*.zip and *.egg files)
+ dataOut.writeInt(pythonIncludes.length)
+ for (f <- pythonIncludes) {
+ PythonRDD.writeAsPickle(f, dataOut)
+ }
+ dataOut.flush()
+ // Serialized user code
+ for (elem <- command) {
+ printOut.println(elem)
+ }
+ printOut.flush()
+ // Data values
+ for (elem <- parent.iterator(split, context)) {
+ PythonRDD.writeAsPickle(elem, dataOut)
+ }
+ dataOut.flush()
+ printOut.flush()
+ worker.shutdownOutput()
+ } catch {
+ case e: IOException =>
+ // This can happen for legitimate reasons if the Python code stops returning data before we are done
+ // passing elements through, e.g., for take(). Just log a message to say it happened.
+ logInfo("stdin writer to Python finished early")
+ logDebug("stdin writer to Python finished early", e)
+ }
+ }
+ }.start()
+
+ // Return an iterator that read lines from the process's stdout
+ val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
+ return new Iterator[Array[Byte]] {
+ def next(): Array[Byte] = {
+ val obj = _nextObj
+ if (hasNext) {
+ // FIXME: can deadlock if worker is waiting for us to
+ // respond to current message (currently irrelevant because
+ // output is shutdown before we read any input)
+ _nextObj = read()
+ }
+ obj
+ }
+
+ private def read(): Array[Byte] = {
+ try {
+ stream.readInt() match {
+ case length if length > 0 =>
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
+ case -3 =>
+ // Timing data from worker
+ val bootTime = stream.readLong()
+ val initTime = stream.readLong()
+ val finishTime = stream.readLong()
+ val boot = bootTime - startTime
+ val init = initTime - bootTime
+ val finish = finishTime - initTime
+ val total = finishTime - startTime
+ logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
+ read
+ case -2 =>
+ // Signals that an exception has been thrown in python
+ val exLength = stream.readInt()
+ val obj = new Array[Byte](exLength)
+ stream.readFully(obj)
+ throw new PythonException(new String(obj))
+ case -1 =>
+ // We've finished the data section of the output, but we can still
+ // read some accumulator updates; let's do that, breaking when we
+ // get a negative length record.
+ var len2 = stream.readInt()
+ while (len2 >= 0) {
+ val update = new Array[Byte](len2)
+ stream.readFully(update)
+ accumulator += Collections.singletonList(update)
+ len2 = stream.readInt()
+ }
+ new Array[Byte](0)
+ }
+ } catch {
+ case eof: EOFException => {
+ throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
+ }
+ case e => throw e
+ }
+ }
+
+ var _nextObj = read()
+
+ def hasNext = _nextObj.length != 0
+ }
+ }
+
+ val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
+}
+
+/** Thrown for exceptions in user Python code. */
+private class PythonException(msg: String) extends Exception(msg)
+
+/**
+ * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
+ * This is used by PySpark's shuffle operations.
+ */
+private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
+ RDD[(Array[Byte], Array[Byte])](prev) {
+ override def getPartitions = prev.partitions
+ override def compute(split: Partition, context: TaskContext) =
+ prev.iterator(split, context).grouped(2).map {
+ case Seq(a, b) => (a, b)
+ case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
+ }
+ val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
+}
+
+private[spark] object PythonRDD {
+
+ /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
+ def stripPickle(arr: Array[Byte]) : Array[Byte] = {
+ arr.slice(2, arr.length - 1)
+ }
+
+ /**
+ * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
+ * The data format is a 32-bit integer representing the pickled object's length (in bytes),
+ * followed by the pickled data.
+ *
+ * Pickle module:
+ *
+ * http://docs.python.org/2/library/pickle.html
+ *
+ * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
+ *
+ * http://hg.python.org/cpython/file/2.6/Lib/pickle.py
+ * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
+ *
+ * @param elem the object to write
+ * @param dOut a data output stream
+ */
+ def writeAsPickle(elem: Any, dOut: DataOutputStream) {
+ if (elem.isInstanceOf[Array[Byte]]) {
+ val arr = elem.asInstanceOf[Array[Byte]]
+ dOut.writeInt(arr.length)
+ dOut.write(arr)
+ } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
+ val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
+ val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
+ dOut.writeInt(length)
+ dOut.writeByte(Pickle.PROTO)
+ dOut.writeByte(Pickle.TWO)
+ dOut.write(PythonRDD.stripPickle(t._1))
+ dOut.write(PythonRDD.stripPickle(t._2))
+ dOut.writeByte(Pickle.TUPLE2)
+ dOut.writeByte(Pickle.STOP)
+ } else if (elem.isInstanceOf[String]) {
+ // For uniformity, strings are wrapped into Pickles.
+ val s = elem.asInstanceOf[String].getBytes("UTF-8")
+ val length = 2 + 1 + 4 + s.length + 1
+ dOut.writeInt(length)
+ dOut.writeByte(Pickle.PROTO)
+ dOut.writeByte(Pickle.TWO)
+ dOut.write(Pickle.BINUNICODE)
+ dOut.writeInt(Integer.reverseBytes(s.length))
+ dOut.write(s)
+ dOut.writeByte(Pickle.STOP)
+ } else {
+ throw new SparkException("Unexpected RDD type")
+ }
+ }
+
+ def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+ JavaRDD[Array[Byte]] = {
+ val file = new DataInputStream(new FileInputStream(filename))
+ val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
+ try {
+ while (true) {
+ val length = file.readInt()
+ val obj = new Array[Byte](length)
+ file.readFully(obj)
+ objs.append(obj)
+ }
+ } catch {
+ case eof: EOFException => {}
+ case e => throw e
+ }
+ JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+ }
+
+ def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ import scala.collection.JavaConverters._
+ writeIteratorToPickleFile(items.asScala, filename)
+ }
+
+ def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
+ val file = new DataOutputStream(new FileOutputStream(filename))
+ for (item <- items) {
+ writeAsPickle(item, file)
+ }
+ file.close()
+ }
+
+ def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
+ implicit val cm : ClassManifest[T] = rdd.elementClassManifest
+ rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
+ }
+}
+
+private object Pickle {
+ val PROTO: Byte = 0x80.toByte
+ val TWO: Byte = 0x02.toByte
+ val BINUNICODE: Byte = 'X'
+ val STOP: Byte = '.'
+ val TUPLE2: Byte = 0x86.toByte
+ val EMPTY_LIST: Byte = ']'
+ val MARK: Byte = '('
+ val APPENDS: Byte = 'e'
+}
+
+private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
+ override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
+}
+
+/**
+ * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
+ * collects a list of pickled strings that we pass to Python through a socket.
+ */
+class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+ extends AccumulatorParam[JList[Array[Byte]]] {
+
+ Utils.checkHost(serverHost, "Expected hostname")
+
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
+ override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
+
+ override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
+ : JList[Array[Byte]] = {
+ if (serverHost == null) {
+ // This happens on the worker node, where we just want to remember all the updates
+ val1.addAll(val2)
+ val1
+ } else {
+ // This happens on the master, where we pass the updates to Python through a socket
+ val socket = new Socket(serverHost, serverPort)
+ val in = socket.getInputStream
+ val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
+ out.writeInt(val2.size)
+ for (array <- val2) {
+ out.writeInt(array.length)
+ out.write(array)
+ }
+ out.flush()
+ // Wait for a byte from the Python side as an acknowledgement
+ val byteRead = in.read()
+ if (byteRead == -1) {
+ throw new SparkException("EOF reached before Python server acknowledged")
+ }
+ socket.close()
+ null
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
new file mode 100644
index 0000000000..08e3f670f5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.io.{File, DataInputStream, IOException}
+import java.net.{Socket, SocketException, InetAddress}
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark._
+
+private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
+ extends Logging {
+ var daemon: Process = null
+ val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
+ var daemonPort: Int = 0
+
+ def create(): Socket = {
+ synchronized {
+ // Start the daemon if it hasn't been started
+ startDaemon()
+
+ // Attempt to connect, restart and retry once if it fails
+ try {
+ new Socket(daemonHost, daemonPort)
+ } catch {
+ case exc: SocketException => {
+ logWarning("Python daemon unexpectedly quit, attempting to restart")
+ stopDaemon()
+ startDaemon()
+ new Socket(daemonHost, daemonPort)
+ }
+ case e => throw e
+ }
+ }
+ }
+
+ def stop() {
+ stopDaemon()
+ }
+
+ private def startDaemon() {
+ synchronized {
+ // Is it already running?
+ if (daemon != null) {
+ return
+ }
+
+ try {
+ // Create and start the daemon
+ val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
+ val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
+ val workerEnv = pb.environment()
+ workerEnv.putAll(envVars)
+ val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
+ workerEnv.put("PYTHONPATH", pythonPath)
+ daemon = pb.start()
+
+ // Redirect the stderr to ours
+ new Thread("stderr reader for " + pythonExec) {
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME HACK: We copy the stream on the level of bytes to
+ // attempt to dodge encoding problems.
+ val in = daemon.getErrorStream
+ var buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
+
+ val in = new DataInputStream(daemon.getInputStream)
+ daemonPort = in.readInt()
+
+ // Redirect further stdout output to our stderr
+ new Thread("stdout reader for " + pythonExec) {
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME HACK: We copy the stream on the level of bytes to
+ // attempt to dodge encoding problems.
+ var buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
+ } catch {
+ case e => {
+ stopDaemon()
+ throw e
+ }
+ }
+
+ // Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
+ // detect our disappearance.
+ }
+ }
+
+ private def stopDaemon() {
+ synchronized {
+ // Request shutdown of existing daemon by sending SIGTERM
+ if (daemon != null) {
+ daemon.destroy()
+ }
+
+ daemon = null
+ daemonPort = 0
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
new file mode 100644
index 0000000000..99e86237fc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
@@ -0,0 +1,1057 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io._
+import java.net._
+import java.util.{BitSet, Comparator, Timer, TimerTask, UUID}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable.{ListBuffer, Map, Set}
+import scala.math
+
+import org.apache.spark._
+import org.apache.spark.storage.StorageLevel
+
+private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+ extends Broadcast[T](id)
+ with Logging
+ with Serializable {
+
+ def value = value_
+
+ def blockId: String = "broadcast_" + id
+
+ MultiTracker.synchronized {
+ SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ }
+
+ @transient var arrayOfBlocks: Array[BroadcastBlock] = null
+ @transient var hasBlocksBitVector: BitSet = null
+ @transient var numCopiesSent: Array[Int] = null
+ @transient var totalBytes = -1
+ @transient var totalBlocks = -1
+ @transient var hasBlocks = new AtomicInteger(0)
+
+ // Used ONLY by driver to track how many unique blocks have been sent out
+ @transient var sentBlocks = new AtomicInteger(0)
+
+ @transient var listenPortLock = new Object
+ @transient var guidePortLock = new Object
+ @transient var totalBlocksLock = new Object
+
+ @transient var listOfSources = ListBuffer[SourceInfo]()
+
+ @transient var serveMR: ServeMultipleRequests = null
+
+ // Used only in driver
+ @transient var guideMR: GuideMultipleRequests = null
+
+ // Used only in Workers
+ @transient var ttGuide: TalkToGuide = null
+
+ @transient var hostAddress = Utils.localIpAddress
+ @transient var listenPort = -1
+ @transient var guidePort = -1
+
+ @transient var stopBroadcast = false
+
+ // Must call this after all the variables have been created/initialized
+ if (!isLocal) {
+ sendBroadcast()
+ }
+
+ def sendBroadcast() {
+ logInfo("Local host address: " + hostAddress)
+
+ // Create a variableInfo object and store it in valueInfos
+ var variableInfo = MultiTracker.blockifyObject(value_)
+
+ // Prepare the value being broadcasted
+ arrayOfBlocks = variableInfo.arrayOfBlocks
+ totalBytes = variableInfo.totalBytes
+ totalBlocks = variableInfo.totalBlocks
+ hasBlocks.set(variableInfo.totalBlocks)
+
+ // Guide has all the blocks
+ hasBlocksBitVector = new BitSet(totalBlocks)
+ hasBlocksBitVector.set(0, totalBlocks)
+
+ // Guide still hasn't sent any block
+ numCopiesSent = new Array[Int](totalBlocks)
+
+ guideMR = new GuideMultipleRequests
+ guideMR.setDaemon(true)
+ guideMR.start()
+ logInfo("GuideMultipleRequests started...")
+
+ // Must always come AFTER guideMR is created
+ while (guidePort == -1) {
+ guidePortLock.synchronized { guidePortLock.wait() }
+ }
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start()
+ logInfo("ServeMultipleRequests started...")
+
+ // Must always come AFTER serveMR is created
+ while (listenPort == -1) {
+ listenPortLock.synchronized { listenPortLock.wait() }
+ }
+
+ // Must always come AFTER listenPort is created
+ val driverSource =
+ SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
+ hasBlocksBitVector.synchronized {
+ driverSource.hasBlocksBitVector = hasBlocksBitVector
+ }
+
+ // In the beginning, this is the only known source to Guide
+ listOfSources += driverSource
+
+ // Register with the Tracker
+ MultiTracker.registerBroadcast(id,
+ SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
+ }
+
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ MultiTracker.synchronized {
+ SparkEnv.get.blockManager.getSingle(blockId) match {
+ case Some(x) =>
+ value_ = x.asInstanceOf[T]
+
+ case None =>
+ logInfo("Started reading broadcast variable " + id)
+ // Initializing everything because driver will only send null/0 values
+ // Only the 1st worker in a node can be here. Others will get from cache
+ initializeWorkerVariables()
+
+ logInfo("Local host address: " + hostAddress)
+
+ // Start local ServeMultipleRequests thread first
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start()
+ logInfo("ServeMultipleRequests started...")
+
+ val start = System.nanoTime
+
+ val receptionSucceeded = receiveBroadcast(id)
+ if (receptionSucceeded) {
+ value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+ SparkEnv.get.blockManager.putSingle(
+ blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ } else {
+ logError("Reading broadcast variable " + id + " failed")
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ }
+ }
+ }
+
+ // Initialize variables in the worker node. Driver sends everything as 0/null
+ private def initializeWorkerVariables() {
+ arrayOfBlocks = null
+ hasBlocksBitVector = null
+ numCopiesSent = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = new AtomicInteger(0)
+
+ listenPortLock = new Object
+ totalBlocksLock = new Object
+
+ serveMR = null
+ ttGuide = null
+
+ hostAddress = Utils.localIpAddress
+ listenPort = -1
+
+ listOfSources = ListBuffer[SourceInfo]()
+
+ stopBroadcast = false
+ }
+
+ private def getLocalSourceInfo: SourceInfo = {
+ // Wait till hostName and listenPort are OK
+ while (listenPort == -1) {
+ listenPortLock.synchronized { listenPortLock.wait() }
+ }
+
+ // Wait till totalBlocks and totalBytes are OK
+ while (totalBlocks == -1) {
+ totalBlocksLock.synchronized { totalBlocksLock.wait() }
+ }
+
+ var localSourceInfo = SourceInfo(
+ hostAddress, listenPort, totalBlocks, totalBytes)
+
+ localSourceInfo.hasBlocks = hasBlocks.get
+
+ hasBlocksBitVector.synchronized {
+ localSourceInfo.hasBlocksBitVector = hasBlocksBitVector
+ }
+
+ return localSourceInfo
+ }
+
+ // Add new SourceInfo to the listOfSources. Update if it exists already.
+ // Optimizing just by OR-ing the BitVectors was BAD for performance
+ private def addToListOfSources(newSourceInfo: SourceInfo) {
+ listOfSources.synchronized {
+ if (listOfSources.contains(newSourceInfo)) {
+ listOfSources = listOfSources - newSourceInfo
+ }
+ listOfSources += newSourceInfo
+ }
+ }
+
+ private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) {
+ newSourceInfos.foreach { newSourceInfo =>
+ addToListOfSources(newSourceInfo)
+ }
+ }
+
+ class TalkToGuide(gInfo: SourceInfo)
+ extends Thread with Logging {
+ override def run() {
+
+ // Keep exchaning information until all blocks have been received
+ while (hasBlocks.get < totalBlocks) {
+ talkOnce
+ Thread.sleep(MultiTracker.ranGen.nextInt(
+ MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
+ MultiTracker.MinKnockInterval)
+ }
+
+ // Talk one more time to let the Guide know of reception completion
+ talkOnce
+ }
+
+ // Connect to Guide and send this worker's information
+ private def talkOnce {
+ var clientSocketToGuide: Socket = null
+ var oosGuide: ObjectOutputStream = null
+ var oisGuide: ObjectInputStream = null
+
+ clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort)
+ oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream)
+ oosGuide.flush()
+ oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream)
+
+ // Send local information
+ oosGuide.writeObject(getLocalSourceInfo)
+ oosGuide.flush()
+
+ // Receive source information from Guide
+ var suitableSources =
+ oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
+ logDebug("Received suitableSources from Driver " + suitableSources)
+
+ addToListOfSources(suitableSources)
+
+ oisGuide.close()
+ oosGuide.close()
+ clientSocketToGuide.close()
+ }
+ }
+
+ def receiveBroadcast(variableID: Long): Boolean = {
+ val gInfo = MultiTracker.getGuideInfo(variableID)
+
+ if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
+ return false
+ }
+
+ // Wait until hostAddress and listenPort are created by the
+ // ServeMultipleRequests thread
+ while (listenPort == -1) {
+ listenPortLock.synchronized { listenPortLock.wait() }
+ }
+
+ // Setup initial states of variables
+ totalBlocks = gInfo.totalBlocks
+ arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
+ hasBlocksBitVector = new BitSet(totalBlocks)
+ numCopiesSent = new Array[Int](totalBlocks)
+ totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
+ totalBytes = gInfo.totalBytes
+
+ // Start ttGuide to periodically talk to the Guide
+ var ttGuide = new TalkToGuide(gInfo)
+ ttGuide.setDaemon(true)
+ ttGuide.start()
+ logInfo("TalkToGuide started...")
+
+ // Start pController to run TalkToPeer threads
+ var pcController = new PeerChatterController
+ pcController.setDaemon(true)
+ pcController.start()
+ logInfo("PeerChatterController started...")
+
+ // FIXME: Must fix this. This might never break if broadcast fails.
+ // We should be able to break and send false. Also need to kill threads
+ while (hasBlocks.get < totalBlocks) {
+ Thread.sleep(MultiTracker.MaxKnockInterval)
+ }
+
+ return true
+ }
+
+ class PeerChatterController
+ extends Thread with Logging {
+ private var peersNowTalking = ListBuffer[SourceInfo]()
+ // TODO: There is a possible bug with blocksInRequestBitVector when a
+ // certain bit is NOT unset upon failure resulting in an infinite loop.
+ private var blocksInRequestBitVector = new BitSet(totalBlocks)
+
+ override def run() {
+ var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
+
+ while (hasBlocks.get < totalBlocks) {
+ var numThreadsToCreate = 0
+ listOfSources.synchronized {
+ numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
+ threadPool.getActiveCount
+ }
+
+ while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
+ var peerToTalkTo = pickPeerToTalkToRandom
+
+ if (peerToTalkTo != null)
+ logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector)
+ else
+ logDebug("No peer chosen...")
+
+ if (peerToTalkTo != null) {
+ threadPool.execute(new TalkToPeer(peerToTalkTo))
+
+ // Add to peersNowTalking. Remove in the thread. We have to do this
+ // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once
+ peersNowTalking.synchronized { peersNowTalking += peerToTalkTo }
+ }
+
+ numThreadsToCreate = numThreadsToCreate - 1
+ }
+
+ // Sleep for a while before starting some more threads
+ Thread.sleep(MultiTracker.MinKnockInterval)
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ // Right now picking the one that has the most blocks this peer wants
+ // Also picking peer randomly if no one has anything interesting
+ private def pickPeerToTalkToRandom: SourceInfo = {
+ var curPeer: SourceInfo = null
+ var curMax = 0
+
+ logDebug("Picking peers to talk to...")
+
+ // Find peers that are not connected right now
+ var peersNotInUse = ListBuffer[SourceInfo]()
+ listOfSources.synchronized {
+ peersNowTalking.synchronized {
+ peersNotInUse = listOfSources -- peersNowTalking
+ }
+ }
+
+ // Select the peer that has the most blocks that this receiver does not
+ peersNotInUse.foreach { eachSource =>
+ var tempHasBlocksBitVector: BitSet = null
+ hasBlocksBitVector.synchronized {
+ tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
+ }
+ tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size)
+ tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector)
+
+ if (tempHasBlocksBitVector.cardinality > curMax) {
+ curPeer = eachSource
+ curMax = tempHasBlocksBitVector.cardinality
+ }
+ }
+
+ // Always picking randomly
+ if (curPeer == null && peersNotInUse.size > 0) {
+ // Pick uniformly the i'th required peer
+ var i = MultiTracker.ranGen.nextInt(peersNotInUse.size)
+
+ var peerIter = peersNotInUse.iterator
+ curPeer = peerIter.next
+
+ while (i > 0) {
+ curPeer = peerIter.next
+ i = i - 1
+ }
+ }
+
+ return curPeer
+ }
+
+ // Picking peer with the weight of rare blocks it has
+ private def pickPeerToTalkToRarestFirst: SourceInfo = {
+ // Find peers that are not connected right now
+ var peersNotInUse = ListBuffer[SourceInfo]()
+ listOfSources.synchronized {
+ peersNowTalking.synchronized {
+ peersNotInUse = listOfSources -- peersNowTalking
+ }
+ }
+
+ // Count the number of copies of each block in the neighborhood
+ var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
+
+ listOfSources.synchronized {
+ listOfSources.foreach { eachSource =>
+ for (i <- 0 until totalBlocks) {
+ numCopiesPerBlock(i) +=
+ ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
+ }
+ }
+ }
+
+ // A block is considered rare if there are at most 2 copies of that block
+ // This CONSTANT could be a function of the neighborhood size
+ var rareBlocksIndices = ListBuffer[Int]()
+ for (i <- 0 until totalBlocks) {
+ if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) {
+ rareBlocksIndices += i
+ }
+ }
+
+ // Find peers with rare blocks
+ var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]()
+ var totalRareBlocks = 0
+
+ peersNotInUse.foreach { eachPeer =>
+ var hasRareBlocks = 0
+ rareBlocksIndices.foreach { rareBlock =>
+ if (eachPeer.hasBlocksBitVector.get(rareBlock)) {
+ hasRareBlocks += 1
+ }
+ }
+
+ if (hasRareBlocks > 0) {
+ peersWithRareBlocks += ((eachPeer, hasRareBlocks))
+ }
+ totalRareBlocks += hasRareBlocks
+ }
+
+ // Select a peer from peersWithRareBlocks based on weight calculated from
+ // unique rare blocks
+ var selectedPeerToTalkTo: SourceInfo = null
+
+ if (peersWithRareBlocks.size > 0) {
+ // Sort the peers based on how many rare blocks they have
+ peersWithRareBlocks.sortBy(_._2)
+
+ var randomNumber = MultiTracker.ranGen.nextDouble
+ var tempSum = 0.0
+
+ var i = 0
+ do {
+ tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks)
+ if (tempSum >= randomNumber) {
+ selectedPeerToTalkTo = peersWithRareBlocks(i)._1
+ }
+ i += 1
+ } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null)
+ }
+
+ if (selectedPeerToTalkTo == null) {
+ selectedPeerToTalkTo = pickPeerToTalkToRandom
+ }
+
+ return selectedPeerToTalkTo
+ }
+
+ class TalkToPeer(peerToTalkTo: SourceInfo)
+ extends Thread with Logging {
+ private var peerSocketToSource: Socket = null
+ private var oosSource: ObjectOutputStream = null
+ private var oisSource: ObjectInputStream = null
+
+ override def run() {
+ // TODO: There is a possible bug here regarding blocksInRequestBitVector
+ var blockToAskFor = -1
+
+ // Setup the timeout mechanism
+ var timeOutTask = new TimerTask {
+ override def run() {
+ cleanUpConnections()
+ }
+ }
+
+ var timeOutTimer = new Timer
+ timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval)
+
+ logInfo("TalkToPeer started... => " + peerToTalkTo)
+
+ try {
+ // Connect to the source
+ peerSocketToSource =
+ new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort)
+ oosSource =
+ new ObjectOutputStream(peerSocketToSource.getOutputStream)
+ oosSource.flush()
+ oisSource =
+ new ObjectInputStream(peerSocketToSource.getInputStream)
+
+ // Receive latest SourceInfo from peerToTalkTo
+ var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo]
+ // Update listOfSources
+ addToListOfSources(newPeerToTalkTo)
+
+ // Turn the timer OFF, if the sender responds before timeout
+ timeOutTimer.cancel()
+
+ // Send the latest SourceInfo
+ oosSource.writeObject(getLocalSourceInfo)
+ oosSource.flush()
+
+ var keepReceiving = true
+
+ while (hasBlocks.get < totalBlocks && keepReceiving) {
+ blockToAskFor =
+ pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector)
+
+ // No block to request
+ if (blockToAskFor < 0) {
+ // Nothing to receive from newPeerToTalkTo
+ keepReceiving = false
+ } else {
+ // Let other threads know that blockToAskFor is being requested
+ blocksInRequestBitVector.synchronized {
+ blocksInRequestBitVector.set(blockToAskFor)
+ }
+
+ // Start with sending the blockID
+ oosSource.writeObject(blockToAskFor)
+ oosSource.flush()
+
+ // CHANGED: Driver might send some other block than the one
+ // requested to ensure fast spreading of all blocks.
+ val recvStartTime = System.currentTimeMillis
+ val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
+ val receptionTime = (System.currentTimeMillis - recvStartTime)
+
+ logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
+
+ if (!hasBlocksBitVector.get(bcBlock.blockID)) {
+ arrayOfBlocks(bcBlock.blockID) = bcBlock
+
+ // Update the hasBlocksBitVector first
+ hasBlocksBitVector.synchronized {
+ hasBlocksBitVector.set(bcBlock.blockID)
+ hasBlocks.getAndIncrement
+ }
+
+ // Some block(may NOT be blockToAskFor) has arrived.
+ // In any case, blockToAskFor is not in request any more
+ blocksInRequestBitVector.synchronized {
+ blocksInRequestBitVector.set(blockToAskFor, false)
+ }
+
+ // Reset blockToAskFor to -1. Else it will be considered missing
+ blockToAskFor = -1
+ }
+
+ // Send the latest SourceInfo
+ oosSource.writeObject(getLocalSourceInfo)
+ oosSource.flush()
+ }
+ }
+ } catch {
+ // EOFException is expected to happen because sender can break
+ // connection due to timeout
+ case eofe: java.io.EOFException => { }
+ case e: Exception => {
+ logError("TalktoPeer had a " + e)
+ // FIXME: Remove 'newPeerToTalkTo' from listOfSources
+ // We probably should have the following in some form, but not
+ // really here. This exception can happen if the sender just breaks connection
+ // listOfSources.synchronized {
+ // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo)
+ // listOfSources = listOfSources - peerToTalkTo
+ // }
+ }
+ } finally {
+ // blockToAskFor != -1 => there was an exception
+ if (blockToAskFor != -1) {
+ blocksInRequestBitVector.synchronized {
+ blocksInRequestBitVector.set(blockToAskFor, false)
+ }
+ }
+
+ cleanUpConnections()
+ }
+ }
+
+ // Right now it picks a block uniformly that this peer does not have
+ private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = {
+ var needBlocksBitVector: BitSet = null
+
+ // Blocks already present
+ hasBlocksBitVector.synchronized {
+ needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
+ }
+
+ // Include blocks already in transmission ONLY IF
+ // MultiTracker.EndGameFraction has NOT been achieved
+ if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
+ blocksInRequestBitVector.synchronized {
+ needBlocksBitVector.or(blocksInRequestBitVector)
+ }
+ }
+
+ // Find blocks that are neither here nor in transit
+ needBlocksBitVector.flip(0, needBlocksBitVector.size)
+
+ // Blocks that should/can be requested
+ needBlocksBitVector.and(txHasBlocksBitVector)
+
+ if (needBlocksBitVector.cardinality == 0) {
+ return -1
+ } else {
+ // Pick uniformly the i'th required block
+ var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality)
+ var pickedBlockIndex = needBlocksBitVector.nextSetBit(0)
+
+ while (i > 0) {
+ pickedBlockIndex =
+ needBlocksBitVector.nextSetBit(pickedBlockIndex + 1)
+ i -= 1
+ }
+
+ return pickedBlockIndex
+ }
+ }
+
+ // Pick the block that seems to be the rarest across sources
+ private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = {
+ var needBlocksBitVector: BitSet = null
+
+ // Blocks already present
+ hasBlocksBitVector.synchronized {
+ needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
+ }
+
+ // Include blocks already in transmission ONLY IF
+ // MultiTracker.EndGameFraction has NOT been achieved
+ if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
+ blocksInRequestBitVector.synchronized {
+ needBlocksBitVector.or(blocksInRequestBitVector)
+ }
+ }
+
+ // Find blocks that are neither here nor in transit
+ needBlocksBitVector.flip(0, needBlocksBitVector.size)
+
+ // Blocks that should/can be requested
+ needBlocksBitVector.and(txHasBlocksBitVector)
+
+ if (needBlocksBitVector.cardinality == 0) {
+ return -1
+ } else {
+ // Count the number of copies for each block across all sources
+ var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
+
+ listOfSources.synchronized {
+ listOfSources.foreach { eachSource =>
+ for (i <- 0 until totalBlocks) {
+ numCopiesPerBlock(i) +=
+ ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
+ }
+ }
+ }
+
+ // Find the minimum
+ var minVal = Integer.MAX_VALUE
+ for (i <- 0 until totalBlocks) {
+ if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) {
+ minVal = numCopiesPerBlock(i)
+ }
+ }
+
+ // Find the blocks with the least copies that this peer does not have
+ var minBlocksIndices = ListBuffer[Int]()
+ for (i <- 0 until totalBlocks) {
+ if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) {
+ minBlocksIndices += i
+ }
+ }
+
+ // Now select a random index from minBlocksIndices
+ if (minBlocksIndices.size == 0) {
+ return -1
+ } else {
+ // Pick uniformly the i'th index
+ var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size)
+ return minBlocksIndices(i)
+ }
+ }
+ }
+
+ private def cleanUpConnections() {
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (peerSocketToSource != null) {
+ peerSocketToSource.close()
+ }
+
+ // Delete from peersNowTalking
+ peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo }
+ }
+ }
+ }
+
+ class GuideMultipleRequests
+ extends Thread with Logging {
+ // Keep track of sources that have completed reception
+ private var setOfCompletedSources = Set[SourceInfo]()
+
+ override def run() {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(0)
+ guidePort = serverSocket.getLocalPort
+ logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
+
+ guidePortLock.synchronized { guidePortLock.notifyAll() }
+
+ try {
+ while (!stopBroadcast) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => {
+ // Stop broadcast if at least one worker has connected and
+ // everyone connected so far are done. Comparing with
+ // listOfSources.size - 1, because it includes the Guide itself
+ listOfSources.synchronized {
+ setOfCompletedSources.synchronized {
+ if (listOfSources.size > 1 &&
+ setOfCompletedSources.size == listOfSources.size - 1) {
+ stopBroadcast = true
+ logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
+ }
+ }
+ }
+ }
+ }
+ if (clientSocket != null) {
+ logDebug("Guide: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new GuideSingleRequest(clientSocket))
+ } catch {
+ // In failure, close the socket here; else, thread will close it
+ case ioe: IOException => {
+ clientSocket.close()
+ }
+ }
+ }
+ }
+
+ // Shutdown the thread pool
+ threadPool.shutdown()
+
+ logInfo("Sending stopBroadcast notifications...")
+ sendStopBroadcastNotifications
+
+ MultiTracker.unregisterBroadcast(id)
+ } finally {
+ if (serverSocket != null) {
+ logInfo("GuideMultipleRequests now stopping...")
+ serverSocket.close()
+ }
+ }
+ }
+
+ private def sendStopBroadcastNotifications() {
+ listOfSources.synchronized {
+ listOfSources.foreach { sourceInfo =>
+
+ var guideSocketToSource: Socket = null
+ var gosSource: ObjectOutputStream = null
+ var gisSource: ObjectInputStream = null
+
+ try {
+ // Connect to the source
+ guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
+ gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
+ gosSource.flush()
+ gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
+
+ // Throw away whatever comes in
+ gisSource.readObject.asInstanceOf[SourceInfo]
+
+ // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast
+ gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast))
+ gosSource.flush()
+ } catch {
+ case e: Exception => {
+ logError("sendStopBroadcastNotifications had a " + e)
+ }
+ } finally {
+ if (gisSource != null) {
+ gisSource.close()
+ }
+ if (gosSource != null) {
+ gosSource.close()
+ }
+ if (guideSocketToSource != null) {
+ guideSocketToSource.close()
+ }
+ }
+ }
+ }
+ }
+
+ class GuideSingleRequest(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ private var sourceInfo: SourceInfo = null
+ private var selectedSources: ListBuffer[SourceInfo] = null
+
+ override def run() {
+ try {
+ logInfo("new GuideSingleRequest is running")
+ // Connecting worker is sending in its information
+ sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ // Select a suitable source and send it back to the worker
+ selectedSources = selectSuitableSources(sourceInfo)
+ logDebug("Sending selectedSources:" + selectedSources)
+ oos.writeObject(selectedSources)
+ oos.flush()
+
+ // Add this source to the listOfSources
+ addToListOfSources(sourceInfo)
+ } catch {
+ case e: Exception => {
+ // Assuming exception caused by receiver failure: remove
+ if (listOfSources != null) {
+ listOfSources.synchronized { listOfSources -= sourceInfo }
+ }
+ }
+ } finally {
+ logInfo("GuideSingleRequest is closing streams and sockets")
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+
+ // Randomly select some sources to send back
+ private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = {
+ var selectedSources = ListBuffer[SourceInfo]()
+
+ // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true'
+ // then add skipSourceInfo to setOfCompletedSources. Return blank.
+ if (skipSourceInfo.hasBlocks == totalBlocks) {
+ setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo }
+ return selectedSources
+ }
+
+ listOfSources.synchronized {
+ if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) {
+ selectedSources = listOfSources.clone
+ } else {
+ var picksLeft = MultiTracker.MaxPeersInGuideResponse
+ var alreadyPicked = new BitSet(listOfSources.size)
+
+ while (picksLeft > 0) {
+ var i = -1
+
+ do {
+ i = MultiTracker.ranGen.nextInt(listOfSources.size)
+ } while (alreadyPicked.get(i))
+
+ var peerIter = listOfSources.iterator
+ var curPeer = peerIter.next
+
+ // Set the BitSet before i is decremented
+ alreadyPicked.set(i)
+
+ while (i > 0) {
+ curPeer = peerIter.next
+ i = i - 1
+ }
+
+ selectedSources += curPeer
+
+ picksLeft = picksLeft - 1
+ }
+ }
+ }
+
+ // Remove the receiving source (if present)
+ selectedSources = selectedSources - skipSourceInfo
+
+ return selectedSources
+ }
+ }
+ }
+
+ class ServeMultipleRequests
+ extends Thread with Logging {
+ // Server at most MultiTracker.MaxChatSlots peers
+ var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
+
+ override def run() {
+ var serverSocket = new ServerSocket(0)
+ listenPort = serverSocket.getLocalPort
+
+ logInfo("ServeMultipleRequests started with " + serverSocket)
+
+ listenPortLock.synchronized { listenPortLock.notifyAll() }
+
+ try {
+ while (!stopBroadcast) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => { }
+ }
+ if (clientSocket != null) {
+ logDebug("Serve: Accepted new client connection:" + clientSocket)
+ try {
+ threadPool.execute(new ServeSingleRequest(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => clientSocket.close()
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ServeMultipleRequests now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ServeSingleRequest(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ logInfo("new ServeSingleRequest is running")
+
+ override def run() {
+ try {
+ // Send latest local SourceInfo to the receiver
+ // In the case of receiver timeout and connection close, this will
+ // throw a java.net.SocketException: Broken pipe
+ oos.writeObject(getLocalSourceInfo)
+ oos.flush()
+
+ // Receive latest SourceInfo from the receiver
+ var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) {
+ stopBroadcast = true
+ } else {
+ addToListOfSources(rxSourceInfo)
+ }
+
+ val startTime = System.currentTimeMillis
+ var curTime = startTime
+ var keepSending = true
+ var numBlocksToSend = MultiTracker.MaxChatBlocks
+
+ while (!stopBroadcast && keepSending && numBlocksToSend > 0) {
+ // Receive which block to send
+ var blockToSend = ois.readObject.asInstanceOf[Int]
+
+ // If it is driver AND at least one copy of each block has not been
+ // sent out already, MODIFY blockToSend
+ if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) {
+ blockToSend = sentBlocks.getAndIncrement
+ }
+
+ // Send the block
+ sendBlock(blockToSend)
+ rxSourceInfo.hasBlocksBitVector.set(blockToSend)
+
+ numBlocksToSend -= 1
+
+ // Receive latest SourceInfo from the receiver
+ rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+ logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
+ addToListOfSources(rxSourceInfo)
+
+ curTime = System.currentTimeMillis
+ // Revoke sending only if there is anyone waiting in the queue
+ if (curTime - startTime >= MultiTracker.MaxChatTime &&
+ threadPool.getQueue.size > 0) {
+ keepSending = false
+ }
+ }
+ } catch {
+ case e: Exception => logError("ServeSingleRequest had a " + e)
+ } finally {
+ logInfo("ServeSingleRequest is closing streams and sockets")
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+
+ private def sendBlock(blockToSend: Int) {
+ try {
+ oos.writeObject(arrayOfBlocks(blockToSend))
+ oos.flush()
+ } catch {
+ case e: Exception => logError("sendBlock had a " + e)
+ }
+ logDebug("Sent block: " + blockToSend + " to " + clientSocket)
+ }
+ }
+ }
+}
+
+private[spark] class BitTorrentBroadcastFactory
+extends BroadcastFactory {
+ def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new BitTorrentBroadcast[T](value_, isLocal, id)
+
+ def stop() { MultiTracker.stop() }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
new file mode 100644
index 0000000000..43c18294c5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark._
+
+abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
+ def value: T
+
+ // We cannot have an abstract readObject here due to some weird issues with
+ // readObject having to be 'private' in sub-classes.
+
+ override def toString = "Broadcast(" + id + ")"
+}
+
+private[spark]
+class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable {
+
+ private var initialized = false
+ private var broadcastFactory: BroadcastFactory = null
+
+ initialize()
+
+ // Called by SparkContext or Executor before using Broadcast
+ private def initialize() {
+ synchronized {
+ if (!initialized) {
+ val broadcastFactoryClass = System.getProperty(
+ "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+
+ broadcastFactory =
+ Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
+
+ // Initialize appropriate BroadcastFactory and BroadcastObject
+ broadcastFactory.initialize(isDriver)
+
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ broadcastFactory.stop()
+ }
+
+ private val nextBroadcastId = new AtomicLong(0)
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean) =
+ broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
+
+ def isDriver = _isDriver
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
new file mode 100644
index 0000000000..68bff75b90
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+/**
+ * An interface for all the broadcast implementations in Spark (to allow
+ * multiple broadcast implementations). SparkContext uses a user-specified
+ * BroadcastFactory implementation to instantiate a particular broadcast for the
+ * entire Spark job.
+ */
+private[spark] trait BroadcastFactory {
+ def initialize(isDriver: Boolean): Unit
+ def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
+ def stop(): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
new file mode 100644
index 0000000000..7a52ff0769
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
+import java.net.URL
+
+import it.unimi.dsi.fastutil.io.FastBufferedInputStream
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import org.apache.spark.{HttpServer, Logging, SparkEnv, Utils}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.{MetadataCleaner, TimeStampedHashSet}
+
+
+private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+ extends Broadcast[T](id) with Logging with Serializable {
+
+ def value = value_
+
+ def blockId: String = "broadcast_" + id
+
+ HttpBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ }
+
+ if (!isLocal) {
+ HttpBroadcast.write(id, value_)
+ }
+
+ // Called by JVM when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ HttpBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(blockId) match {
+ case Some(x) => value_ = x.asInstanceOf[T]
+ case None => {
+ logInfo("Started reading broadcast variable " + id)
+ val start = System.nanoTime
+ value_ = HttpBroadcast.read[T](id)
+ SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ }
+ }
+ }
+ }
+}
+
+private[spark] class HttpBroadcastFactory extends BroadcastFactory {
+ def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new HttpBroadcast[T](value_, isLocal, id)
+
+ def stop() { HttpBroadcast.stop() }
+}
+
+private object HttpBroadcast extends Logging {
+ private var initialized = false
+
+ private var broadcastDir: File = null
+ private var compress: Boolean = false
+ private var bufferSize: Int = 65536
+ private var serverUri: String = null
+ private var server: HttpServer = null
+
+ private val files = new TimeStampedHashSet[String]
+ private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
+
+ private lazy val compressionCodec = CompressionCodec.createCodec()
+
+ def initialize(isDriver: Boolean) {
+ synchronized {
+ if (!initialized) {
+ bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ compress = System.getProperty("spark.broadcast.compress", "true").toBoolean
+ if (isDriver) {
+ createServer()
+ }
+ serverUri = System.getProperty("spark.httpBroadcast.uri")
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ synchronized {
+ if (server != null) {
+ server.stop()
+ server = null
+ }
+ initialized = false
+ cleaner.cancel()
+ }
+ }
+
+ private def createServer() {
+ broadcastDir = Utils.createTempDir(Utils.getLocalDir)
+ server = new HttpServer(broadcastDir)
+ server.start()
+ serverUri = server.uri
+ System.setProperty("spark.httpBroadcast.uri", serverUri)
+ logInfo("Broadcast server started at " + serverUri)
+ }
+
+ def write(id: Long, value: Any) {
+ val file = new File(broadcastDir, "broadcast-" + id)
+ val out: OutputStream = {
+ if (compress) {
+ compressionCodec.compressedOutputStream(new FileOutputStream(file))
+ } else {
+ new FastBufferedOutputStream(new FileOutputStream(file), bufferSize)
+ }
+ }
+ val ser = SparkEnv.get.serializer.newInstance()
+ val serOut = ser.serializeStream(out)
+ serOut.writeObject(value)
+ serOut.close()
+ files += file.getAbsolutePath
+ }
+
+ def read[T](id: Long): T = {
+ val url = serverUri + "/broadcast-" + id
+ val in = {
+ if (compress) {
+ compressionCodec.compressedInputStream(new URL(url).openStream())
+ } else {
+ new FastBufferedInputStream(new URL(url).openStream(), bufferSize)
+ }
+ }
+ val ser = SparkEnv.get.serializer.newInstance()
+ val serIn = ser.deserializeStream(in)
+ val obj = serIn.readObject[T]()
+ serIn.close()
+ obj
+ }
+
+ def cleanup(cleanupTime: Long) {
+ val iterator = files.internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ val (file, time) = (entry.getKey, entry.getValue)
+ if (time < cleanupTime) {
+ try {
+ iterator.remove()
+ new File(file.toString).delete()
+ logInfo("Deleted broadcast file '" + file + "'")
+ } catch {
+ case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
new file mode 100644
index 0000000000..10b910df87
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io._
+import java.net._
+import java.util.Random
+
+import scala.collection.mutable.Map
+
+import org.apache.spark._
+
+private object MultiTracker
+extends Logging {
+
+ // Tracker Messages
+ val REGISTER_BROADCAST_TRACKER = 0
+ val UNREGISTER_BROADCAST_TRACKER = 1
+ val FIND_BROADCAST_TRACKER = 2
+
+ // Map to keep track of guides of ongoing broadcasts
+ var valueToGuideMap = Map[Long, SourceInfo]()
+
+ // Random number generator
+ var ranGen = new Random
+
+ private var initialized = false
+ private var _isDriver = false
+
+ private var stopBroadcast = false
+
+ private var trackMV: TrackMultipleValues = null
+
+ def initialize(__isDriver: Boolean) {
+ synchronized {
+ if (!initialized) {
+ _isDriver = __isDriver
+
+ if (isDriver) {
+ trackMV = new TrackMultipleValues
+ trackMV.setDaemon(true)
+ trackMV.start()
+
+ // Set DriverHostAddress to the driver's IP address for the slaves to read
+ System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
+ }
+
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ stopBroadcast = true
+ }
+
+ // Load common parameters
+ private var DriverHostAddress_ = System.getProperty(
+ "spark.MultiTracker.DriverHostAddress", "")
+ private var DriverTrackerPort_ = System.getProperty(
+ "spark.broadcast.driverTrackerPort", "11111").toInt
+ private var BlockSize_ = System.getProperty(
+ "spark.broadcast.blockSize", "4096").toInt * 1024
+ private var MaxRetryCount_ = System.getProperty(
+ "spark.broadcast.maxRetryCount", "2").toInt
+
+ private var TrackerSocketTimeout_ = System.getProperty(
+ "spark.broadcast.trackerSocketTimeout", "50000").toInt
+ private var ServerSocketTimeout_ = System.getProperty(
+ "spark.broadcast.serverSocketTimeout", "10000").toInt
+
+ private var MinKnockInterval_ = System.getProperty(
+ "spark.broadcast.minKnockInterval", "500").toInt
+ private var MaxKnockInterval_ = System.getProperty(
+ "spark.broadcast.maxKnockInterval", "999").toInt
+
+ // Load TreeBroadcast config params
+ private var MaxDegree_ = System.getProperty(
+ "spark.broadcast.maxDegree", "2").toInt
+
+ // Load BitTorrentBroadcast config params
+ private var MaxPeersInGuideResponse_ = System.getProperty(
+ "spark.broadcast.maxPeersInGuideResponse", "4").toInt
+
+ private var MaxChatSlots_ = System.getProperty(
+ "spark.broadcast.maxChatSlots", "4").toInt
+ private var MaxChatTime_ = System.getProperty(
+ "spark.broadcast.maxChatTime", "500").toInt
+ private var MaxChatBlocks_ = System.getProperty(
+ "spark.broadcast.maxChatBlocks", "1024").toInt
+
+ private var EndGameFraction_ = System.getProperty(
+ "spark.broadcast.endGameFraction", "0.95").toDouble
+
+ def isDriver = _isDriver
+
+ // Common config params
+ def DriverHostAddress = DriverHostAddress_
+ def DriverTrackerPort = DriverTrackerPort_
+ def BlockSize = BlockSize_
+ def MaxRetryCount = MaxRetryCount_
+
+ def TrackerSocketTimeout = TrackerSocketTimeout_
+ def ServerSocketTimeout = ServerSocketTimeout_
+
+ def MinKnockInterval = MinKnockInterval_
+ def MaxKnockInterval = MaxKnockInterval_
+
+ // TreeBroadcast configs
+ def MaxDegree = MaxDegree_
+
+ // BitTorrentBroadcast configs
+ def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
+
+ def MaxChatSlots = MaxChatSlots_
+ def MaxChatTime = MaxChatTime_
+ def MaxChatBlocks = MaxChatBlocks_
+
+ def EndGameFraction = EndGameFraction_
+
+ class TrackMultipleValues
+ extends Thread with Logging {
+ override def run() {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(DriverTrackerPort)
+ logInfo("TrackMultipleValues started at " + serverSocket)
+
+ try {
+ while (!stopBroadcast) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(TrackerSocketTimeout)
+ clientSocket = serverSocket.accept()
+ } catch {
+ case e: Exception => {
+ if (stopBroadcast) {
+ logInfo("Stopping TrackMultipleValues...")
+ }
+ }
+ }
+
+ if (clientSocket != null) {
+ try {
+ threadPool.execute(new Thread {
+ override def run() {
+ val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ try {
+ // First, read message type
+ val messageType = ois.readObject.asInstanceOf[Int]
+
+ if (messageType == REGISTER_BROADCAST_TRACKER) {
+ // Receive Long
+ val id = ois.readObject.asInstanceOf[Long]
+ // Receive hostAddress and listenPort
+ val gInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ // Add to the map
+ valueToGuideMap.synchronized {
+ valueToGuideMap += (id -> gInfo)
+ }
+
+ logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
+
+ // Send dummy ACK
+ oos.writeObject(-1)
+ oos.flush()
+ } else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
+ // Receive Long
+ val id = ois.readObject.asInstanceOf[Long]
+
+ // Remove from the map
+ valueToGuideMap.synchronized {
+ valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
+ }
+
+ logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
+
+ // Send dummy ACK
+ oos.writeObject(-1)
+ oos.flush()
+ } else if (messageType == FIND_BROADCAST_TRACKER) {
+ // Receive Long
+ val id = ois.readObject.asInstanceOf[Long]
+
+ var gInfo =
+ if (valueToGuideMap.contains(id)) valueToGuideMap(id)
+ else SourceInfo("", SourceInfo.TxNotStartedRetry)
+
+ logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
+
+ // Send reply back
+ oos.writeObject(gInfo)
+ oos.flush()
+ } else {
+ throw new SparkException("Undefined messageType at TrackMultipleValues")
+ }
+ } catch {
+ case e: Exception => {
+ logError("TrackMultipleValues had a " + e)
+ }
+ } finally {
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+ })
+ } catch {
+ // In failure, close socket here; else, client thread will close
+ case ioe: IOException => clientSocket.close()
+ }
+ }
+ }
+ } finally {
+ serverSocket.close()
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+ }
+
+ def getGuideInfo(variableLong: Long): SourceInfo = {
+ var clientSocketToTracker: Socket = null
+ var oosTracker: ObjectOutputStream = null
+ var oisTracker: ObjectInputStream = null
+
+ var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
+
+ var retriesLeft = MultiTracker.MaxRetryCount
+ do {
+ try {
+ // Connect to the tracker to find out GuideInfo
+ clientSocketToTracker =
+ new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
+ oosTracker =
+ new ObjectOutputStream(clientSocketToTracker.getOutputStream)
+ oosTracker.flush()
+ oisTracker =
+ new ObjectInputStream(clientSocketToTracker.getInputStream)
+
+ // Send messageType/intention
+ oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
+ oosTracker.flush()
+
+ // Send Long and receive GuideInfo
+ oosTracker.writeObject(variableLong)
+ oosTracker.flush()
+ gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
+ } catch {
+ case e: Exception => logError("getGuideInfo had a " + e)
+ } finally {
+ if (oisTracker != null) {
+ oisTracker.close()
+ }
+ if (oosTracker != null) {
+ oosTracker.close()
+ }
+ if (clientSocketToTracker != null) {
+ clientSocketToTracker.close()
+ }
+ }
+
+ Thread.sleep(MultiTracker.ranGen.nextInt(
+ MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
+ MultiTracker.MinKnockInterval)
+
+ retriesLeft -= 1
+ } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
+
+ logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
+ return gInfo
+ }
+
+ def registerBroadcast(id: Long, gInfo: SourceInfo) {
+ val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
+ val oosST = new ObjectOutputStream(socket.getOutputStream)
+ oosST.flush()
+ val oisST = new ObjectInputStream(socket.getInputStream)
+
+ // Send messageType/intention
+ oosST.writeObject(REGISTER_BROADCAST_TRACKER)
+ oosST.flush()
+
+ // Send Long of this broadcast
+ oosST.writeObject(id)
+ oosST.flush()
+
+ // Send this tracker's information
+ oosST.writeObject(gInfo)
+ oosST.flush()
+
+ // Receive ACK and throw it away
+ oisST.readObject.asInstanceOf[Int]
+
+ // Shut stuff down
+ oisST.close()
+ oosST.close()
+ socket.close()
+ }
+
+ def unregisterBroadcast(id: Long) {
+ val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
+ val oosST = new ObjectOutputStream(socket.getOutputStream)
+ oosST.flush()
+ val oisST = new ObjectInputStream(socket.getInputStream)
+
+ // Send messageType/intention
+ oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
+ oosST.flush()
+
+ // Send Long of this broadcast
+ oosST.writeObject(id)
+ oosST.flush()
+
+ // Receive ACK and throw it away
+ oisST.readObject.asInstanceOf[Int]
+
+ // Shut stuff down
+ oisST.close()
+ oosST.close()
+ socket.close()
+ }
+
+ // Helper method to convert an object to Array[BroadcastBlock]
+ def blockifyObject[IN](obj: IN): VariableInfo = {
+ val baos = new ByteArrayOutputStream
+ val oos = new ObjectOutputStream(baos)
+ oos.writeObject(obj)
+ oos.close()
+ baos.close()
+ val byteArray = baos.toByteArray
+ val bais = new ByteArrayInputStream(byteArray)
+
+ var blockNum = (byteArray.length / BlockSize)
+ if (byteArray.length % BlockSize != 0)
+ blockNum += 1
+
+ var retVal = new Array[BroadcastBlock](blockNum)
+ var blockID = 0
+
+ for (i <- 0 until (byteArray.length, BlockSize)) {
+ val thisBlockSize = math.min(BlockSize, byteArray.length - i)
+ var tempByteArray = new Array[Byte](thisBlockSize)
+ val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+
+ retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close()
+
+ var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
+ variableInfo.hasBlocks = blockNum
+
+ return variableInfo
+ }
+
+ // Helper method to convert Array[BroadcastBlock] to object
+ def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
+ totalBytes: Int,
+ totalBlocks: Int): OUT = {
+
+ var retByteArray = new Array[Byte](totalBytes)
+ for (i <- 0 until totalBlocks) {
+ System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
+ i * BlockSize, arrayOfBlocks(i).byteArray.length)
+ }
+ byteArrayToObject(retByteArray)
+ }
+
+ private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
+ val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
+ override def resolveClass(desc: ObjectStreamClass) =
+ Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
+ }
+ val retVal = in.readObject.asInstanceOf[OUT]
+ in.close()
+ return retVal
+ }
+}
+
+private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
+extends Serializable
+
+private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
+ totalBlocks: Int,
+ totalBytes: Int)
+extends Serializable {
+ @transient var hasBlocks = 0
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
new file mode 100644
index 0000000000..baa1fd6da4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.util.BitSet
+
+import org.apache.spark._
+
+/**
+ * Used to keep and pass around information of peers involved in a broadcast
+ */
+private[spark] case class SourceInfo (hostAddress: String,
+ listenPort: Int,
+ totalBlocks: Int = SourceInfo.UnusedParam,
+ totalBytes: Int = SourceInfo.UnusedParam)
+extends Comparable[SourceInfo] with Logging {
+
+ var currentLeechers = 0
+ var receptionFailed = false
+
+ var hasBlocks = 0
+ var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
+
+ // Ascending sort based on leecher count
+ def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
+}
+
+/**
+ * Helper Object of SourceInfo for its constants
+ */
+private[spark] object SourceInfo {
+ // Broadcast has not started yet! Should never happen.
+ val TxNotStartedRetry = -1
+ // Broadcast has already finished. Try default mechanism.
+ val TxOverGoToDefault = -3
+ // Other constants
+ val StopBroadcast = -2
+ val UnusedParam = 0
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
new file mode 100644
index 0000000000..b5a4ccc0ee
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
@@ -0,0 +1,602 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io._
+import java.net._
+import java.util.{Comparator, Random, UUID}
+
+import scala.collection.mutable.{ListBuffer, Map, Set}
+import scala.math
+
+import org.apache.spark._
+import org.apache.spark.storage.StorageLevel
+
+private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+extends Broadcast[T](id) with Logging with Serializable {
+
+ def value = value_
+
+ def blockId = "broadcast_" + id
+
+ MultiTracker.synchronized {
+ SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ }
+
+ @transient var arrayOfBlocks: Array[BroadcastBlock] = null
+ @transient var totalBytes = -1
+ @transient var totalBlocks = -1
+ @transient var hasBlocks = 0
+
+ @transient var listenPortLock = new Object
+ @transient var guidePortLock = new Object
+ @transient var totalBlocksLock = new Object
+ @transient var hasBlocksLock = new Object
+
+ @transient var listOfSources = ListBuffer[SourceInfo]()
+
+ @transient var serveMR: ServeMultipleRequests = null
+ @transient var guideMR: GuideMultipleRequests = null
+
+ @transient var hostAddress = Utils.localIpAddress
+ @transient var listenPort = -1
+ @transient var guidePort = -1
+
+ @transient var stopBroadcast = false
+
+ // Must call this after all the variables have been created/initialized
+ if (!isLocal) {
+ sendBroadcast()
+ }
+
+ def sendBroadcast() {
+ logInfo("Local host address: " + hostAddress)
+
+ // Create a variableInfo object and store it in valueInfos
+ var variableInfo = MultiTracker.blockifyObject(value_)
+
+ // Prepare the value being broadcasted
+ arrayOfBlocks = variableInfo.arrayOfBlocks
+ totalBytes = variableInfo.totalBytes
+ totalBlocks = variableInfo.totalBlocks
+ hasBlocks = variableInfo.totalBlocks
+
+ guideMR = new GuideMultipleRequests
+ guideMR.setDaemon(true)
+ guideMR.start()
+ logInfo("GuideMultipleRequests started...")
+
+ // Must always come AFTER guideMR is created
+ while (guidePort == -1) {
+ guidePortLock.synchronized { guidePortLock.wait() }
+ }
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start()
+ logInfo("ServeMultipleRequests started...")
+
+ // Must always come AFTER serveMR is created
+ while (listenPort == -1) {
+ listenPortLock.synchronized { listenPortLock.wait() }
+ }
+
+ // Must always come AFTER listenPort is created
+ val masterSource =
+ SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
+ listOfSources += masterSource
+
+ // Register with the Tracker
+ MultiTracker.registerBroadcast(id,
+ SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
+ }
+
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ MultiTracker.synchronized {
+ SparkEnv.get.blockManager.getSingle(blockId) match {
+ case Some(x) =>
+ value_ = x.asInstanceOf[T]
+
+ case None =>
+ logInfo("Started reading broadcast variable " + id)
+ // Initializing everything because Driver will only send null/0 values
+ // Only the 1st worker in a node can be here. Others will get from cache
+ initializeWorkerVariables()
+
+ logInfo("Local host address: " + hostAddress)
+
+ serveMR = new ServeMultipleRequests
+ serveMR.setDaemon(true)
+ serveMR.start()
+ logInfo("ServeMultipleRequests started...")
+
+ val start = System.nanoTime
+
+ val receptionSucceeded = receiveBroadcast(id)
+ if (receptionSucceeded) {
+ value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+ SparkEnv.get.blockManager.putSingle(
+ blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ } else {
+ logError("Reading broadcast variable " + id + " failed")
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ }
+ }
+ }
+
+ private def initializeWorkerVariables() {
+ arrayOfBlocks = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = 0
+
+ listenPortLock = new Object
+ totalBlocksLock = new Object
+ hasBlocksLock = new Object
+
+ serveMR = null
+
+ hostAddress = Utils.localIpAddress
+ listenPort = -1
+
+ stopBroadcast = false
+ }
+
+ def receiveBroadcast(variableID: Long): Boolean = {
+ val gInfo = MultiTracker.getGuideInfo(variableID)
+
+ if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
+ return false
+ }
+
+ // Wait until hostAddress and listenPort are created by the
+ // ServeMultipleRequests thread
+ while (listenPort == -1) {
+ listenPortLock.synchronized { listenPortLock.wait() }
+ }
+
+ var clientSocketToDriver: Socket = null
+ var oosDriver: ObjectOutputStream = null
+ var oisDriver: ObjectInputStream = null
+
+ // Connect and receive broadcast from the specified source, retrying the
+ // specified number of times in case of failures
+ var retriesLeft = MultiTracker.MaxRetryCount
+ do {
+ // Connect to Driver and send this worker's Information
+ clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
+ oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
+ oosDriver.flush()
+ oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
+
+ logDebug("Connected to Driver's guiding object")
+
+ // Send local source information
+ oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
+ oosDriver.flush()
+
+ // Receive source information from Driver
+ var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
+ totalBlocks = sourceInfo.totalBlocks
+ arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
+ totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
+ totalBytes = sourceInfo.totalBytes
+
+ logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
+
+ val start = System.nanoTime
+ val receptionSucceeded = receiveSingleTransmission(sourceInfo)
+ val time = (System.nanoTime - start) / 1e9
+
+ // Updating some statistics in sourceInfo. Driver will be using them later
+ if (!receptionSucceeded) {
+ sourceInfo.receptionFailed = true
+ }
+
+ // Send back statistics to the Driver
+ oosDriver.writeObject(sourceInfo)
+
+ if (oisDriver != null) {
+ oisDriver.close()
+ }
+ if (oosDriver != null) {
+ oosDriver.close()
+ }
+ if (clientSocketToDriver != null) {
+ clientSocketToDriver.close()
+ }
+
+ retriesLeft -= 1
+ } while (retriesLeft > 0 && hasBlocks < totalBlocks)
+
+ return (hasBlocks == totalBlocks)
+ }
+
+ /**
+ * Tries to receive broadcast from the source and returns Boolean status.
+ * This might be called multiple times to retry a defined number of times.
+ */
+ private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
+ var clientSocketToSource: Socket = null
+ var oosSource: ObjectOutputStream = null
+ var oisSource: ObjectInputStream = null
+
+ var receptionSucceeded = false
+ try {
+ // Connect to the source to get the object itself
+ clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
+ oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
+ oosSource.flush()
+ oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
+
+ logDebug("Inside receiveSingleTransmission")
+ logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
+
+ // Send the range
+ oosSource.writeObject((hasBlocks, totalBlocks))
+ oosSource.flush()
+
+ for (i <- hasBlocks until totalBlocks) {
+ val recvStartTime = System.currentTimeMillis
+ val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
+ val receptionTime = (System.currentTimeMillis - recvStartTime)
+
+ logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
+
+ arrayOfBlocks(hasBlocks) = bcBlock
+ hasBlocks += 1
+
+ // Set to true if at least one block is received
+ receptionSucceeded = true
+ hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
+ }
+ } catch {
+ case e: Exception => logError("receiveSingleTransmission had a " + e)
+ } finally {
+ if (oisSource != null) {
+ oisSource.close()
+ }
+ if (oosSource != null) {
+ oosSource.close()
+ }
+ if (clientSocketToSource != null) {
+ clientSocketToSource.close()
+ }
+ }
+
+ return receptionSucceeded
+ }
+
+ class GuideMultipleRequests
+ extends Thread with Logging {
+ // Keep track of sources that have completed reception
+ private var setOfCompletedSources = Set[SourceInfo]()
+
+ override def run() {
+ var threadPool = Utils.newDaemonCachedThreadPool()
+ var serverSocket: ServerSocket = null
+
+ serverSocket = new ServerSocket(0)
+ guidePort = serverSocket.getLocalPort
+ logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
+
+ guidePortLock.synchronized { guidePortLock.notifyAll() }
+
+ try {
+ while (!stopBroadcast) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => {
+ // Stop broadcast if at least one worker has connected and
+ // everyone connected so far are done. Comparing with
+ // listOfSources.size - 1, because it includes the Guide itself
+ listOfSources.synchronized {
+ setOfCompletedSources.synchronized {
+ if (listOfSources.size > 1 &&
+ setOfCompletedSources.size == listOfSources.size - 1) {
+ stopBroadcast = true
+ logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
+ }
+ }
+ }
+ }
+ }
+ if (clientSocket != null) {
+ logDebug("Guide: Accepted new client connection: " + clientSocket)
+ try {
+ threadPool.execute(new GuideSingleRequest(clientSocket))
+ } catch {
+ // In failure, close() the socket here; else, the thread will close() it
+ case ioe: IOException => clientSocket.close()
+ }
+ }
+ }
+
+ logInfo("Sending stopBroadcast notifications...")
+ sendStopBroadcastNotifications
+
+ MultiTracker.unregisterBroadcast(id)
+ } finally {
+ if (serverSocket != null) {
+ logInfo("GuideMultipleRequests now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ private def sendStopBroadcastNotifications() {
+ listOfSources.synchronized {
+ var listIter = listOfSources.iterator
+ while (listIter.hasNext) {
+ var sourceInfo = listIter.next
+
+ var guideSocketToSource: Socket = null
+ var gosSource: ObjectOutputStream = null
+ var gisSource: ObjectInputStream = null
+
+ try {
+ // Connect to the source
+ guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
+ gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
+ gosSource.flush()
+ gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
+
+ // Send stopBroadcast signal
+ gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
+ gosSource.flush()
+ } catch {
+ case e: Exception => {
+ logError("sendStopBroadcastNotifications had a " + e)
+ }
+ } finally {
+ if (gisSource != null) {
+ gisSource.close()
+ }
+ if (gosSource != null) {
+ gosSource.close()
+ }
+ if (guideSocketToSource != null) {
+ guideSocketToSource.close()
+ }
+ }
+ }
+ }
+ }
+
+ class GuideSingleRequest(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ private var selectedSourceInfo: SourceInfo = null
+ private var thisWorkerInfo:SourceInfo = null
+
+ override def run() {
+ try {
+ logInfo("new GuideSingleRequest is running")
+ // Connecting worker is sending in its hostAddress and listenPort it will
+ // be listening to. Other fields are invalid (SourceInfo.UnusedParam)
+ var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ listOfSources.synchronized {
+ // Select a suitable source and send it back to the worker
+ selectedSourceInfo = selectSuitableSource(sourceInfo)
+ logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
+ oos.writeObject(selectedSourceInfo)
+ oos.flush()
+
+ // Add this new (if it can finish) source to the list of sources
+ thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
+ sourceInfo.listenPort, totalBlocks, totalBytes)
+ logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
+ listOfSources += thisWorkerInfo
+ }
+
+ // Wait till the whole transfer is done. Then receive and update source
+ // statistics in listOfSources
+ sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
+
+ listOfSources.synchronized {
+ // This should work since SourceInfo is a case class
+ assert(listOfSources.contains(selectedSourceInfo))
+
+ // Remove first
+ // (Currently removing a source based on just one failure notification!)
+ listOfSources = listOfSources - selectedSourceInfo
+
+ // Update sourceInfo and put it back in, IF reception succeeded
+ if (!sourceInfo.receptionFailed) {
+ // Add thisWorkerInfo to sources that have completed reception
+ setOfCompletedSources.synchronized {
+ setOfCompletedSources += thisWorkerInfo
+ }
+
+ // Update leecher count and put it back in
+ selectedSourceInfo.currentLeechers -= 1
+ listOfSources += selectedSourceInfo
+ }
+ }
+ } catch {
+ case e: Exception => {
+ // Remove failed worker from listOfSources and update leecherCount of
+ // corresponding source worker
+ listOfSources.synchronized {
+ if (selectedSourceInfo != null) {
+ // Remove first
+ listOfSources = listOfSources - selectedSourceInfo
+ // Update leecher count and put it back in
+ selectedSourceInfo.currentLeechers -= 1
+ listOfSources += selectedSourceInfo
+ }
+
+ // Remove thisWorkerInfo
+ if (listOfSources != null) {
+ listOfSources = listOfSources - thisWorkerInfo
+ }
+ }
+ }
+ } finally {
+ logInfo("GuideSingleRequest is closing streams and sockets")
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+
+ // Assuming the caller to have a synchronized block on listOfSources
+ // Select one with the most leechers. This will level-wise fill the tree
+ private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
+ var maxLeechers = -1
+ var selectedSource: SourceInfo = null
+
+ listOfSources.foreach { source =>
+ if ((source.hostAddress != skipSourceInfo.hostAddress ||
+ source.listenPort != skipSourceInfo.listenPort) &&
+ source.currentLeechers < MultiTracker.MaxDegree &&
+ source.currentLeechers > maxLeechers) {
+ selectedSource = source
+ maxLeechers = source.currentLeechers
+ }
+ }
+
+ // Update leecher count
+ selectedSource.currentLeechers += 1
+ return selectedSource
+ }
+ }
+ }
+
+ class ServeMultipleRequests
+ extends Thread with Logging {
+
+ var threadPool = Utils.newDaemonCachedThreadPool()
+
+ override def run() {
+ var serverSocket = new ServerSocket(0)
+ listenPort = serverSocket.getLocalPort
+
+ logInfo("ServeMultipleRequests started with " + serverSocket)
+
+ listenPortLock.synchronized { listenPortLock.notifyAll() }
+
+ try {
+ while (!stopBroadcast) {
+ var clientSocket: Socket = null
+ try {
+ serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
+ clientSocket = serverSocket.accept
+ } catch {
+ case e: Exception => { }
+ }
+
+ if (clientSocket != null) {
+ logDebug("Serve: Accepted new client connection: " + clientSocket)
+ try {
+ threadPool.execute(new ServeSingleRequest(clientSocket))
+ } catch {
+ // In failure, close socket here; else, the thread will close it
+ case ioe: IOException => clientSocket.close()
+ }
+ }
+ }
+ } finally {
+ if (serverSocket != null) {
+ logInfo("ServeMultipleRequests now stopping...")
+ serverSocket.close()
+ }
+ }
+ // Shutdown the thread pool
+ threadPool.shutdown()
+ }
+
+ class ServeSingleRequest(val clientSocket: Socket)
+ extends Thread with Logging {
+ private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
+ oos.flush()
+ private val ois = new ObjectInputStream(clientSocket.getInputStream)
+
+ private var sendFrom = 0
+ private var sendUntil = totalBlocks
+
+ override def run() {
+ try {
+ logInfo("new ServeSingleRequest is running")
+
+ // Receive range to send
+ var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
+ sendFrom = rangeToSend._1
+ sendUntil = rangeToSend._2
+
+ // If not a valid range, stop broadcast
+ if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
+ stopBroadcast = true
+ } else {
+ sendObject
+ }
+ } catch {
+ case e: Exception => logError("ServeSingleRequest had a " + e)
+ } finally {
+ logInfo("ServeSingleRequest is closing streams and sockets")
+ ois.close()
+ oos.close()
+ clientSocket.close()
+ }
+ }
+
+ private def sendObject() {
+ // Wait till receiving the SourceInfo from Driver
+ while (totalBlocks == -1) {
+ totalBlocksLock.synchronized { totalBlocksLock.wait() }
+ }
+
+ for (i <- sendFrom until sendUntil) {
+ while (i == hasBlocks) {
+ hasBlocksLock.synchronized { hasBlocksLock.wait() }
+ }
+ try {
+ oos.writeObject(arrayOfBlocks(i))
+ oos.flush()
+ } catch {
+ case e: Exception => logError("sendObject had a " + e)
+ }
+ logDebug("Sent block: " + i + " to " + clientSocket)
+ }
+ }
+ }
+ }
+}
+
+private[spark] class TreeBroadcastFactory
+extends BroadcastFactory {
+ def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new TreeBroadcast[T](value_, isLocal, id)
+
+ def stop() { MultiTracker.stop() }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
new file mode 100644
index 0000000000..19d393a0db
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+private[spark] class ApplicationDescription(
+ val name: String,
+ val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */
+ val memoryPerSlave: Int,
+ val command: Command,
+ val sparkHome: String,
+ val appUiUrl: String)
+ extends Serializable {
+
+ val user = System.getProperty("user.name", "<unknown>")
+
+ override def toString: String = "ApplicationDescription(" + name + ")"
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/Command.scala b/core/src/main/scala/org/apache/spark/deploy/Command.scala
new file mode 100644
index 0000000000..fa8af9a646
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/Command.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import scala.collection.Map
+
+private[spark] case class Command(
+ mainClass: String,
+ arguments: Seq[String],
+ environment: Map[String, String]) {
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
new file mode 100644
index 0000000000..4dc6ada2d1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import scala.collection.immutable.List
+
+import org.apache.spark.Utils
+import org.apache.spark.deploy.ExecutorState.ExecutorState
+import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo}
+import org.apache.spark.deploy.worker.ExecutorRunner
+
+
+private[deploy] sealed trait DeployMessage extends Serializable
+
+private[deploy] object DeployMessages {
+
+ // Worker to Master
+
+ case class RegisterWorker(
+ id: String,
+ host: String,
+ port: Int,
+ cores: Int,
+ memory: Int,
+ webUiPort: Int,
+ publicAddress: String)
+ extends DeployMessage {
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+ }
+
+ case class ExecutorStateChanged(
+ appId: String,
+ execId: Int,
+ state: ExecutorState,
+ message: Option[String],
+ exitStatus: Option[Int])
+ extends DeployMessage
+
+ case class Heartbeat(workerId: String) extends DeployMessage
+
+ // Master to Worker
+
+ case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
+
+ case class RegisterWorkerFailed(message: String) extends DeployMessage
+
+ case class KillExecutor(appId: String, execId: Int) extends DeployMessage
+
+ case class LaunchExecutor(
+ appId: String,
+ execId: Int,
+ appDesc: ApplicationDescription,
+ cores: Int,
+ memory: Int,
+ sparkHome: String)
+ extends DeployMessage
+
+ // Client to Master
+
+ case class RegisterApplication(appDescription: ApplicationDescription)
+ extends DeployMessage
+
+ // Master to Client
+
+ case class RegisteredApplication(appId: String) extends DeployMessage
+
+ // TODO(matei): replace hostPort with host
+ case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
+ Utils.checkHostPort(hostPort, "Required hostport")
+ }
+
+ case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String],
+ exitStatus: Option[Int])
+
+ case class ApplicationRemoved(message: String)
+
+ // Internal message in Client
+
+ case object StopClient
+
+ // MasterWebUI To Master
+
+ case object RequestMasterState
+
+ // Master to MasterWebUI
+
+ case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo],
+ activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
+
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+
+ def uri = "spark://" + host + ":" + port
+ }
+
+ // WorkerWebUI to Worker
+
+ case object RequestWorkerState
+
+ // Worker to WorkerWebUI
+
+ case class WorkerStateResponse(host: String, port: Int, workerId: String,
+ executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String,
+ cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) {
+
+ Utils.checkHost(host, "Required hostname")
+ assert (port > 0)
+ }
+
+ // Actor System to Master
+
+ case object CheckForWorkerTimeOut
+
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala
new file mode 100644
index 0000000000..fcfea96ad6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+private[spark] object ExecutorState
+ extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") {
+
+ val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value
+
+ type ExecutorState = Value
+
+ def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST).contains(state)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
new file mode 100644
index 0000000000..a6be8efef1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import net.liftweb.json.JsonDSL._
+
+import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse}
+import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo}
+import org.apache.spark.deploy.worker.ExecutorRunner
+
+
+private[spark] object JsonProtocol {
+ def writeWorkerInfo(obj: WorkerInfo) = {
+ ("id" -> obj.id) ~
+ ("host" -> obj.host) ~
+ ("port" -> obj.port) ~
+ ("webuiaddress" -> obj.webUiAddress) ~
+ ("cores" -> obj.cores) ~
+ ("coresused" -> obj.coresUsed) ~
+ ("memory" -> obj.memory) ~
+ ("memoryused" -> obj.memoryUsed) ~
+ ("state" -> obj.state.toString)
+ }
+
+ def writeApplicationInfo(obj: ApplicationInfo) = {
+ ("starttime" -> obj.startTime) ~
+ ("id" -> obj.id) ~
+ ("name" -> obj.desc.name) ~
+ ("cores" -> obj.desc.maxCores) ~
+ ("user" -> obj.desc.user) ~
+ ("memoryperslave" -> obj.desc.memoryPerSlave) ~
+ ("submitdate" -> obj.submitDate.toString)
+ }
+
+ def writeApplicationDescription(obj: ApplicationDescription) = {
+ ("name" -> obj.name) ~
+ ("cores" -> obj.maxCores) ~
+ ("memoryperslave" -> obj.memoryPerSlave) ~
+ ("user" -> obj.user)
+ }
+
+ def writeExecutorRunner(obj: ExecutorRunner) = {
+ ("id" -> obj.execId) ~
+ ("memory" -> obj.memory) ~
+ ("appid" -> obj.appId) ~
+ ("appdesc" -> writeApplicationDescription(obj.appDesc))
+ }
+
+ def writeMasterState(obj: MasterStateResponse) = {
+ ("url" -> ("spark://" + obj.uri)) ~
+ ("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~
+ ("cores" -> obj.workers.map(_.cores).sum) ~
+ ("coresused" -> obj.workers.map(_.coresUsed).sum) ~
+ ("memory" -> obj.workers.map(_.memory).sum) ~
+ ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~
+ ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~
+ ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo))
+ }
+
+ def writeWorkerState(obj: WorkerStateResponse) = {
+ ("id" -> obj.workerId) ~
+ ("masterurl" -> obj.masterUrl) ~
+ ("masterwebuiurl" -> obj.masterWebUiUrl) ~
+ ("cores" -> obj.cores) ~
+ ("coresused" -> obj.coresUsed) ~
+ ("memory" -> obj.memory) ~
+ ("memoryused" -> obj.memoryUsed) ~
+ ("executors" -> obj.executors.toList.map(writeExecutorRunner)) ~
+ ("finishedexecutors" -> obj.finishedExecutors.toList.map(writeExecutorRunner))
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
new file mode 100644
index 0000000000..af5a4110b0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
+
+import org.apache.spark.deploy.worker.Worker
+import org.apache.spark.deploy.master.Master
+import org.apache.spark.util.AkkaUtils
+import org.apache.spark.{Logging, Utils}
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * Testing class that creates a Spark standalone process in-cluster (that is, running the
+ * spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched
+ * by the Workers still run in separate JVMs. This can be used to test distributed operation and
+ * fault recovery without spinning up a lot of processes.
+ */
+private[spark]
+class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
+
+ private val localHostname = Utils.localHostName()
+ private val masterActorSystems = ArrayBuffer[ActorSystem]()
+ private val workerActorSystems = ArrayBuffer[ActorSystem]()
+
+ def start(): String = {
+ logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
+
+ /* Start the Master */
+ val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0)
+ masterActorSystems += masterSystem
+ val masterUrl = "spark://" + localHostname + ":" + masterPort
+
+ /* Start the Workers */
+ for (workerNum <- 1 to numWorkers) {
+ val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
+ memoryPerWorker, masterUrl, null, Some(workerNum))
+ workerActorSystems += workerSystem
+ }
+
+ return masterUrl
+ }
+
+ def stop() {
+ logInfo("Shutting down local Spark cluster.")
+ // Stop the workers before the master so they don't get upset that it disconnected
+ workerActorSystems.foreach(_.shutdown())
+ workerActorSystems.foreach(_.awaitTermination())
+
+ masterActorSystems.foreach(_.shutdown())
+ masterActorSystems.foreach(_.awaitTermination())
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
new file mode 100644
index 0000000000..0a5f4c368f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapred.JobConf
+
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+class SparkHadoopUtil {
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ def newConfiguration(): Configuration = new Configuration()
+
+ // add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster
+ def addCredentials(conf: JobConf) {}
+
+ def isYarnMode(): Boolean = { false }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/WebUI.scala b/core/src/main/scala/org/apache/spark/deploy/WebUI.scala
new file mode 100644
index 0000000000..ae258b58b9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/WebUI.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.text.SimpleDateFormat
+import java.util.Date
+
+/**
+ * Utilities used throughout the web UI.
+ */
+private[spark] object DeployWebUI {
+ val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+
+ def formatDate(date: Date): String = DATE_FORMAT.format(date)
+
+ def formatDate(timestamp: Long): String = DATE_FORMAT.format(new Date(timestamp))
+
+ def formatDuration(milliseconds: Long): String = {
+ val seconds = milliseconds.toDouble / 1000
+ if (seconds < 60) {
+ return "%.0f s".format(seconds)
+ }
+ val minutes = seconds / 60
+ if (minutes < 10) {
+ return "%.1f min".format(minutes)
+ } else if (minutes < 60) {
+ return "%.0f min".format(minutes)
+ }
+ val hours = minutes / 60
+ return "%.1f h".format(hours)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
new file mode 100644
index 0000000000..a342dd724a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.client
+
+import java.util.concurrent.TimeoutException
+
+import akka.actor._
+import akka.actor.Terminated
+import akka.pattern.ask
+import akka.util.Duration
+import akka.remote.RemoteClientDisconnected
+import akka.remote.RemoteClientLifeCycleEvent
+import akka.remote.RemoteClientShutdown
+import akka.dispatch.Await
+
+import org.apache.spark.Logging
+import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
+import org.apache.spark.deploy.DeployMessages._
+import org.apache.spark.deploy.master.Master
+
+
+/**
+ * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description,
+ * and a listener for cluster events, and calls back the listener when various events occur.
+ */
+private[spark] class Client(
+ actorSystem: ActorSystem,
+ masterUrl: String,
+ appDescription: ApplicationDescription,
+ listener: ClientListener)
+ extends Logging {
+
+ var actor: ActorRef = null
+ var appId: String = null
+
+ class ClientActor extends Actor with Logging {
+ var master: ActorRef = null
+ var masterAddress: Address = null
+ var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
+
+ override def preStart() {
+ logInfo("Connecting to master " + masterUrl)
+ try {
+ master = context.actorFor(Master.toAkkaUrl(masterUrl))
+ masterAddress = master.path.address
+ master ! RegisterApplication(appDescription)
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to master", e)
+ markDisconnected()
+ context.stop(self)
+ }
+ }
+
+ override def receive = {
+ case RegisteredApplication(appId_) =>
+ appId = appId_
+ listener.connected(appId)
+
+ case ApplicationRemoved(message) =>
+ logError("Master removed our application: %s; stopping client".format(message))
+ markDisconnected()
+ context.stop(self)
+
+ case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) =>
+ val fullId = appId + "/" + id
+ logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores))
+ listener.executorAdded(fullId, workerId, hostPort, cores, memory)
+
+ case ExecutorUpdated(id, state, message, exitStatus) =>
+ val fullId = appId + "/" + id
+ val messageText = message.map(s => " (" + s + ")").getOrElse("")
+ logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText))
+ if (ExecutorState.isFinished(state)) {
+ listener.executorRemoved(fullId, message.getOrElse(""), exitStatus)
+ }
+
+ case Terminated(actor_) if actor_ == master =>
+ logError("Connection to master failed; stopping client")
+ markDisconnected()
+ context.stop(self)
+
+ case RemoteClientDisconnected(transport, address) if address == masterAddress =>
+ logError("Connection to master failed; stopping client")
+ markDisconnected()
+ context.stop(self)
+
+ case RemoteClientShutdown(transport, address) if address == masterAddress =>
+ logError("Connection to master failed; stopping client")
+ markDisconnected()
+ context.stop(self)
+
+ case StopClient =>
+ markDisconnected()
+ sender ! true
+ context.stop(self)
+ }
+
+ /**
+ * Notify the listener that we disconnected, if we hadn't already done so before.
+ */
+ def markDisconnected() {
+ if (!alreadyDisconnected) {
+ listener.disconnected()
+ alreadyDisconnected = true
+ }
+ }
+ }
+
+ def start() {
+ // Just launch an actor; it will call back into the listener.
+ actor = actorSystem.actorOf(Props(new ClientActor))
+ }
+
+ def stop() {
+ if (actor != null) {
+ try {
+ val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val future = actor.ask(StopClient)(timeout)
+ Await.result(future, timeout)
+ } catch {
+ case e: TimeoutException =>
+ logInfo("Stop request to Master timed out; it may already be shut down.")
+ }
+ actor = null
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
new file mode 100644
index 0000000000..4605368c11
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.client
+
+/**
+ * Callbacks invoked by deploy client when various events happen. There are currently four events:
+ * connecting to the cluster, disconnecting, being given an executor, and having an executor
+ * removed (either due to failure or due to revocation).
+ *
+ * Users of this API should *not* block inside the callback methods.
+ */
+private[spark] trait ClientListener {
+ def connected(appId: String): Unit
+
+ def disconnected(): Unit
+
+ def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
+
+ def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
new file mode 100644
index 0000000000..0322029fbd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.client
+
+import org.apache.spark.util.AkkaUtils
+import org.apache.spark.{Logging, Utils}
+import org.apache.spark.deploy.{Command, ApplicationDescription}
+
+private[spark] object TestClient {
+
+ class TestListener extends ClientListener with Logging {
+ def connected(id: String) {
+ logInfo("Connected to master, got app ID " + id)
+ }
+
+ def disconnected() {
+ logInfo("Disconnected from master")
+ System.exit(0)
+ }
+
+ def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
+
+ def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
+ }
+
+ def main(args: Array[String]) {
+ val url = args(0)
+ val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
+ val desc = new ApplicationDescription(
+ "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored")
+ val listener = new TestListener
+ val client = new Client(actorSystem, url, desc, listener)
+ client.start()
+ actorSystem.awaitTermination()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala
new file mode 100644
index 0000000000..c5ac45c673
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/client/TestExecutor.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.client
+
+private[spark] object TestExecutor {
+ def main(args: Array[String]) {
+ println("Hello world!")
+ while (true) {
+ Thread.sleep(1000)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
new file mode 100644
index 0000000000..bd5327627a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import org.apache.spark.deploy.ApplicationDescription
+import java.util.Date
+import akka.actor.ActorRef
+import scala.collection.mutable
+
+private[spark] class ApplicationInfo(
+ val startTime: Long,
+ val id: String,
+ val desc: ApplicationDescription,
+ val submitDate: Date,
+ val driver: ActorRef,
+ val appUiUrl: String)
+{
+ var state = ApplicationState.WAITING
+ var executors = new mutable.HashMap[Int, ExecutorInfo]
+ var coresGranted = 0
+ var endTime = -1L
+ val appSource = new ApplicationSource(this)
+
+ private var nextExecutorId = 0
+
+ def newExecutorId(): Int = {
+ val id = nextExecutorId
+ nextExecutorId += 1
+ id
+ }
+
+ def addExecutor(worker: WorkerInfo, cores: Int): ExecutorInfo = {
+ val exec = new ExecutorInfo(newExecutorId(), this, worker, cores, desc.memoryPerSlave)
+ executors(exec.id) = exec
+ coresGranted += cores
+ exec
+ }
+
+ def removeExecutor(exec: ExecutorInfo) {
+ if (executors.contains(exec.id)) {
+ executors -= exec.id
+ coresGranted -= exec.cores
+ }
+ }
+
+ def coresLeft: Int = desc.maxCores - coresGranted
+
+ private var _retryCount = 0
+
+ def retryCount = _retryCount
+
+ def incrementRetryCount = {
+ _retryCount += 1
+ _retryCount
+ }
+
+ def markFinished(endState: ApplicationState.Value) {
+ state = endState
+ endTime = System.currentTimeMillis()
+ }
+
+ def duration: Long = {
+ if (endTime != -1) {
+ endTime - startTime
+ } else {
+ System.currentTimeMillis() - startTime
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala
new file mode 100644
index 0000000000..2d75ad5a2c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala
@@ -0,0 +1,24 @@
+package org.apache.spark.deploy.master
+
+import com.codahale.metrics.{Gauge, MetricRegistry}
+
+import org.apache.spark.metrics.source.Source
+
+class ApplicationSource(val application: ApplicationInfo) extends Source {
+ val metricRegistry = new MetricRegistry()
+ val sourceName = "%s.%s.%s".format("application", application.desc.name,
+ System.currentTimeMillis())
+
+ metricRegistry.register(MetricRegistry.name("status"), new Gauge[String] {
+ override def getValue: String = application.state.toString
+ })
+
+ metricRegistry.register(MetricRegistry.name("runtime_ms"), new Gauge[Long] {
+ override def getValue: Long = application.duration
+ })
+
+ metricRegistry.register(MetricRegistry.name("cores", "number"), new Gauge[Int] {
+ override def getValue: Int = application.coresGranted
+ })
+
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
new file mode 100644
index 0000000000..7e804223cf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+private[spark] object ApplicationState
+ extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
+
+ type ApplicationState = Value
+
+ val WAITING, RUNNING, FINISHED, FAILED = Value
+
+ val MAX_NUM_RETRY = 10
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
new file mode 100644
index 0000000000..cf384a985e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import org.apache.spark.deploy.ExecutorState
+
+private[spark] class ExecutorInfo(
+ val id: Int,
+ val application: ApplicationInfo,
+ val worker: WorkerInfo,
+ val cores: Int,
+ val memory: Int) {
+
+ var state = ExecutorState.LAUNCHING
+
+ def fullId: String = application.id + "/" + id
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
new file mode 100644
index 0000000000..869b2b2646
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -0,0 +1,386 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import java.text.SimpleDateFormat
+import java.util.Date
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
+import akka.actor._
+import akka.actor.Terminated
+import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown}
+import akka.util.duration._
+
+import org.apache.spark.{Logging, SparkException, Utils}
+import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
+import org.apache.spark.deploy.DeployMessages._
+import org.apache.spark.deploy.master.ui.MasterWebUI
+import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.util.AkkaUtils
+
+
+private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
+ val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
+ val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
+ val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt
+ val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt
+
+ var nextAppNumber = 0
+ val workers = new HashSet[WorkerInfo]
+ val idToWorker = new HashMap[String, WorkerInfo]
+ val actorToWorker = new HashMap[ActorRef, WorkerInfo]
+ val addressToWorker = new HashMap[Address, WorkerInfo]
+
+ val apps = new HashSet[ApplicationInfo]
+ val idToApp = new HashMap[String, ApplicationInfo]
+ val actorToApp = new HashMap[ActorRef, ApplicationInfo]
+ val addressToApp = new HashMap[Address, ApplicationInfo]
+
+ val waitingApps = new ArrayBuffer[ApplicationInfo]
+ val completedApps = new ArrayBuffer[ApplicationInfo]
+
+ var firstApp: Option[ApplicationInfo] = None
+
+ Utils.checkHost(host, "Expected hostname")
+
+ val masterMetricsSystem = MetricsSystem.createMetricsSystem("master")
+ val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications")
+ val masterSource = new MasterSource(this)
+
+ val webUi = new MasterWebUI(this, webUiPort)
+
+ val masterPublicAddress = {
+ val envVar = System.getenv("SPARK_PUBLIC_DNS")
+ if (envVar != null) envVar else host
+ }
+
+ // As a temporary workaround before better ways of configuring memory, we allow users to set
+ // a flag that will perform round-robin scheduling across the nodes (spreading out each app
+ // among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
+ val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
+
+ override def preStart() {
+ logInfo("Starting Spark master at spark://" + host + ":" + port)
+ // Listen for remote client disconnection events, since they don't go through Akka's watch()
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ webUi.start()
+ context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
+
+ masterMetricsSystem.registerSource(masterSource)
+ masterMetricsSystem.start()
+ applicationMetricsSystem.start()
+ }
+
+ override def postStop() {
+ webUi.stop()
+ masterMetricsSystem.stop()
+ applicationMetricsSystem.stop()
+ }
+
+ override def receive = {
+ case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) => {
+ logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
+ host, workerPort, cores, Utils.megabytesToString(memory)))
+ if (idToWorker.contains(id)) {
+ sender ! RegisterWorkerFailed("Duplicate worker ID")
+ } else {
+ addWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress)
+ context.watch(sender) // This doesn't work with remote actors but helps for testing
+ sender ! RegisteredWorker("http://" + masterPublicAddress + ":" + webUi.boundPort.get)
+ schedule()
+ }
+ }
+
+ case RegisterApplication(description) => {
+ logInfo("Registering app " + description.name)
+ val app = addApplication(description, sender)
+ logInfo("Registered app " + description.name + " with ID " + app.id)
+ waitingApps += app
+ context.watch(sender) // This doesn't work with remote actors but helps for testing
+ sender ! RegisteredApplication(app.id)
+ schedule()
+ }
+
+ case ExecutorStateChanged(appId, execId, state, message, exitStatus) => {
+ val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId))
+ execOption match {
+ case Some(exec) => {
+ exec.state = state
+ exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus)
+ if (ExecutorState.isFinished(state)) {
+ val appInfo = idToApp(appId)
+ // Remove this executor from the worker and app
+ logInfo("Removing executor " + exec.fullId + " because it is " + state)
+ appInfo.removeExecutor(exec)
+ exec.worker.removeExecutor(exec)
+
+ // Only retry certain number of times so we don't go into an infinite loop.
+ if (appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) {
+ schedule()
+ } else {
+ logError("Application %s with ID %s failed %d times, removing it".format(
+ appInfo.desc.name, appInfo.id, appInfo.retryCount))
+ removeApplication(appInfo, ApplicationState.FAILED)
+ }
+ }
+ }
+ case None =>
+ logWarning("Got status update for unknown executor " + appId + "/" + execId)
+ }
+ }
+
+ case Heartbeat(workerId) => {
+ idToWorker.get(workerId) match {
+ case Some(workerInfo) =>
+ workerInfo.lastHeartbeat = System.currentTimeMillis()
+ case None =>
+ logWarning("Got heartbeat from unregistered worker " + workerId)
+ }
+ }
+
+ case Terminated(actor) => {
+ // The disconnected actor could've been either a worker or an app; remove whichever of
+ // those we have an entry for in the corresponding actor hashmap
+ actorToWorker.get(actor).foreach(removeWorker)
+ actorToApp.get(actor).foreach(finishApplication)
+ }
+
+ case RemoteClientDisconnected(transport, address) => {
+ // The disconnected client could've been either a worker or an app; remove whichever it was
+ addressToWorker.get(address).foreach(removeWorker)
+ addressToApp.get(address).foreach(finishApplication)
+ }
+
+ case RemoteClientShutdown(transport, address) => {
+ // The disconnected client could've been either a worker or an app; remove whichever it was
+ addressToWorker.get(address).foreach(removeWorker)
+ addressToApp.get(address).foreach(finishApplication)
+ }
+
+ case RequestMasterState => {
+ sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray)
+ }
+
+ case CheckForWorkerTimeOut => {
+ timeOutDeadWorkers()
+ }
+ }
+
+ /**
+ * Can an app use the given worker? True if the worker has enough memory and we haven't already
+ * launched an executor for the app on it (right now the standalone backend doesn't like having
+ * two executors on the same worker).
+ */
+ def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = {
+ worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app)
+ }
+
+ /**
+ * Schedule the currently available resources among waiting apps. This method will be called
+ * every time a new app joins or resource availability changes.
+ */
+ def schedule() {
+ // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app
+ // in the queue, then the second app, etc.
+ if (spreadOutApps) {
+ // Try to spread out each app among all the nodes, until it has all its cores
+ for (app <- waitingApps if app.coresLeft > 0) {
+ val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE)
+ .filter(canUse(app, _)).sortBy(_.coresFree).reverse
+ val numUsable = usableWorkers.length
+ val assigned = new Array[Int](numUsable) // Number of cores to give on each node
+ var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum)
+ var pos = 0
+ while (toAssign > 0) {
+ if (usableWorkers(pos).coresFree - assigned(pos) > 0) {
+ toAssign -= 1
+ assigned(pos) += 1
+ }
+ pos = (pos + 1) % numUsable
+ }
+ // Now that we've decided how many cores to give on each node, let's actually give them
+ for (pos <- 0 until numUsable) {
+ if (assigned(pos) > 0) {
+ val exec = app.addExecutor(usableWorkers(pos), assigned(pos))
+ launchExecutor(usableWorkers(pos), exec, app.desc.sparkHome)
+ app.state = ApplicationState.RUNNING
+ }
+ }
+ }
+ } else {
+ // Pack each app into as few nodes as possible until we've assigned all its cores
+ for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) {
+ for (app <- waitingApps if app.coresLeft > 0) {
+ if (canUse(app, worker)) {
+ val coresToUse = math.min(worker.coresFree, app.coresLeft)
+ if (coresToUse > 0) {
+ val exec = app.addExecutor(worker, coresToUse)
+ launchExecutor(worker, exec, app.desc.sparkHome)
+ app.state = ApplicationState.RUNNING
+ }
+ }
+ }
+ }
+ }
+ }
+
+ def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
+ logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
+ worker.addExecutor(exec)
+ worker.actor ! LaunchExecutor(
+ exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
+ exec.application.driver ! ExecutorAdded(
+ exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
+ }
+
+ def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
+ publicAddress: String): WorkerInfo = {
+ // There may be one or more refs to dead workers on this same node (w/ different ID's),
+ // remove them.
+ workers.filter { w =>
+ (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)
+ }.foreach { w =>
+ workers -= w
+ }
+ val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
+ workers += worker
+ idToWorker(worker.id) = worker
+ actorToWorker(sender) = worker
+ addressToWorker(sender.path.address) = worker
+ worker
+ }
+
+ def removeWorker(worker: WorkerInfo) {
+ logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port)
+ worker.setState(WorkerState.DEAD)
+ idToWorker -= worker.id
+ actorToWorker -= worker.actor
+ addressToWorker -= worker.actor.path.address
+ for (exec <- worker.executors.values) {
+ logInfo("Telling app of lost executor: " + exec.id)
+ exec.application.driver ! ExecutorUpdated(
+ exec.id, ExecutorState.LOST, Some("worker lost"), None)
+ exec.application.removeExecutor(exec)
+ }
+ }
+
+ def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
+ val now = System.currentTimeMillis()
+ val date = new Date(now)
+ val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
+ applicationMetricsSystem.registerSource(app.appSource)
+ apps += app
+ idToApp(app.id) = app
+ actorToApp(driver) = app
+ addressToApp(driver.path.address) = app
+ if (firstApp == None) {
+ firstApp = Some(app)
+ }
+ val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray
+ if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) {
+ logWarning("Could not find any workers with enough memory for " + firstApp.get.id)
+ }
+ app
+ }
+
+ def finishApplication(app: ApplicationInfo) {
+ removeApplication(app, ApplicationState.FINISHED)
+ }
+
+ def removeApplication(app: ApplicationInfo, state: ApplicationState.Value) {
+ if (apps.contains(app)) {
+ logInfo("Removing app " + app.id)
+ apps -= app
+ idToApp -= app.id
+ actorToApp -= app.driver
+ addressToApp -= app.driver.path.address
+ if (completedApps.size >= RETAINED_APPLICATIONS) {
+ val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1)
+ completedApps.take(toRemove).foreach( a => {
+ applicationMetricsSystem.removeSource(a.appSource)
+ })
+ completedApps.trimStart(toRemove)
+ }
+ completedApps += app // Remember it in our history
+ waitingApps -= app
+ for (exec <- app.executors.values) {
+ exec.worker.removeExecutor(exec)
+ exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
+ exec.state = ExecutorState.KILLED
+ }
+ app.markFinished(state)
+ if (state != ApplicationState.FINISHED) {
+ app.driver ! ApplicationRemoved(state.toString)
+ }
+ schedule()
+ }
+ }
+
+ /** Generate a new app ID given a app's submission date */
+ def newApplicationId(submitDate: Date): String = {
+ val appId = "app-%s-%04d".format(DATE_FORMAT.format(submitDate), nextAppNumber)
+ nextAppNumber += 1
+ appId
+ }
+
+ /** Check for, and remove, any timed-out workers */
+ def timeOutDeadWorkers() {
+ // Copy the workers into an array so we don't modify the hashset while iterating through it
+ val currentTime = System.currentTimeMillis()
+ val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray
+ for (worker <- toRemove) {
+ if (worker.state != WorkerState.DEAD) {
+ logWarning("Removing %s because we got no heartbeat in %d seconds".format(
+ worker.id, WORKER_TIMEOUT/1000))
+ removeWorker(worker)
+ } else {
+ if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT))
+ workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it
+ }
+ }
+ }
+}
+
+private[spark] object Master {
+ private val systemName = "sparkMaster"
+ private val actorName = "Master"
+ private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
+
+ def main(argStrings: Array[String]) {
+ val args = new MasterArguments(argStrings)
+ val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort)
+ actorSystem.awaitTermination()
+ }
+
+ /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
+ def toAkkaUrl(sparkUrl: String): String = {
+ sparkUrl match {
+ case sparkUrlRegex(host, port) =>
+ "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
+ case _ =>
+ throw new SparkException("Invalid master URL: " + sparkUrl)
+ }
+ }
+
+ def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int) = {
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
+ val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
+ (actorSystem, boundPort)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
new file mode 100644
index 0000000000..c86cca278d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import org.apache.spark.util.IntParam
+import org.apache.spark.Utils
+
+/**
+ * Command-line parser for the master.
+ */
+private[spark] class MasterArguments(args: Array[String]) {
+ var host = Utils.localHostName()
+ var port = 7077
+ var webUiPort = 8080
+
+ // Check for settings in environment variables
+ if (System.getenv("SPARK_MASTER_HOST") != null) {
+ host = System.getenv("SPARK_MASTER_HOST")
+ }
+ if (System.getenv("SPARK_MASTER_PORT") != null) {
+ port = System.getenv("SPARK_MASTER_PORT").toInt
+ }
+ if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) {
+ webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt
+ }
+ if (System.getProperty("master.ui.port") != null) {
+ webUiPort = System.getProperty("master.ui.port").toInt
+ }
+
+ parse(args.toList)
+
+ def parse(args: List[String]): Unit = args match {
+ case ("--ip" | "-i") :: value :: tail =>
+ Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
+ host = value
+ parse(tail)
+
+ case ("--host" | "-h") :: value :: tail =>
+ Utils.checkHost(value, "Please use hostname " + value)
+ host = value
+ parse(tail)
+
+ case ("--port" | "-p") :: IntParam(value) :: tail =>
+ port = value
+ parse(tail)
+
+ case "--webui-port" :: IntParam(value) :: tail =>
+ webUiPort = value
+ parse(tail)
+
+ case ("--help" | "-h") :: tail =>
+ printUsageAndExit(0)
+
+ case Nil => {}
+
+ case _ =>
+ printUsageAndExit(1)
+ }
+
+ /**
+ * Print usage and exit JVM with the given exit code.
+ */
+ def printUsageAndExit(exitCode: Int) {
+ System.err.println(
+ "Usage: Master [options]\n" +
+ "\n" +
+ "Options:\n" +
+ " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" +
+ " -h HOST, --host HOST Hostname to listen on\n" +
+ " -p PORT, --port PORT Port to listen on (default: 7077)\n" +
+ " --webui-port PORT Port for web UI (default: 8080)")
+ System.exit(exitCode)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
new file mode 100644
index 0000000000..8dd0a42f71
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala
@@ -0,0 +1,25 @@
+package org.apache.spark.deploy.master
+
+import com.codahale.metrics.{Gauge, MetricRegistry}
+
+import org.apache.spark.metrics.source.Source
+
+private[spark] class MasterSource(val master: Master) extends Source {
+ val metricRegistry = new MetricRegistry()
+ val sourceName = "master"
+
+ // Gauge for worker numbers in cluster
+ metricRegistry.register(MetricRegistry.name("workers","number"), new Gauge[Int] {
+ override def getValue: Int = master.workers.size
+ })
+
+ // Gauge for application numbers in cluster
+ metricRegistry.register(MetricRegistry.name("apps", "number"), new Gauge[Int] {
+ override def getValue: Int = master.apps.size
+ })
+
+ // Gauge for waiting application numbers in cluster
+ metricRegistry.register(MetricRegistry.name("waitingApps", "number"), new Gauge[Int] {
+ override def getValue: Int = master.waitingApps.size
+ })
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
new file mode 100644
index 0000000000..285e07a823
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import akka.actor.ActorRef
+import scala.collection.mutable
+import org.apache.spark.Utils
+
+private[spark] class WorkerInfo(
+ val id: String,
+ val host: String,
+ val port: Int,
+ val cores: Int,
+ val memory: Int,
+ val actor: ActorRef,
+ val webUiPort: Int,
+ val publicAddress: String) {
+
+ Utils.checkHost(host, "Expected hostname")
+ assert (port > 0)
+
+ var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
+ var state: WorkerState.Value = WorkerState.ALIVE
+ var coresUsed = 0
+ var memoryUsed = 0
+
+ var lastHeartbeat = System.currentTimeMillis()
+
+ def coresFree: Int = cores - coresUsed
+ def memoryFree: Int = memory - memoryUsed
+
+ def hostPort: String = {
+ assert (port > 0)
+ host + ":" + port
+ }
+
+ def addExecutor(exec: ExecutorInfo) {
+ executors(exec.fullId) = exec
+ coresUsed += exec.cores
+ memoryUsed += exec.memory
+ }
+
+ def removeExecutor(exec: ExecutorInfo) {
+ if (executors.contains(exec.fullId)) {
+ executors -= exec.fullId
+ coresUsed -= exec.cores
+ memoryUsed -= exec.memory
+ }
+ }
+
+ def hasExecutor(app: ApplicationInfo): Boolean = {
+ executors.values.exists(_.application == app)
+ }
+
+ def webUiAddress : String = {
+ "http://" + this.publicAddress + ":" + this.webUiPort
+ }
+
+ def setState(state: WorkerState.Value) = {
+ this.state = state
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
new file mode 100644
index 0000000000..b5ee6dca79
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+private[spark] object WorkerState extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED") {
+ type WorkerState = Value
+
+ val ALIVE, DEAD, DECOMMISSIONED = Value
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
new file mode 100644
index 0000000000..6435c7f917
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master.ui
+
+import scala.xml.Node
+
+import akka.dispatch.Await
+import akka.pattern.ask
+import akka.util.duration._
+
+import javax.servlet.http.HttpServletRequest
+
+import net.liftweb.json.JsonAST.JValue
+
+import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
+import org.apache.spark.deploy.JsonProtocol
+import org.apache.spark.deploy.master.ExecutorInfo
+import org.apache.spark.ui.UIUtils
+import org.apache.spark.Utils
+
+private[spark] class ApplicationPage(parent: MasterWebUI) {
+ val master = parent.masterActorRef
+ implicit val timeout = parent.timeout
+
+ /** Executor details for a particular application */
+ def renderJson(request: HttpServletRequest): JValue = {
+ val appId = request.getParameter("appId")
+ val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
+ val state = Await.result(stateFuture, 30 seconds)
+ val app = state.activeApps.find(_.id == appId).getOrElse({
+ state.completedApps.find(_.id == appId).getOrElse(null)
+ })
+ JsonProtocol.writeApplicationInfo(app)
+ }
+
+ /** Executor details for a particular application */
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val appId = request.getParameter("appId")
+ val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
+ val state = Await.result(stateFuture, 30 seconds)
+ val app = state.activeApps.find(_.id == appId).getOrElse({
+ state.completedApps.find(_.id == appId).getOrElse(null)
+ })
+
+ val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs")
+ val executors = app.executors.values.toSeq
+ val executorTable = UIUtils.listingTable(executorHeaders, executorRow, executors)
+
+ val content =
+ <div class="row-fluid">
+ <div class="span12">
+ <ul class="unstyled">
+ <li><strong>ID:</strong> {app.id}</li>
+ <li><strong>Name:</strong> {app.desc.name}</li>
+ <li><strong>User:</strong> {app.desc.user}</li>
+ <li><strong>Cores:</strong>
+ {
+ if (app.desc.maxCores == Integer.MAX_VALUE) {
+ "Unlimited (%s granted)".format(app.coresGranted)
+ } else {
+ "%s (%s granted, %s left)".format(
+ app.desc.maxCores, app.coresGranted, app.coresLeft)
+ }
+ }
+ </li>
+ <li>
+ <strong>Executor Memory:</strong>
+ {Utils.megabytesToString(app.desc.memoryPerSlave)}
+ </li>
+ <li><strong>Submit Date:</strong> {app.submitDate}</li>
+ <li><strong>State:</strong> {app.state}</li>
+ <li><strong><a href={app.appUiUrl}>Application Detail UI</a></strong></li>
+ </ul>
+ </div>
+ </div>
+
+ <div class="row-fluid"> <!-- Executors -->
+ <div class="span12">
+ <h4> Executor Summary </h4>
+ {executorTable}
+ </div>
+ </div>;
+ UIUtils.basicSparkPage(content, "Application: " + app.desc.name)
+ }
+
+ def executorRow(executor: ExecutorInfo): Seq[Node] = {
+ <tr>
+ <td>{executor.id}</td>
+ <td>
+ <a href={executor.worker.webUiAddress}>{executor.worker.id}</a>
+ </td>
+ <td>{executor.cores}</td>
+ <td>{executor.memory}</td>
+ <td>{executor.state}</td>
+ <td>
+ <a href={"%s/logPage?appId=%s&executorId=%s&logType=stdout"
+ .format(executor.worker.webUiAddress, executor.application.id, executor.id)}>stdout</a>
+ <a href={"%s/logPage?appId=%s&executorId=%s&logType=stderr"
+ .format(executor.worker.webUiAddress, executor.application.id, executor.id)}>stderr</a>
+ </td>
+ </tr>
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
new file mode 100644
index 0000000000..58d3863009
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master.ui
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import akka.dispatch.Await
+import akka.pattern.ask
+import akka.util.duration._
+
+import net.liftweb.json.JsonAST.JValue
+
+import org.apache.spark.Utils
+import org.apache.spark.deploy.DeployWebUI
+import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
+import org.apache.spark.deploy.JsonProtocol
+import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo}
+import org.apache.spark.ui.UIUtils
+
+private[spark] class IndexPage(parent: MasterWebUI) {
+ val master = parent.masterActorRef
+ implicit val timeout = parent.timeout
+
+ def renderJson(request: HttpServletRequest): JValue = {
+ val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
+ val state = Await.result(stateFuture, 30 seconds)
+ JsonProtocol.writeMasterState(state)
+ }
+
+ /** Index view listing applications and executors */
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
+ val state = Await.result(stateFuture, 30 seconds)
+
+ val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory")
+ val workers = state.workers.sortBy(_.id)
+ val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
+
+ val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User",
+ "State", "Duration")
+ val activeApps = state.activeApps.sortBy(_.startTime).reverse
+ val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps)
+ val completedApps = state.completedApps.sortBy(_.endTime).reverse
+ val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps)
+
+ val content =
+ <div class="row-fluid">
+ <div class="span12">
+ <ul class="unstyled">
+ <li><strong>URL:</strong> {state.uri}</li>
+ <li><strong>Workers:</strong> {state.workers.size}</li>
+ <li><strong>Cores:</strong> {state.workers.map(_.cores).sum} Total,
+ {state.workers.map(_.coresUsed).sum} Used</li>
+ <li><strong>Memory:</strong>
+ {Utils.megabytesToString(state.workers.map(_.memory).sum)} Total,
+ {Utils.megabytesToString(state.workers.map(_.memoryUsed).sum)} Used</li>
+ <li><strong>Applications:</strong>
+ {state.activeApps.size} Running,
+ {state.completedApps.size} Completed </li>
+ </ul>
+ </div>
+ </div>
+
+ <div class="row-fluid">
+ <div class="span12">
+ <h4> Workers </h4>
+ {workerTable}
+ </div>
+ </div>
+
+ <div class="row-fluid">
+ <div class="span12">
+ <h4> Running Applications </h4>
+
+ {activeAppsTable}
+ </div>
+ </div>
+
+ <div class="row-fluid">
+ <div class="span12">
+ <h4> Completed Applications </h4>
+ {completedAppsTable}
+ </div>
+ </div>;
+ UIUtils.basicSparkPage(content, "Spark Master at " + state.uri)
+ }
+
+ def workerRow(worker: WorkerInfo): Seq[Node] = {
+ <tr>
+ <td>
+ <a href={worker.webUiAddress}>{worker.id}</a>
+ </td>
+ <td>{worker.host}:{worker.port}</td>
+ <td>{worker.state}</td>
+ <td>{worker.cores} ({worker.coresUsed} Used)</td>
+ <td sorttable_customkey={"%s.%s".format(worker.memory, worker.memoryUsed)}>
+ {Utils.megabytesToString(worker.memory)}
+ ({Utils.megabytesToString(worker.memoryUsed)} Used)
+ </td>
+ </tr>
+ }
+
+
+ def appRow(app: ApplicationInfo): Seq[Node] = {
+ <tr>
+ <td>
+ <a href={"app?appId=" + app.id}>{app.id}</a>
+ </td>
+ <td>
+ <a href={app.appUiUrl}>{app.desc.name}</a>
+ </td>
+ <td>
+ {app.coresGranted}
+ </td>
+ <td sorttable_customkey={app.desc.memoryPerSlave.toString}>
+ {Utils.megabytesToString(app.desc.memoryPerSlave)}
+ </td>
+ <td>{DeployWebUI.formatDate(app.submitDate)}</td>
+ <td>{app.desc.user}</td>
+ <td>{app.state.toString}</td>
+ <td>{DeployWebUI.formatDuration(app.duration)}</td>
+ </tr>
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
new file mode 100644
index 0000000000..47b1e521f5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master.ui
+
+import akka.util.Duration
+
+import javax.servlet.http.HttpServletRequest
+
+import org.eclipse.jetty.server.{Handler, Server}
+
+import org.apache.spark.{Logging, Utils}
+import org.apache.spark.deploy.master.Master
+import org.apache.spark.ui.JettyUtils
+import org.apache.spark.ui.JettyUtils._
+
+/**
+ * Web UI server for the standalone master.
+ */
+private[spark]
+class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
+ implicit val timeout = Duration.create(
+ System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val host = Utils.localHostName()
+ val port = requestedPort
+
+ val masterActorRef = master.self
+
+ var server: Option[Server] = None
+ var boundPort: Option[Int] = None
+
+ val applicationPage = new ApplicationPage(this)
+ val indexPage = new IndexPage(this)
+
+ def start() {
+ try {
+ val (srv, bPort) = JettyUtils.startJettyServer("0.0.0.0", port, handlers)
+ server = Some(srv)
+ boundPort = Some(bPort)
+ logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get))
+ } catch {
+ case e: Exception =>
+ logError("Failed to create Master JettyUtils", e)
+ System.exit(1)
+ }
+ }
+
+ val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++
+ master.applicationMetricsSystem.getServletHandlers
+
+ val handlers = metricsHandlers ++ Array[(String, Handler)](
+ ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)),
+ ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)),
+ ("/app", (request: HttpServletRequest) => applicationPage.render(request)),
+ ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
+ ("*", (request: HttpServletRequest) => indexPage.render(request))
+ )
+
+ def stop() {
+ server.foreach(_.stop())
+ }
+}
+
+private[spark] object MasterWebUI {
+ val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
new file mode 100644
index 0000000000..01ce4a6dea
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker
+
+import java.io._
+import java.lang.System.getenv
+
+import akka.actor.ActorRef
+
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+
+import org.apache.spark.{Utils, Logging}
+import org.apache.spark.deploy.{ExecutorState, ApplicationDescription}
+import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
+
+/**
+ * Manages the execution of one executor process.
+ */
+private[spark] class ExecutorRunner(
+ val appId: String,
+ val execId: Int,
+ val appDesc: ApplicationDescription,
+ val cores: Int,
+ val memory: Int,
+ val worker: ActorRef,
+ val workerId: String,
+ val host: String,
+ val sparkHome: File,
+ val workDir: File)
+ extends Logging {
+
+ val fullId = appId + "/" + execId
+ var workerThread: Thread = null
+ var process: Process = null
+ var shutdownHook: Thread = null
+
+ private def getAppEnv(key: String): Option[String] =
+ appDesc.command.environment.get(key).orElse(Option(getenv(key)))
+
+ def start() {
+ workerThread = new Thread("ExecutorRunner for " + fullId) {
+ override def run() { fetchAndRunExecutor() }
+ }
+ workerThread.start()
+
+ // Shutdown hook that kills actors on shutdown.
+ shutdownHook = new Thread() {
+ override def run() {
+ if (process != null) {
+ logInfo("Shutdown hook killing child process.")
+ process.destroy()
+ process.waitFor()
+ }
+ }
+ }
+ Runtime.getRuntime.addShutdownHook(shutdownHook)
+ }
+
+ /** Stop this executor runner, including killing the process it launched */
+ def kill() {
+ if (workerThread != null) {
+ workerThread.interrupt()
+ workerThread = null
+ if (process != null) {
+ logInfo("Killing process!")
+ process.destroy()
+ process.waitFor()
+ }
+ worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None)
+ Runtime.getRuntime.removeShutdownHook(shutdownHook)
+ }
+ }
+
+ /** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
+ def substituteVariables(argument: String): String = argument match {
+ case "{{EXECUTOR_ID}}" => execId.toString
+ case "{{HOSTNAME}}" => host
+ case "{{CORES}}" => cores.toString
+ case other => other
+ }
+
+ def buildCommandSeq(): Seq[String] = {
+ val command = appDesc.command
+ val runner = getAppEnv("JAVA_HOME").map(_ + "/bin/java").getOrElse("java")
+ // SPARK-698: do not call the run.cmd script, as process.destroy()
+ // fails to kill a process tree on Windows
+ Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
+ command.arguments.map(substituteVariables)
+ }
+
+ /**
+ * Attention: this must always be aligned with the environment variables in the run scripts and
+ * the way the JAVA_OPTS are assembled there.
+ */
+ def buildJavaOpts(): Seq[String] = {
+ val libraryOpts = getAppEnv("SPARK_LIBRARY_PATH")
+ .map(p => List("-Djava.library.path=" + p))
+ .getOrElse(Nil)
+ val workerLocalOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil)
+ val userOpts = getAppEnv("SPARK_JAVA_OPTS").map(Utils.splitCommandString).getOrElse(Nil)
+ val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M")
+
+ // Figure out our classpath with the external compute-classpath script
+ val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
+ val classPath = Utils.executeAndGetOutput(
+ Seq(sparkHome + "/bin/compute-classpath" + ext),
+ extraEnvironment=appDesc.command.environment)
+
+ Seq("-cp", classPath) ++ libraryOpts ++ workerLocalOpts ++ userOpts ++ memoryOpts
+ }
+
+ /** Spawn a thread that will redirect a given stream to a file */
+ def redirectStream(in: InputStream, file: File) {
+ val out = new FileOutputStream(file, true)
+ new Thread("redirect output to " + file) {
+ override def run() {
+ try {
+ Utils.copyStream(in, out, true)
+ } catch {
+ case e: IOException =>
+ logInfo("Redirection to " + file + " closed: " + e.getMessage)
+ }
+ }
+ }.start()
+ }
+
+ /**
+ * Download and run the executor described in our ApplicationDescription
+ */
+ def fetchAndRunExecutor() {
+ try {
+ // Create the executor's working directory
+ val executorDir = new File(workDir, appId + "/" + execId)
+ if (!executorDir.mkdirs()) {
+ throw new IOException("Failed to create directory " + executorDir)
+ }
+
+ // Launch the process
+ val command = buildCommandSeq()
+ logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
+ val builder = new ProcessBuilder(command: _*).directory(executorDir)
+ val env = builder.environment()
+ for ((key, value) <- appDesc.command.environment) {
+ env.put(key, value)
+ }
+ // In case we are running this from within the Spark Shell, avoid creating a "scala"
+ // parent process for the executor command
+ env.put("SPARK_LAUNCH_WITH_SCALA", "0")
+ process = builder.start()
+
+ val header = "Spark Executor Command: %s\n%s\n\n".format(
+ command.mkString("\"", "\" \"", "\""), "=" * 40)
+
+ // Redirect its stdout and stderr to files
+ val stdout = new File(executorDir, "stdout")
+ redirectStream(process.getInputStream, stdout)
+
+ val stderr = new File(executorDir, "stderr")
+ Files.write(header, stderr, Charsets.UTF_8)
+ redirectStream(process.getErrorStream, stderr)
+
+ // Wait for it to exit; this is actually a bad thing if it happens, because we expect to run
+ // long-lived processes only. However, in the future, we might restart the executor a few
+ // times on the same machine.
+ val exitCode = process.waitFor()
+ val message = "Command exited with code " + exitCode
+ worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message),
+ Some(exitCode))
+ } catch {
+ case interrupted: InterruptedException =>
+ logInfo("Runner thread for executor " + fullId + " interrupted")
+
+ case e: Exception => {
+ logError("Error running executor", e)
+ if (process != null) {
+ process.destroy()
+ }
+ val message = e.getClass + ": " + e.getMessage
+ worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
new file mode 100644
index 0000000000..86e8e7543b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker
+
+import java.text.SimpleDateFormat
+import java.util.Date
+import java.io.File
+
+import scala.collection.mutable.HashMap
+
+import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
+import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
+import akka.util.duration._
+
+import org.apache.spark.{Logging, Utils}
+import org.apache.spark.deploy.ExecutorState
+import org.apache.spark.deploy.DeployMessages._
+import org.apache.spark.deploy.master.Master
+import org.apache.spark.deploy.worker.ui.WorkerWebUI
+import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.util.AkkaUtils
+
+
+private[spark] class Worker(
+ host: String,
+ port: Int,
+ webUiPort: Int,
+ cores: Int,
+ memory: Int,
+ masterUrl: String,
+ workDirPath: String = null)
+ extends Actor with Logging {
+
+ Utils.checkHost(host, "Expected hostname")
+ assert (port > 0)
+
+ val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
+
+ // Send a heartbeat every (heartbeat timeout) / 4 milliseconds
+ val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4
+
+ var master: ActorRef = null
+ var masterWebUiUrl : String = ""
+ val workerId = generateWorkerId()
+ var sparkHome: File = null
+ var workDir: File = null
+ val executors = new HashMap[String, ExecutorRunner]
+ val finishedExecutors = new HashMap[String, ExecutorRunner]
+ val publicAddress = {
+ val envVar = System.getenv("SPARK_PUBLIC_DNS")
+ if (envVar != null) envVar else host
+ }
+ var webUi: WorkerWebUI = null
+
+ var coresUsed = 0
+ var memoryUsed = 0
+
+ val metricsSystem = MetricsSystem.createMetricsSystem("worker")
+ val workerSource = new WorkerSource(this)
+
+ def coresFree: Int = cores - coresUsed
+ def memoryFree: Int = memory - memoryUsed
+
+ def createWorkDir() {
+ workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
+ try {
+ // This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs()
+ // So attempting to create and then check if directory was created or not.
+ workDir.mkdirs()
+ if ( !workDir.exists() || !workDir.isDirectory) {
+ logError("Failed to create work directory " + workDir)
+ System.exit(1)
+ }
+ assert (workDir.isDirectory)
+ } catch {
+ case e: Exception =>
+ logError("Failed to create work directory " + workDir, e)
+ System.exit(1)
+ }
+ }
+
+ override def preStart() {
+ logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
+ host, port, cores, Utils.megabytesToString(memory)))
+ sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
+ logInfo("Spark home: " + sparkHome)
+ createWorkDir()
+ webUi = new WorkerWebUI(this, workDir, Some(webUiPort))
+
+ webUi.start()
+ connectToMaster()
+
+ metricsSystem.registerSource(workerSource)
+ metricsSystem.start()
+ }
+
+ def connectToMaster() {
+ logInfo("Connecting to master " + masterUrl)
+ master = context.actorFor(Master.toAkkaUrl(masterUrl))
+ master ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress)
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(master) // Doesn't work with remote actors, but useful for testing
+ }
+
+ override def receive = {
+ case RegisteredWorker(url) =>
+ masterWebUiUrl = url
+ logInfo("Successfully registered with master")
+ context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) {
+ master ! Heartbeat(workerId)
+ }
+
+ case RegisterWorkerFailed(message) =>
+ logError("Worker registration failed: " + message)
+ System.exit(1)
+
+ case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
+ logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
+ val manager = new ExecutorRunner(
+ appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir)
+ executors(appId + "/" + execId) = manager
+ manager.start()
+ coresUsed += cores_
+ memoryUsed += memory_
+ master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None)
+
+ case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
+ master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
+ val fullId = appId + "/" + execId
+ if (ExecutorState.isFinished(state)) {
+ val executor = executors(fullId)
+ logInfo("Executor " + fullId + " finished with state " + state +
+ message.map(" message " + _).getOrElse("") +
+ exitStatus.map(" exitStatus " + _).getOrElse(""))
+ finishedExecutors(fullId) = executor
+ executors -= fullId
+ coresUsed -= executor.cores
+ memoryUsed -= executor.memory
+ }
+
+ case KillExecutor(appId, execId) =>
+ val fullId = appId + "/" + execId
+ executors.get(fullId) match {
+ case Some(executor) =>
+ logInfo("Asked to kill executor " + fullId)
+ executor.kill()
+ case None =>
+ logInfo("Asked to kill unknown executor " + fullId)
+ }
+
+ case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
+ masterDisconnected()
+
+ case RequestWorkerState => {
+ sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
+ finishedExecutors.values.toList, masterUrl, cores, memory,
+ coresUsed, memoryUsed, masterWebUiUrl)
+ }
+ }
+
+ def masterDisconnected() {
+ // TODO: It would be nice to try to reconnect to the master, but just shut down for now.
+ // (Note that if reconnecting we would also need to assign IDs differently.)
+ logError("Connection to master failed! Shutting down.")
+ executors.values.foreach(_.kill())
+ System.exit(1)
+ }
+
+ def generateWorkerId(): String = {
+ "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port)
+ }
+
+ override def postStop() {
+ executors.values.foreach(_.kill())
+ webUi.stop()
+ metricsSystem.stop()
+ }
+}
+
+private[spark] object Worker {
+ def main(argStrings: Array[String]) {
+ val args = new WorkerArguments(argStrings)
+ val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
+ args.memory, args.master, args.workDir)
+ actorSystem.awaitTermination()
+ }
+
+ def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
+ masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ // The LocalSparkCluster runs multiple local sparkWorkerX actor systems
+ val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
+ val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
+ masterUrl, workDir)), name = "Worker")
+ (actorSystem, boundPort)
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
new file mode 100644
index 0000000000..6d91223413
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker
+
+import org.apache.spark.util.IntParam
+import org.apache.spark.util.MemoryParam
+import org.apache.spark.Utils
+import java.lang.management.ManagementFactory
+
+/**
+ * Command-line parser for the master.
+ */
+private[spark] class WorkerArguments(args: Array[String]) {
+ var host = Utils.localHostName()
+ var port = 0
+ var webUiPort = 8081
+ var cores = inferDefaultCores()
+ var memory = inferDefaultMemory()
+ var master: String = null
+ var workDir: String = null
+
+ // Check for settings in environment variables
+ if (System.getenv("SPARK_WORKER_PORT") != null) {
+ port = System.getenv("SPARK_WORKER_PORT").toInt
+ }
+ if (System.getenv("SPARK_WORKER_CORES") != null) {
+ cores = System.getenv("SPARK_WORKER_CORES").toInt
+ }
+ if (System.getenv("SPARK_WORKER_MEMORY") != null) {
+ memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY"))
+ }
+ if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) {
+ webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt
+ }
+ if (System.getenv("SPARK_WORKER_DIR") != null) {
+ workDir = System.getenv("SPARK_WORKER_DIR")
+ }
+
+ parse(args.toList)
+
+ def parse(args: List[String]): Unit = args match {
+ case ("--ip" | "-i") :: value :: tail =>
+ Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
+ host = value
+ parse(tail)
+
+ case ("--host" | "-h") :: value :: tail =>
+ Utils.checkHost(value, "Please use hostname " + value)
+ host = value
+ parse(tail)
+
+ case ("--port" | "-p") :: IntParam(value) :: tail =>
+ port = value
+ parse(tail)
+
+ case ("--cores" | "-c") :: IntParam(value) :: tail =>
+ cores = value
+ parse(tail)
+
+ case ("--memory" | "-m") :: MemoryParam(value) :: tail =>
+ memory = value
+ parse(tail)
+
+ case ("--work-dir" | "-d") :: value :: tail =>
+ workDir = value
+ parse(tail)
+
+ case "--webui-port" :: IntParam(value) :: tail =>
+ webUiPort = value
+ parse(tail)
+
+ case ("--help" | "-h") :: tail =>
+ printUsageAndExit(0)
+
+ case value :: tail =>
+ if (master != null) { // Two positional arguments were given
+ printUsageAndExit(1)
+ }
+ master = value
+ parse(tail)
+
+ case Nil =>
+ if (master == null) { // No positional argument was given
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1)
+ }
+
+ /**
+ * Print usage and exit JVM with the given exit code.
+ */
+ def printUsageAndExit(exitCode: Int) {
+ System.err.println(
+ "Usage: Worker [options] <master>\n" +
+ "\n" +
+ "Master must be a URL of the form spark://hostname:port\n" +
+ "\n" +
+ "Options:\n" +
+ " -c CORES, --cores CORES Number of cores to use\n" +
+ " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
+ " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" +
+ " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" +
+ " -h HOST, --host HOST Hostname to listen on\n" +
+ " -p PORT, --port PORT Port to listen on (default: random)\n" +
+ " --webui-port PORT Port for web UI (default: 8081)")
+ System.exit(exitCode)
+ }
+
+ def inferDefaultCores(): Int = {
+ Runtime.getRuntime.availableProcessors()
+ }
+
+ def inferDefaultMemory(): Int = {
+ val ibmVendor = System.getProperty("java.vendor").contains("IBM")
+ var totalMb = 0
+ try {
+ val bean = ManagementFactory.getOperatingSystemMXBean()
+ if (ibmVendor) {
+ val beanClass = Class.forName("com.ibm.lang.management.OperatingSystemMXBean")
+ val method = beanClass.getDeclaredMethod("getTotalPhysicalMemory")
+ totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
+ } else {
+ val beanClass = Class.forName("com.sun.management.OperatingSystemMXBean")
+ val method = beanClass.getDeclaredMethod("getTotalPhysicalMemorySize")
+ totalMb = (method.invoke(bean).asInstanceOf[Long] / 1024 / 1024).toInt
+ }
+ } catch {
+ case e: Exception => {
+ totalMb = 2*1024
+ System.out.println("Failed to get total physical memory. Using " + totalMb + " MB")
+ }
+ }
+ // Leave out 1 GB for the operating system, but don't return a negative memory size
+ math.max(totalMb - 1024, 512)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala
new file mode 100644
index 0000000000..6427c0178f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala
@@ -0,0 +1,34 @@
+package org.apache.spark.deploy.worker
+
+import com.codahale.metrics.{Gauge, MetricRegistry}
+
+import org.apache.spark.metrics.source.Source
+
+private[spark] class WorkerSource(val worker: Worker) extends Source {
+ val sourceName = "worker"
+ val metricRegistry = new MetricRegistry()
+
+ metricRegistry.register(MetricRegistry.name("executors", "number"), new Gauge[Int] {
+ override def getValue: Int = worker.executors.size
+ })
+
+ // Gauge for cores used of this worker
+ metricRegistry.register(MetricRegistry.name("coresUsed", "number"), new Gauge[Int] {
+ override def getValue: Int = worker.coresUsed
+ })
+
+ // Gauge for memory used of this worker
+ metricRegistry.register(MetricRegistry.name("memUsed", "MBytes"), new Gauge[Int] {
+ override def getValue: Int = worker.memoryUsed
+ })
+
+ // Gauge for cores free of this worker
+ metricRegistry.register(MetricRegistry.name("coresFree", "number"), new Gauge[Int] {
+ override def getValue: Int = worker.coresFree
+ })
+
+ // Gauge for memory free of this worker
+ metricRegistry.register(MetricRegistry.name("memFree", "MBytes"), new Gauge[Int] {
+ override def getValue: Int = worker.memoryFree
+ })
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
new file mode 100644
index 0000000000..6192c2324b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker.ui
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import akka.dispatch.Await
+import akka.pattern.ask
+import akka.util.duration._
+
+import net.liftweb.json.JsonAST.JValue
+
+import org.apache.spark.Utils
+import org.apache.spark.deploy.JsonProtocol
+import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse}
+import org.apache.spark.deploy.worker.ExecutorRunner
+import org.apache.spark.ui.UIUtils
+
+
+private[spark] class IndexPage(parent: WorkerWebUI) {
+ val workerActor = parent.worker.self
+ val worker = parent.worker
+ val timeout = parent.timeout
+
+ def renderJson(request: HttpServletRequest): JValue = {
+ val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
+ val workerState = Await.result(stateFuture, 30 seconds)
+ JsonProtocol.writeWorkerState(workerState)
+ }
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
+ val workerState = Await.result(stateFuture, 30 seconds)
+
+ val executorHeaders = Seq("ExecutorID", "Cores", "Memory", "Job Details", "Logs")
+ val runningExecutorTable =
+ UIUtils.listingTable(executorHeaders, executorRow, workerState.executors)
+ val finishedExecutorTable =
+ UIUtils.listingTable(executorHeaders, executorRow, workerState.finishedExecutors)
+
+ val content =
+ <div class="row-fluid"> <!-- Worker Details -->
+ <div class="span12">
+ <ul class="unstyled">
+ <li><strong>ID:</strong> {workerState.workerId}</li>
+ <li><strong>
+ Master URL:</strong> {workerState.masterUrl}
+ </li>
+ <li><strong>Cores:</strong> {workerState.cores} ({workerState.coresUsed} Used)</li>
+ <li><strong>Memory:</strong> {Utils.megabytesToString(workerState.memory)}
+ ({Utils.megabytesToString(workerState.memoryUsed)} Used)</li>
+ </ul>
+ <p><a href={workerState.masterWebUiUrl}>Back to Master</a></p>
+ </div>
+ </div>
+
+ <div class="row-fluid"> <!-- Running Executors -->
+ <div class="span12">
+ <h4> Running Executors {workerState.executors.size} </h4>
+ {runningExecutorTable}
+ </div>
+ </div>
+
+ <div class="row-fluid"> <!-- Finished Executors -->
+ <div class="span12">
+ <h4> Finished Executors </h4>
+ {finishedExecutorTable}
+ </div>
+ </div>;
+
+ UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format(
+ workerState.host, workerState.port))
+ }
+
+ def executorRow(executor: ExecutorRunner): Seq[Node] = {
+ <tr>
+ <td>{executor.execId}</td>
+ <td>{executor.cores}</td>
+ <td sorttable_customkey={executor.memory.toString}>
+ {Utils.megabytesToString(executor.memory)}
+ </td>
+ <td>
+ <ul class="unstyled">
+ <li><strong>ID:</strong> {executor.appId}</li>
+ <li><strong>Name:</strong> {executor.appDesc.name}</li>
+ <li><strong>User:</strong> {executor.appDesc.user}</li>
+ </ul>
+ </td>
+ <td>
+ <a href={"logPage?appId=%s&executorId=%s&logType=stdout"
+ .format(executor.appId, executor.execId)}>stdout</a>
+ <a href={"logPage?appId=%s&executorId=%s&logType=stderr"
+ .format(executor.appId, executor.execId)}>stderr</a>
+ </td>
+ </tr>
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
new file mode 100644
index 0000000000..bb8165ac09
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker.ui
+
+import akka.util.{Duration, Timeout}
+
+import java.io.{FileInputStream, File}
+
+import javax.servlet.http.HttpServletRequest
+
+import org.eclipse.jetty.server.{Handler, Server}
+
+import org.apache.spark.deploy.worker.Worker
+import org.apache.spark.{Utils, Logging}
+import org.apache.spark.ui.JettyUtils
+import org.apache.spark.ui.JettyUtils._
+import org.apache.spark.ui.UIUtils
+
+/**
+ * Web UI server for the standalone worker.
+ */
+private[spark]
+class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None)
+ extends Logging {
+ implicit val timeout = Timeout(
+ Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
+ val host = Utils.localHostName()
+ val port = requestedPort.getOrElse(
+ System.getProperty("worker.ui.port", WorkerWebUI.DEFAULT_PORT).toInt)
+
+ var server: Option[Server] = None
+ var boundPort: Option[Int] = None
+
+ val indexPage = new IndexPage(this)
+
+ val metricsHandlers = worker.metricsSystem.getServletHandlers
+
+ val handlers = metricsHandlers ++ Array[(String, Handler)](
+ ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)),
+ ("/log", (request: HttpServletRequest) => log(request)),
+ ("/logPage", (request: HttpServletRequest) => logPage(request)),
+ ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)),
+ ("*", (request: HttpServletRequest) => indexPage.render(request))
+ )
+
+ def start() {
+ try {
+ val (srv, bPort) = JettyUtils.startJettyServer("0.0.0.0", port, handlers)
+ server = Some(srv)
+ boundPort = Some(bPort)
+ logInfo("Started Worker web UI at http://%s:%d".format(host, bPort))
+ } catch {
+ case e: Exception =>
+ logError("Failed to create Worker JettyUtils", e)
+ System.exit(1)
+ }
+ }
+
+ def log(request: HttpServletRequest): String = {
+ val defaultBytes = 100 * 1024
+ val appId = request.getParameter("appId")
+ val executorId = request.getParameter("executorId")
+ val logType = request.getParameter("logType")
+ val offset = Option(request.getParameter("offset")).map(_.toLong)
+ val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
+ val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType)
+
+ val (startByte, endByte) = getByteRange(path, offset, byteLength)
+ val file = new File(path)
+ val logLength = file.length
+
+ val pre = "==== Bytes %s-%s of %s of %s/%s/%s ====\n"
+ .format(startByte, endByte, logLength, appId, executorId, logType)
+ pre + Utils.offsetBytes(path, startByte, endByte)
+ }
+
+ def logPage(request: HttpServletRequest): Seq[scala.xml.Node] = {
+ val defaultBytes = 100 * 1024
+ val appId = request.getParameter("appId")
+ val executorId = request.getParameter("executorId")
+ val logType = request.getParameter("logType")
+ val offset = Option(request.getParameter("offset")).map(_.toLong)
+ val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
+ val path = "%s/%s/%s/%s".format(workDir.getPath, appId, executorId, logType)
+
+ val (startByte, endByte) = getByteRange(path, offset, byteLength)
+ val file = new File(path)
+ val logLength = file.length
+
+ val logText = <node>{Utils.offsetBytes(path, startByte, endByte)}</node>
+
+ val linkToMaster = <p><a href={worker.masterWebUiUrl}>Back to Master</a></p>
+
+ val range = <span>Bytes {startByte.toString} - {endByte.toString} of {logLength}</span>
+
+ val backButton =
+ if (startByte > 0) {
+ <a href={"?appId=%s&executorId=%s&logType=%s&offset=%s&byteLength=%s"
+ .format(appId, executorId, logType, math.max(startByte-byteLength, 0),
+ byteLength)}>
+ <button type="button" class="btn btn-default">
+ Previous {Utils.bytesToString(math.min(byteLength, startByte))}
+ </button>
+ </a>
+ }
+ else {
+ <button type="button" class="btn btn-default" disabled="disabled">
+ Previous 0 B
+ </button>
+ }
+
+ val nextButton =
+ if (endByte < logLength) {
+ <a href={"?appId=%s&executorId=%s&logType=%s&offset=%s&byteLength=%s".
+ format(appId, executorId, logType, endByte, byteLength)}>
+ <button type="button" class="btn btn-default">
+ Next {Utils.bytesToString(math.min(byteLength, logLength-endByte))}
+ </button>
+ </a>
+ }
+ else {
+ <button type="button" class="btn btn-default" disabled="disabled">
+ Next 0 B
+ </button>
+ }
+
+ val content =
+ <html>
+ <body>
+ {linkToMaster}
+ <div>
+ <div style="float:left;width:40%">{backButton}</div>
+ <div style="float:left;">{range}</div>
+ <div style="float:right;">{nextButton}</div>
+ </div>
+ <br />
+ <div style="height:500px;overflow:auto;padding:5px;">
+ <pre>{logText}</pre>
+ </div>
+ </body>
+ </html>
+ UIUtils.basicSparkPage(content, logType + " log page for " + appId)
+ }
+
+ /** Determine the byte range for a log or log page. */
+ def getByteRange(path: String, offset: Option[Long], byteLength: Int)
+ : (Long, Long) = {
+ val defaultBytes = 100 * 1024
+ val maxBytes = 1024 * 1024
+
+ val file = new File(path)
+ val logLength = file.length()
+ val getOffset = offset.getOrElse(logLength-defaultBytes)
+
+ val startByte =
+ if (getOffset < 0) 0L
+ else if (getOffset > logLength) logLength
+ else getOffset
+
+ val logPageLength = math.min(byteLength, maxBytes)
+
+ val endByte = math.min(startByte+logPageLength, logLength)
+
+ (startByte, endByte)
+ }
+
+ def stop() {
+ server.foreach(_.stop())
+ }
+}
+
+private[spark] object WorkerWebUI {
+ val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+ val DEFAULT_PORT="8081"
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
new file mode 100644
index 0000000000..5446a3fca9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import java.io.{File}
+import java.lang.management.ManagementFactory
+import java.nio.ByteBuffer
+import java.util.concurrent._
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.HashMap
+
+import org.apache.spark.scheduler._
+import org.apache.spark._
+
+
+/**
+ * The Mesos executor for Spark.
+ */
+private[spark] class Executor(
+ executorId: String,
+ slaveHostname: String,
+ properties: Seq[(String, String)])
+ extends Logging
+{
+ // Application dependencies (added through SparkContext) that we've fetched so far on this node.
+ // Each map holds the master's timestamp for the version of that file or JAR we got.
+ private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
+ private val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
+
+ private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
+
+ initLogging()
+
+ // No ip or host:port - just hostname
+ Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
+ // must not have port specified.
+ assert (0 == Utils.parseHostPort(slaveHostname)._2)
+
+ // Make sure the local hostname we report matches the cluster scheduler's name for this host
+ Utils.setCustomHostname(slaveHostname)
+
+ // Set spark.* system properties from executor arg
+ for ((key, value) <- properties) {
+ System.setProperty(key, value)
+ }
+
+ // If we are in yarn mode, systems can have different disk layouts so we must set it
+ // to what Yarn on this system said was available. This will be used later when SparkEnv
+ // created.
+ if (java.lang.Boolean.valueOf(System.getenv("SPARK_YARN_MODE"))) {
+ System.setProperty("spark.local.dir", getYarnLocalDirs())
+ }
+
+ // Create our ClassLoader and set it on this thread
+ private val urlClassLoader = createClassLoader()
+ private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
+
+ // Make any thread terminations due to uncaught exceptions kill the entire
+ // executor process to avoid surprising stalls.
+ Thread.setDefaultUncaughtExceptionHandler(
+ new Thread.UncaughtExceptionHandler {
+ override def uncaughtException(thread: Thread, exception: Throwable) {
+ try {
+ logError("Uncaught exception in thread " + thread, exception)
+
+ // We may have been called from a shutdown hook. If so, we must not call System.exit().
+ // (If we do, we will deadlock.)
+ if (!Utils.inShutdown()) {
+ if (exception.isInstanceOf[OutOfMemoryError]) {
+ System.exit(ExecutorExitCode.OOM)
+ } else {
+ System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ }
+ }
+ } catch {
+ case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
+ case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
+ }
+ }
+ }
+ )
+
+ val executorSource = new ExecutorSource(this)
+
+ // Initialize Spark environment (using system properties read above)
+ val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
+ SparkEnv.set(env)
+ env.metricsSystem.registerSource(executorSource)
+
+ private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
+
+ // Start worker thread pool
+ val threadPool = new ThreadPoolExecutor(
+ 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
+
+ def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
+ threadPool.execute(new TaskRunner(context, taskId, serializedTask))
+ }
+
+ /** Get the Yarn approved local directories. */
+ private def getYarnLocalDirs(): String = {
+ // Hadoop 0.23 and 2.x have different Environment variable names for the
+ // local dirs, so lets check both. We assume one of the 2 is set.
+ // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
+ val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
+ .getOrElse(Option(System.getenv("LOCAL_DIRS"))
+ .getOrElse(""))
+
+ if (localDirs.isEmpty()) {
+ throw new Exception("Yarn Local dirs can't be empty")
+ }
+ return localDirs
+ }
+
+ class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
+ extends Runnable {
+
+ override def run() {
+ val startTime = System.currentTimeMillis()
+ SparkEnv.set(env)
+ Thread.currentThread.setContextClassLoader(replClassLoader)
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ logInfo("Running task ID " + taskId)
+ context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
+ var attemptedTask: Option[Task[Any]] = None
+ var taskStart: Long = 0
+ def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
+ val startGCTime = getTotalGCTime
+
+ try {
+ SparkEnv.set(env)
+ Accumulators.clear()
+ val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
+ updateDependencies(taskFiles, taskJars)
+ val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ attemptedTask = Some(task)
+ logInfo("Its epoch is " + task.epoch)
+ env.mapOutputTracker.updateEpoch(task.epoch)
+ taskStart = System.currentTimeMillis()
+ val value = task.run(taskId.toInt)
+ val taskFinish = System.currentTimeMillis()
+ for (m <- task.metrics) {
+ m.hostname = Utils.localHostName
+ m.executorDeserializeTime = (taskStart - startTime).toInt
+ m.executorRunTime = (taskFinish - taskStart).toInt
+ m.jvmGCTime = getTotalGCTime - startGCTime
+ }
+ //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
+ // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
+ // just change the relevants bytes in the byte buffer
+ val accumUpdates = Accumulators.values
+ val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
+ val serializedResult = ser.serialize(result)
+ logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
+ if (serializedResult.limit >= (akkaFrameSize - 1024)) {
+ context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure()))
+ return
+ }
+ context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
+ logInfo("Finished task ID " + taskId)
+ } catch {
+ case ffe: FetchFailedException => {
+ val reason = ffe.toTaskEndReason
+ context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ }
+
+ case t: Throwable => {
+ val serviceTime = (System.currentTimeMillis() - taskStart).toInt
+ val metrics = attemptedTask.flatMap(t => t.metrics)
+ for (m <- metrics) {
+ m.executorRunTime = serviceTime
+ m.jvmGCTime = getTotalGCTime - startGCTime
+ }
+ val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
+ context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+
+ // TODO: Should we exit the whole executor here? On the one hand, the failed task may
+ // have left some weird state around depending on when the exception was thrown, but on
+ // the other hand, maybe we could detect that when future tasks fail and exit then.
+ logError("Exception in task ID " + taskId, t)
+ //System.exit(1)
+ }
+ }
+ }
+ }
+
+ /**
+ * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
+ * created by the interpreter to the search path
+ */
+ private def createClassLoader(): ExecutorURLClassLoader = {
+ var loader = this.getClass.getClassLoader
+
+ // For each of the jars in the jarSet, add them to the class loader.
+ // We assume each of the files has already been fetched.
+ val urls = currentJars.keySet.map { uri =>
+ new File(uri.split("/").last).toURI.toURL
+ }.toArray
+ new ExecutorURLClassLoader(urls, loader)
+ }
+
+ /**
+ * If the REPL is in use, add another ClassLoader that will read
+ * new classes defined by the REPL as the user types code
+ */
+ private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
+ val classUri = System.getProperty("spark.repl.class.uri")
+ if (classUri != null) {
+ logInfo("Using REPL class URI: " + classUri)
+ try {
+ val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
+ .asInstanceOf[Class[_ <: ClassLoader]]
+ val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
+ return constructor.newInstance(classUri, parent)
+ } catch {
+ case _: ClassNotFoundException =>
+ logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
+ System.exit(1)
+ null
+ }
+ } else {
+ return parent
+ }
+ }
+
+ /**
+ * Download any missing dependencies if we receive a new set of files and JARs from the
+ * SparkContext. Also adds any new JARs we fetched to the class loader.
+ */
+ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
+ synchronized {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentFiles(name) = timestamp
+ }
+ for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
+ if (!urlClassLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ urlClassLoader.addURL(url)
+ }
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala
new file mode 100644
index 0000000000..ad7dd34c76
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import java.nio.ByteBuffer
+import org.apache.spark.TaskState.TaskState
+
+/**
+ * A pluggable interface used by the Executor to send updates to the cluster scheduler.
+ */
+private[spark] trait ExecutorBackend {
+ def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer)
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
new file mode 100644
index 0000000000..e5c9bbbe28
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+/**
+ * These are exit codes that executors should use to provide the master with information about
+ * executor failures assuming that cluster management framework can capture the exit codes (but
+ * perhaps not log files). The exit code constants here are chosen to be unlikely to conflict
+ * with "natural" exit statuses that may be caused by the JVM or user code. In particular,
+ * exit codes 128+ arise on some Unix-likes as a result of signals, and it appears that the
+ * OpenJDK JVM may use exit code 1 in some of its own "last chance" code.
+ */
+private[spark]
+object ExecutorExitCode {
+ /** The default uncaught exception handler was reached. */
+ val UNCAUGHT_EXCEPTION = 50
+
+ /** The default uncaught exception handler was called and an exception was encountered while
+ logging the exception. */
+ val UNCAUGHT_EXCEPTION_TWICE = 51
+
+ /** The default uncaught exception handler was reached, and the uncaught exception was an
+ OutOfMemoryError. */
+ val OOM = 52
+
+ /** DiskStore failed to create a local temporary directory after many attempts. */
+ val DISK_STORE_FAILED_TO_CREATE_DIR = 53
+
+ def explainExitCode(exitCode: Int): String = {
+ exitCode match {
+ case UNCAUGHT_EXCEPTION => "Uncaught exception"
+ case UNCAUGHT_EXCEPTION_TWICE => "Uncaught exception, and logging the exception failed"
+ case OOM => "OutOfMemoryError"
+ case DISK_STORE_FAILED_TO_CREATE_DIR =>
+ "Failed to create local directory (bad spark.local.dir?)"
+ case _ =>
+ "Unknown executor exit code (" + exitCode + ")" + (
+ if (exitCode > 128)
+ " (died from signal " + (exitCode - 128) + "?)"
+ else
+ ""
+ )
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
new file mode 100644
index 0000000000..17653cd560
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
@@ -0,0 +1,55 @@
+package org.apache.spark.executor
+
+import com.codahale.metrics.{Gauge, MetricRegistry}
+
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.hdfs.DistributedFileSystem
+import org.apache.hadoop.fs.LocalFileSystem
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.metrics.source.Source
+
+class ExecutorSource(val executor: Executor) extends Source {
+ private def fileStats(scheme: String) : Option[FileSystem.Statistics] =
+ FileSystem.getAllStatistics().filter(s => s.getScheme.equals(scheme)).headOption
+
+ private def registerFileSystemStat[T](
+ scheme: String, name: String, f: FileSystem.Statistics => T, defaultValue: T) = {
+ metricRegistry.register(MetricRegistry.name("filesystem", scheme, name), new Gauge[T] {
+ override def getValue: T = fileStats(scheme).map(f).getOrElse(defaultValue)
+ })
+ }
+
+ val metricRegistry = new MetricRegistry()
+ val sourceName = "executor"
+
+ // Gauge for executor thread pool's actively executing task counts
+ metricRegistry.register(MetricRegistry.name("threadpool", "activeTask", "count"), new Gauge[Int] {
+ override def getValue: Int = executor.threadPool.getActiveCount()
+ })
+
+ // Gauge for executor thread pool's approximate total number of tasks that have been completed
+ metricRegistry.register(MetricRegistry.name("threadpool", "completeTask", "count"), new Gauge[Long] {
+ override def getValue: Long = executor.threadPool.getCompletedTaskCount()
+ })
+
+ // Gauge for executor thread pool's current number of threads
+ metricRegistry.register(MetricRegistry.name("threadpool", "currentPool", "size"), new Gauge[Int] {
+ override def getValue: Int = executor.threadPool.getPoolSize()
+ })
+
+ // Gauge got executor thread pool's largest number of threads that have ever simultaneously been in th pool
+ metricRegistry.register(MetricRegistry.name("threadpool", "maxPool", "size"), new Gauge[Int] {
+ override def getValue: Int = executor.threadPool.getMaximumPoolSize()
+ })
+
+ // Gauge for file system stats of this executor
+ for (scheme <- Array("hdfs", "file")) {
+ registerFileSystemStat(scheme, "bytesRead", _.getBytesRead(), 0L)
+ registerFileSystemStat(scheme, "bytesWritten", _.getBytesWritten(), 0L)
+ registerFileSystemStat(scheme, "readOps", _.getReadOps(), 0)
+ registerFileSystemStat(scheme, "largeReadOps", _.getLargeReadOps(), 0)
+ registerFileSystemStat(scheme, "writeOps", _.getWriteOps(), 0)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
new file mode 100644
index 0000000000..f9bfe8ed2f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import java.net.{URLClassLoader, URL}
+
+/**
+ * The addURL method in URLClassLoader is protected. We subclass it to make this accessible.
+ */
+private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
+ extends URLClassLoader(urls, parent) {
+
+ override def addURL(url: URL) {
+ super.addURL(url)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
new file mode 100644
index 0000000000..410a94df6b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import java.nio.ByteBuffer
+import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
+import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _}
+import org.apache.spark.TaskState.TaskState
+import com.google.protobuf.ByteString
+import org.apache.spark.{Utils, Logging}
+import org.apache.spark.TaskState
+
+private[spark] class MesosExecutorBackend
+ extends MesosExecutor
+ with ExecutorBackend
+ with Logging {
+
+ var executor: Executor = null
+ var driver: ExecutorDriver = null
+
+ override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
+ val mesosTaskId = TaskID.newBuilder().setValue(taskId.toString).build()
+ driver.sendStatusUpdate(MesosTaskStatus.newBuilder()
+ .setTaskId(mesosTaskId)
+ .setState(TaskState.toMesos(state))
+ .setData(ByteString.copyFrom(data))
+ .build())
+ }
+
+ override def registered(
+ driver: ExecutorDriver,
+ executorInfo: ExecutorInfo,
+ frameworkInfo: FrameworkInfo,
+ slaveInfo: SlaveInfo) {
+ logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
+ this.driver = driver
+ val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
+ executor = new Executor(
+ executorInfo.getExecutorId.getValue,
+ slaveInfo.getHostname,
+ properties)
+ }
+
+ override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
+ val taskId = taskInfo.getTaskId.getValue.toLong
+ if (executor == null) {
+ logError("Received launchTask but executor was null")
+ } else {
+ executor.launchTask(this, taskId, taskInfo.getData.asReadOnlyByteBuffer)
+ }
+ }
+
+ override def error(d: ExecutorDriver, message: String) {
+ logError("Error from Mesos: " + message)
+ }
+
+ override def killTask(d: ExecutorDriver, t: TaskID) {
+ logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)")
+ }
+
+ override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {}
+
+ override def disconnected(d: ExecutorDriver) {}
+
+ override def frameworkMessage(d: ExecutorDriver, data: Array[Byte]) {}
+
+ override def shutdown(d: ExecutorDriver) {}
+}
+
+/**
+ * Entry point for Mesos executor.
+ */
+private[spark] object MesosExecutorBackend {
+ def main(args: Array[String]) {
+ MesosNativeLibrary.load()
+ // Create a new Executor and start it running
+ val runner = new MesosExecutorBackend()
+ new MesosExecutorDriver(runner).run()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala
new file mode 100644
index 0000000000..65801f75b7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import java.nio.ByteBuffer
+
+import akka.actor.{ActorRef, Actor, Props, Terminated}
+import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
+
+import org.apache.spark.{Logging, Utils, SparkEnv}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
+import org.apache.spark.util.AkkaUtils
+
+
+private[spark] class StandaloneExecutorBackend(
+ driverUrl: String,
+ executorId: String,
+ hostPort: String,
+ cores: Int)
+ extends Actor
+ with ExecutorBackend
+ with Logging {
+
+ Utils.checkHostPort(hostPort, "Expected hostport")
+
+ var executor: Executor = null
+ var driver: ActorRef = null
+
+ override def preStart() {
+ logInfo("Connecting to driver: " + driverUrl)
+ driver = context.actorFor(driverUrl)
+ driver ! RegisterExecutor(executorId, hostPort, cores)
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(driver) // Doesn't work with remote actors, but useful for testing
+ }
+
+ override def receive = {
+ case RegisteredExecutor(sparkProperties) =>
+ logInfo("Successfully registered with driver")
+ // Make this host instead of hostPort ?
+ executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties)
+
+ case RegisterExecutorFailed(message) =>
+ logError("Slave registration failed: " + message)
+ System.exit(1)
+
+ case LaunchTask(taskDesc) =>
+ logInfo("Got assigned task " + taskDesc.taskId)
+ if (executor == null) {
+ logError("Received launchTask but executor was null")
+ System.exit(1)
+ } else {
+ executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
+ }
+
+ case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
+ logError("Driver terminated or disconnected! Shutting down.")
+ System.exit(1)
+ }
+
+ override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
+ driver ! StatusUpdate(executorId, taskId, state, data)
+ }
+}
+
+private[spark] object StandaloneExecutorBackend {
+ def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
+ // Debug code
+ Utils.checkHost(hostname)
+
+ // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
+ // before getting started with all our system properties, etc
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
+ // set it
+ val sparkHostPort = hostname + ":" + boundPort
+ System.setProperty("spark.hostPort", sparkHostPort)
+ val actor = actorSystem.actorOf(
+ Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
+ name = "Executor")
+ actorSystem.awaitTermination()
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 4) {
+ //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
+ System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]")
+ System.exit(1)
+ }
+ run(args(0), args(1), args(2), args(3).toInt)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
new file mode 100644
index 0000000000..f311141148
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+class TaskMetrics extends Serializable {
+ /**
+ * Host's name the task runs on
+ */
+ var hostname: String = _
+
+ /**
+ * Time taken on the executor to deserialize this task
+ */
+ var executorDeserializeTime: Int = _
+
+ /**
+ * Time the executor spends actually running the task (including fetching shuffle data)
+ */
+ var executorRunTime: Int = _
+
+ /**
+ * The number of bytes this task transmitted back to the driver as the TaskResult
+ */
+ var resultSize: Long = _
+
+ /**
+ * Amount of time the JVM spent in garbage collection while executing this task
+ */
+ var jvmGCTime: Long = _
+
+ /**
+ * If this task reads from shuffle output, metrics on getting shuffle data will be collected here
+ */
+ var shuffleReadMetrics: Option[ShuffleReadMetrics] = None
+
+ /**
+ * If this task writes to shuffle output, metrics on the written shuffle data will be collected here
+ */
+ var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
+}
+
+object TaskMetrics {
+ private[spark] def empty(): TaskMetrics = new TaskMetrics
+}
+
+
+class ShuffleReadMetrics extends Serializable {
+ /**
+ * Time when shuffle finishs
+ */
+ var shuffleFinishTime: Long = _
+
+ /**
+ * Total number of blocks fetched in a shuffle (remote or local)
+ */
+ var totalBlocksFetched: Int = _
+
+ /**
+ * Number of remote blocks fetched in a shuffle
+ */
+ var remoteBlocksFetched: Int = _
+
+ /**
+ * Local blocks fetched in a shuffle
+ */
+ var localBlocksFetched: Int = _
+
+ /**
+ * Total time that is spent blocked waiting for shuffle to fetch data
+ */
+ var fetchWaitTime: Long = _
+
+ /**
+ * The total amount of time for all the shuffle fetches. This adds up time from overlapping
+ * shuffles, so can be longer than task time
+ */
+ var remoteFetchTime: Long = _
+
+ /**
+ * Total number of remote bytes read from a shuffle
+ */
+ var remoteBytesRead: Long = _
+}
+
+class ShuffleWriteMetrics extends Serializable {
+ /**
+ * Number of bytes written for a shuffle
+ */
+ var shuffleBytesWritten: Long = _
+}
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
new file mode 100644
index 0000000000..90a0420caf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.io
+
+import java.io.{InputStream, OutputStream}
+
+import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
+
+import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream}
+
+
+/**
+ * CompressionCodec allows the customization of choosing different compression implementations
+ * to be used in block storage.
+ */
+trait CompressionCodec {
+
+ def compressedOutputStream(s: OutputStream): OutputStream
+
+ def compressedInputStream(s: InputStream): InputStream
+}
+
+
+private[spark] object CompressionCodec {
+
+ def createCodec(): CompressionCodec = {
+ // Set the default codec to Snappy since the LZF implementation initializes a pretty large
+ // buffer for every stream, which results in a lot of memory overhead when the number of
+ // shuffle reduce buckets are large.
+ createCodec(classOf[SnappyCompressionCodec].getName)
+ }
+
+ def createCodec(codecName: String): CompressionCodec = {
+ Class.forName(
+ System.getProperty("spark.io.compression.codec", codecName),
+ true,
+ Thread.currentThread.getContextClassLoader).newInstance().asInstanceOf[CompressionCodec]
+ }
+}
+
+
+/**
+ * LZF implementation of [[org.apache.spark.io.CompressionCodec]].
+ */
+class LZFCompressionCodec extends CompressionCodec {
+
+ override def compressedOutputStream(s: OutputStream): OutputStream = {
+ new LZFOutputStream(s).setFinishBlockOnFlush(true)
+ }
+
+ override def compressedInputStream(s: InputStream): InputStream = new LZFInputStream(s)
+}
+
+
+/**
+ * Snappy implementation of [[org.apache.spark.io.CompressionCodec]].
+ * Block size can be configured by spark.io.compression.snappy.block.size.
+ */
+class SnappyCompressionCodec extends CompressionCodec {
+
+ override def compressedOutputStream(s: OutputStream): OutputStream = {
+ val blockSize = System.getProperty("spark.io.compression.snappy.block.size", "32768").toInt
+ new SnappyOutputStream(s, blockSize)
+ }
+
+ override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s)
+}
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
new file mode 100644
index 0000000000..0f9c4e00b1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics
+
+import java.util.Properties
+import java.io.{File, FileInputStream, InputStream, IOException}
+
+import scala.collection.mutable
+import scala.util.matching.Regex
+
+import org.apache.spark.Logging
+
+private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging {
+ initLogging()
+
+ val DEFAULT_PREFIX = "*"
+ val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r
+ val METRICS_CONF = "metrics.properties"
+
+ val properties = new Properties()
+ var propertyCategories: mutable.HashMap[String, Properties] = null
+
+ private def setDefaultProperties(prop: Properties) {
+ prop.setProperty("*.sink.servlet.class", "org.apache.spark.metrics.sink.MetricsServlet")
+ prop.setProperty("*.sink.servlet.uri", "/metrics/json")
+ prop.setProperty("*.sink.servlet.sample", "false")
+ prop.setProperty("master.sink.servlet.uri", "/metrics/master/json")
+ prop.setProperty("applications.sink.servlet.uri", "/metrics/applications/json")
+ }
+
+ def initialize() {
+ //Add default properties in case there's no properties file
+ setDefaultProperties(properties)
+
+ // If spark.metrics.conf is not set, try to get file in class path
+ var is: InputStream = null
+ try {
+ is = configFile match {
+ case Some(f) => new FileInputStream(f)
+ case None => getClass.getClassLoader.getResourceAsStream(METRICS_CONF)
+ }
+
+ if (is != null) {
+ properties.load(is)
+ }
+ } catch {
+ case e: Exception => logError("Error loading configure file", e)
+ } finally {
+ if (is != null) is.close()
+ }
+
+ propertyCategories = subProperties(properties, INSTANCE_REGEX)
+ if (propertyCategories.contains(DEFAULT_PREFIX)) {
+ import scala.collection.JavaConversions._
+
+ val defaultProperty = propertyCategories(DEFAULT_PREFIX)
+ for { (inst, prop) <- propertyCategories
+ if (inst != DEFAULT_PREFIX)
+ (k, v) <- defaultProperty
+ if (prop.getProperty(k) == null) } {
+ prop.setProperty(k, v)
+ }
+ }
+ }
+
+ def subProperties(prop: Properties, regex: Regex): mutable.HashMap[String, Properties] = {
+ val subProperties = new mutable.HashMap[String, Properties]
+ import scala.collection.JavaConversions._
+ prop.foreach { kv =>
+ if (regex.findPrefixOf(kv._1) != None) {
+ val regex(prefix, suffix) = kv._1
+ subProperties.getOrElseUpdate(prefix, new Properties).setProperty(suffix, kv._2)
+ }
+ }
+ subProperties
+ }
+
+ def getInstance(inst: String): Properties = {
+ propertyCategories.get(inst) match {
+ case Some(s) => s
+ case None => propertyCategories.getOrElse(DEFAULT_PREFIX, new Properties)
+ }
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
new file mode 100644
index 0000000000..bec0c83be8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics
+
+import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry}
+
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+import org.apache.spark.metrics.sink.{MetricsServlet, Sink}
+import org.apache.spark.metrics.source.Source
+
+/**
+ * Spark Metrics System, created by specific "instance", combined by source,
+ * sink, periodically poll source metrics data to sink destinations.
+ *
+ * "instance" specify "who" (the role) use metrics system. In spark there are several roles
+ * like master, worker, executor, client driver, these roles will create metrics system
+ * for monitoring. So instance represents these roles. Currently in Spark, several instances
+ * have already implemented: master, worker, executor, driver, applications.
+ *
+ * "source" specify "where" (source) to collect metrics data. In metrics system, there exists
+ * two kinds of source:
+ * 1. Spark internal source, like MasterSource, WorkerSource, etc, which will collect
+ * Spark component's internal state, these sources are related to instance and will be
+ * added after specific metrics system is created.
+ * 2. Common source, like JvmSource, which will collect low level state, is configured by
+ * configuration and loaded through reflection.
+ *
+ * "sink" specify "where" (destination) to output metrics data to. Several sinks can be
+ * coexisted and flush metrics to all these sinks.
+ *
+ * Metrics configuration format is like below:
+ * [instance].[sink|source].[name].[options] = xxxx
+ *
+ * [instance] can be "master", "worker", "executor", "driver", "applications" which means only
+ * the specified instance has this property.
+ * wild card "*" can be used to replace instance name, which means all the instances will have
+ * this property.
+ *
+ * [sink|source] means this property belongs to source or sink. This field can only be source or sink.
+ *
+ * [name] specify the name of sink or source, it is custom defined.
+ *
+ * [options] is the specific property of this source or sink.
+ */
+private[spark] class MetricsSystem private (val instance: String) extends Logging {
+ initLogging()
+
+ val confFile = System.getProperty("spark.metrics.conf")
+ val metricsConfig = new MetricsConfig(Option(confFile))
+
+ val sinks = new mutable.ArrayBuffer[Sink]
+ val sources = new mutable.ArrayBuffer[Source]
+ val registry = new MetricRegistry()
+
+ // Treat MetricsServlet as a special sink as it should be exposed to add handlers to web ui
+ private var metricsServlet: Option[MetricsServlet] = None
+
+ /** Get any UI handlers used by this metrics system. */
+ def getServletHandlers = metricsServlet.map(_.getHandlers).getOrElse(Array())
+
+ metricsConfig.initialize()
+ registerSources()
+ registerSinks()
+
+ def start() {
+ sinks.foreach(_.start)
+ }
+
+ def stop() {
+ sinks.foreach(_.stop)
+ }
+
+ def registerSource(source: Source) {
+ sources += source
+ try {
+ registry.register(source.sourceName, source.metricRegistry)
+ } catch {
+ case e: IllegalArgumentException => logInfo("Metrics already registered", e)
+ }
+ }
+
+ def removeSource(source: Source) {
+ sources -= source
+ registry.removeMatching(new MetricFilter {
+ def matches(name: String, metric: Metric): Boolean = name.startsWith(source.sourceName)
+ })
+ }
+
+ def registerSources() {
+ val instConfig = metricsConfig.getInstance(instance)
+ val sourceConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SOURCE_REGEX)
+
+ // Register all the sources related to instance
+ sourceConfigs.foreach { kv =>
+ val classPath = kv._2.getProperty("class")
+ try {
+ val source = Class.forName(classPath).newInstance()
+ registerSource(source.asInstanceOf[Source])
+ } catch {
+ case e: Exception => logError("Source class " + classPath + " cannot be instantialized", e)
+ }
+ }
+ }
+
+ def registerSinks() {
+ val instConfig = metricsConfig.getInstance(instance)
+ val sinkConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SINK_REGEX)
+
+ sinkConfigs.foreach { kv =>
+ val classPath = kv._2.getProperty("class")
+ try {
+ val sink = Class.forName(classPath)
+ .getConstructor(classOf[Properties], classOf[MetricRegistry])
+ .newInstance(kv._2, registry)
+ if (kv._1 == "servlet") {
+ metricsServlet = Some(sink.asInstanceOf[MetricsServlet])
+ } else {
+ sinks += sink.asInstanceOf[Sink]
+ }
+ } catch {
+ case e: Exception => logError("Sink class " + classPath + " cannot be instantialized", e)
+ }
+ }
+ }
+}
+
+private[spark] object MetricsSystem {
+ val SINK_REGEX = "^sink\\.(.+)\\.(.+)".r
+ val SOURCE_REGEX = "^source\\.(.+)\\.(.+)".r
+
+ val MINIMAL_POLL_UNIT = TimeUnit.SECONDS
+ val MINIMAL_POLL_PERIOD = 1
+
+ def checkMinimalPollingPeriod(pollUnit: TimeUnit, pollPeriod: Int) {
+ val period = MINIMAL_POLL_UNIT.convert(pollPeriod, pollUnit)
+ if (period < MINIMAL_POLL_PERIOD) {
+ throw new IllegalArgumentException("Polling period " + pollPeriod + " " + pollUnit +
+ " below than minimal polling period ")
+ }
+ }
+
+ def createMetricsSystem(instance: String): MetricsSystem = new MetricsSystem(instance)
+}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
new file mode 100644
index 0000000000..bce257d6e6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import com.codahale.metrics.{ConsoleReporter, MetricRegistry}
+
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+
+import org.apache.spark.metrics.MetricsSystem
+
+class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+ val CONSOLE_DEFAULT_PERIOD = 10
+ val CONSOLE_DEFAULT_UNIT = "SECONDS"
+
+ val CONSOLE_KEY_PERIOD = "period"
+ val CONSOLE_KEY_UNIT = "unit"
+
+ val pollPeriod = Option(property.getProperty(CONSOLE_KEY_PERIOD)) match {
+ case Some(s) => s.toInt
+ case None => CONSOLE_DEFAULT_PERIOD
+ }
+
+ val pollUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match {
+ case Some(s) => TimeUnit.valueOf(s.toUpperCase())
+ case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT)
+ }
+
+ MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
+
+ val reporter: ConsoleReporter = ConsoleReporter.forRegistry(registry)
+ .convertDurationsTo(TimeUnit.MILLISECONDS)
+ .convertRatesTo(TimeUnit.SECONDS)
+ .build()
+
+ override def start() {
+ reporter.start(pollPeriod, pollUnit)
+ }
+
+ override def stop() {
+ reporter.stop()
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
new file mode 100644
index 0000000000..3d1a06a395
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import com.codahale.metrics.{CsvReporter, MetricRegistry}
+
+import java.io.File
+import java.util.{Locale, Properties}
+import java.util.concurrent.TimeUnit
+
+import org.apache.spark.metrics.MetricsSystem
+
+class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+ val CSV_KEY_PERIOD = "period"
+ val CSV_KEY_UNIT = "unit"
+ val CSV_KEY_DIR = "directory"
+
+ val CSV_DEFAULT_PERIOD = 10
+ val CSV_DEFAULT_UNIT = "SECONDS"
+ val CSV_DEFAULT_DIR = "/tmp/"
+
+ val pollPeriod = Option(property.getProperty(CSV_KEY_PERIOD)) match {
+ case Some(s) => s.toInt
+ case None => CSV_DEFAULT_PERIOD
+ }
+
+ val pollUnit = Option(property.getProperty(CSV_KEY_UNIT)) match {
+ case Some(s) => TimeUnit.valueOf(s.toUpperCase())
+ case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT)
+ }
+
+ MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
+
+ val pollDir = Option(property.getProperty(CSV_KEY_DIR)) match {
+ case Some(s) => s
+ case None => CSV_DEFAULT_DIR
+ }
+
+ val reporter: CsvReporter = CsvReporter.forRegistry(registry)
+ .formatFor(Locale.US)
+ .convertDurationsTo(TimeUnit.MILLISECONDS)
+ .convertRatesTo(TimeUnit.SECONDS)
+ .build(new File(pollDir))
+
+ override def start() {
+ reporter.start(pollPeriod, pollUnit)
+ }
+
+ override def stop() {
+ reporter.stop()
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
new file mode 100644
index 0000000000..621d086d41
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import com.codahale.metrics.{JmxReporter, MetricRegistry}
+
+import java.util.Properties
+
+class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+ val reporter: JmxReporter = JmxReporter.forRegistry(registry).build()
+
+ override def start() {
+ reporter.start()
+ }
+
+ override def stop() {
+ reporter.stop()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
new file mode 100644
index 0000000000..4e90dd4323
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import com.codahale.metrics.MetricRegistry
+import com.codahale.metrics.json.MetricsModule
+
+import com.fasterxml.jackson.databind.ObjectMapper
+
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+import javax.servlet.http.HttpServletRequest
+
+import org.eclipse.jetty.server.Handler
+
+import org.apache.spark.ui.JettyUtils
+
+class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink {
+ val SERVLET_KEY_URI = "uri"
+ val SERVLET_KEY_SAMPLE = "sample"
+
+ val servletURI = property.getProperty(SERVLET_KEY_URI)
+
+ val servletShowSample = property.getProperty(SERVLET_KEY_SAMPLE).toBoolean
+
+ val mapper = new ObjectMapper().registerModule(
+ new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample))
+
+ def getHandlers = Array[(String, Handler)](
+ (servletURI, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json"))
+ )
+
+ def getMetricsSnapshot(request: HttpServletRequest): String = {
+ mapper.writeValueAsString(registry)
+ }
+
+ override def start() { }
+
+ override def stop() { }
+}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala
new file mode 100644
index 0000000000..3a739aa563
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+trait Sink {
+ def start: Unit
+ def stop: Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala
new file mode 100644
index 0000000000..75cb2b8973
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.source
+
+import com.codahale.metrics.MetricRegistry
+import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet}
+
+class JvmSource extends Source {
+ val sourceName = "jvm"
+ val metricRegistry = new MetricRegistry()
+
+ val gcMetricSet = new GarbageCollectorMetricSet
+ val memGaugeSet = new MemoryUsageGaugeSet
+
+ metricRegistry.registerAll(gcMetricSet)
+ metricRegistry.registerAll(memGaugeSet)
+}
diff --git a/core/src/main/scala/org/apache/spark/metrics/source/Source.scala b/core/src/main/scala/org/apache/spark/metrics/source/Source.scala
new file mode 100644
index 0000000000..3fee55cc6d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/source/Source.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.source
+
+import com.codahale.metrics.MetricRegistry
+
+trait Source {
+ def sourceName: String
+ def metricRegistry: MetricRegistry
+}
diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
new file mode 100644
index 0000000000..f736bb3713
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.storage.BlockManager
+
+
+private[spark]
+class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
+ extends Message(Message.BUFFER_MESSAGE, id_) {
+
+ val initialSize = currentSize()
+ var gotChunkForSendingOnce = false
+
+ def size = initialSize
+
+ def currentSize() = {
+ if (buffers == null || buffers.isEmpty) {
+ 0
+ } else {
+ buffers.map(_.remaining).reduceLeft(_ + _)
+ }
+ }
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
+ if (maxChunkSize <= 0) {
+ throw new Exception("Max chunk size is " + maxChunkSize)
+ }
+
+ if (size == 0 && gotChunkForSendingOnce == false) {
+ val newChunk = new MessageChunk(
+ new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+
+ while(!buffers.isEmpty) {
+ val buffer = buffers(0)
+ if (buffer.remaining == 0) {
+ BlockManager.dispose(buffer)
+ buffers -= buffer
+ } else {
+ val newBuffer = if (buffer.remaining <= maxChunkSize) {
+ buffer.duplicate()
+ } else {
+ buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
+ }
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ gotChunkForSendingOnce = true
+ return Some(newChunk)
+ }
+ }
+ None
+ }
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
+ // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
+ if (buffers.size > 1) {
+ throw new Exception("Attempting to get chunk from message with multiple data buffers")
+ }
+ val buffer = buffers(0)
+ if (buffer.remaining > 0) {
+ if (buffer.remaining < chunkSize) {
+ throw new Exception("Not enough space in data buffer for receiving chunk")
+ }
+ val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
+ buffer.position(buffer.position + newBuffer.remaining)
+ val newChunk = new MessageChunk(new MessageChunkHeader(
+ typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
+ return Some(newChunk)
+ }
+ None
+ }
+
+ def flip() {
+ buffers.foreach(_.flip)
+ }
+
+ def hasAckId() = (ackId != 0)
+
+ def isCompletelyReceived() = !buffers(0).hasRemaining
+
+ override def toString = {
+ if (hasAckId) {
+ "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
+ } else {
+ "BufferMessage(id = " + id + ", size = " + size + ")"
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala
new file mode 100644
index 0000000000..95cb0206ac
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/Connection.scala
@@ -0,0 +1,586 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import org.apache.spark._
+
+import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}
+
+import java.io._
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.net._
+
+
+private[spark]
+abstract class Connection(val channel: SocketChannel, val selector: Selector,
+ val socketRemoteConnectionManagerId: ConnectionManagerId)
+ extends Logging {
+
+ def this(channel_ : SocketChannel, selector_ : Selector) = {
+ this(channel_, selector_,
+ ConnectionManagerId.fromSocketAddress(
+ channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
+ }
+
+ channel.configureBlocking(false)
+ channel.socket.setTcpNoDelay(true)
+ channel.socket.setReuseAddress(true)
+ channel.socket.setKeepAlive(true)
+ /*channel.socket.setReceiveBufferSize(32768) */
+
+ @volatile private var closed = false
+ var onCloseCallback: Connection => Unit = null
+ var onExceptionCallback: (Connection, Exception) => Unit = null
+ var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
+
+ val remoteAddress = getRemoteAddress()
+
+ def resetForceReregister(): Boolean
+
+ // Read channels typically do not register for write and write does not for read
+ // Now, we do have write registering for read too (temporarily), but this is to detect
+ // channel close NOT to actually read/consume data on it !
+ // How does this work if/when we move to SSL ?
+
+ // What is the interest to register with selector for when we want this connection to be selected
+ def registerInterest()
+
+ // What is the interest to register with selector for when we want this connection to
+ // be de-selected
+ // Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack,
+ // it will be SelectionKey.OP_READ (until we fix it properly)
+ def unregisterInterest()
+
+ // On receiving a read event, should we change the interest for this channel or not ?
+ // Will be true for ReceivingConnection, false for SendingConnection.
+ def changeInterestForRead(): Boolean
+
+ // On receiving a write event, should we change the interest for this channel or not ?
+ // Will be false for ReceivingConnection, true for SendingConnection.
+ // Actually, for now, should not get triggered for ReceivingConnection
+ def changeInterestForWrite(): Boolean
+
+ def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ socketRemoteConnectionManagerId
+ }
+
+ def key() = channel.keyFor(selector)
+
+ def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
+
+ // Returns whether we have to register for further reads or not.
+ def read(): Boolean = {
+ throw new UnsupportedOperationException(
+ "Cannot read on connection of type " + this.getClass.toString)
+ }
+
+ // Returns whether we have to register for further writes or not.
+ def write(): Boolean = {
+ throw new UnsupportedOperationException(
+ "Cannot write on connection of type " + this.getClass.toString)
+ }
+
+ def close() {
+ closed = true
+ val k = key()
+ if (k != null) {
+ k.cancel()
+ }
+ channel.close()
+ callOnCloseCallback()
+ }
+
+ protected def isClosed: Boolean = closed
+
+ def onClose(callback: Connection => Unit) {
+ onCloseCallback = callback
+ }
+
+ def onException(callback: (Connection, Exception) => Unit) {
+ onExceptionCallback = callback
+ }
+
+ def onKeyInterestChange(callback: (Connection, Int) => Unit) {
+ onKeyInterestChangeCallback = callback
+ }
+
+ def callOnExceptionCallback(e: Exception) {
+ if (onExceptionCallback != null) {
+ onExceptionCallback(this, e)
+ } else {
+ logError("Error in connection to " + getRemoteConnectionManagerId() +
+ " and OnExceptionCallback not registered", e)
+ }
+ }
+
+ def callOnCloseCallback() {
+ if (onCloseCallback != null) {
+ onCloseCallback(this)
+ } else {
+ logWarning("Connection to " + getRemoteConnectionManagerId() +
+ " closed and OnExceptionCallback not registered")
+ }
+
+ }
+
+ def changeConnectionKeyInterest(ops: Int) {
+ if (onKeyInterestChangeCallback != null) {
+ onKeyInterestChangeCallback(this, ops)
+ } else {
+ throw new Exception("OnKeyInterestChangeCallback not registered")
+ }
+ }
+
+ def printRemainingBuffer(buffer: ByteBuffer) {
+ val bytes = new Array[Byte](buffer.remaining)
+ val curPosition = buffer.position
+ buffer.get(bytes)
+ bytes.foreach(x => print(x + " "))
+ buffer.position(curPosition)
+ print(" (" + bytes.size + ")")
+ }
+
+ def printBuffer(buffer: ByteBuffer, position: Int, length: Int) {
+ val bytes = new Array[Byte](length)
+ val curPosition = buffer.position
+ buffer.position(position)
+ buffer.get(bytes)
+ bytes.foreach(x => print(x + " "))
+ print(" (" + position + ", " + length + ")")
+ buffer.position(curPosition)
+ }
+}
+
+
+private[spark]
+class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
+ remoteId_ : ConnectionManagerId)
+ extends Connection(SocketChannel.open, selector_, remoteId_) {
+
+ private class Outbox(fair: Int = 0) {
+ val messages = new Queue[Message]()
+ val defaultChunkSize = 65536 //32768 //16384
+ var nextMessageToBeUsed = 0
+
+ def addMessage(message: Message) {
+ messages.synchronized{
+ /*messages += message*/
+ messages.enqueue(message)
+ logDebug("Added [" + message + "] to outbox for sending to " +
+ "[" + getRemoteConnectionManagerId() + "]")
+ }
+ }
+
+ def getChunk(): Option[MessageChunk] = {
+ fair match {
+ case 0 => getChunkFIFO()
+ case 1 => getChunkRR()
+ case _ => throw new Exception("Unexpected fairness policy in outbox")
+ }
+ }
+
+ private def getChunkFIFO(): Option[MessageChunk] = {
+ /*logInfo("Using FIFO")*/
+ messages.synchronized {
+ while (!messages.isEmpty) {
+ val message = messages(0)
+ val chunk = message.getChunkForSending(defaultChunkSize)
+ if (chunk.isDefined) {
+ messages += message // this is probably incorrect, it wont work as fifo
+ if (!message.started) {
+ logDebug("Starting to send [" + message + "]")
+ message.started = true
+ message.startTime = System.currentTimeMillis
+ }
+ return chunk
+ } else {
+ /*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/
+ message.finishTime = System.currentTimeMillis
+ logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
+ "] in " + message.timeTaken )
+ }
+ }
+ }
+ None
+ }
+
+ private def getChunkRR(): Option[MessageChunk] = {
+ messages.synchronized {
+ while (!messages.isEmpty) {
+ /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
+ /*val message = messages(nextMessageToBeUsed)*/
+ val message = messages.dequeue
+ val chunk = message.getChunkForSending(defaultChunkSize)
+ if (chunk.isDefined) {
+ messages.enqueue(message)
+ nextMessageToBeUsed = nextMessageToBeUsed + 1
+ if (!message.started) {
+ logDebug(
+ "Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
+ message.started = true
+ message.startTime = System.currentTimeMillis
+ }
+ logTrace(
+ "Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
+ return chunk
+ } else {
+ message.finishTime = System.currentTimeMillis
+ logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
+ "] in " + message.timeTaken )
+ }
+ }
+ }
+ None
+ }
+ }
+
+ // outbox is used as a lock - ensure that it is always used as a leaf (since methods which
+ // lock it are invoked in context of other locks)
+ private val outbox = new Outbox(1)
+ /*
+ This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly
+ different purpose. This flag is to see if we need to force reregister for write even when we
+ do not have any pending bytes to write to socket.
+ This can happen due to a race between adding pending buffers, and checking for existing of
+ data as detailed in https://github.com/mesos/spark/pull/791
+ */
+ private var needForceReregister = false
+ val currentBuffers = new ArrayBuffer[ByteBuffer]()
+
+ /*channel.socket.setSendBufferSize(256 * 1024)*/
+
+ override def getRemoteAddress() = address
+
+ val DEFAULT_INTEREST = SelectionKey.OP_READ
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(DEFAULT_INTEREST)
+ }
+
+ def send(message: Message) {
+ outbox.synchronized {
+ outbox.addMessage(message)
+ needForceReregister = true
+ }
+ if (channel.isConnected) {
+ registerInterest()
+ }
+ }
+
+ // return previous value after resetting it.
+ def resetForceReregister(): Boolean = {
+ outbox.synchronized {
+ val result = needForceReregister
+ needForceReregister = false
+ result
+ }
+ }
+
+ // MUST be called within the selector loop
+ def connect() {
+ try{
+ channel.register(selector, SelectionKey.OP_CONNECT)
+ channel.connect(address)
+ logInfo("Initiating connection to [" + address + "]")
+ } catch {
+ case e: Exception => {
+ logError("Error connecting to " + address, e)
+ callOnExceptionCallback(e)
+ }
+ }
+ }
+
+ def finishConnect(force: Boolean): Boolean = {
+ try {
+ // Typically, this should finish immediately since it was triggered by a connect
+ // selection - though need not necessarily always complete successfully.
+ val connected = channel.finishConnect
+ if (!force && !connected) {
+ logInfo(
+ "finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
+ return false
+ }
+
+ // Fallback to previous behavior - assume finishConnect completed
+ // This will happen only when finishConnect failed for some repeated number of times
+ // (10 or so)
+ // Is highly unlikely unless there was an unclean close of socket, etc
+ registerInterest()
+ logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
+ return true
+ } catch {
+ case e: Exception => {
+ logWarning("Error finishing connection to " + address, e)
+ callOnExceptionCallback(e)
+ // ignore
+ return true
+ }
+ }
+ }
+
+ override def write(): Boolean = {
+ try {
+ while (true) {
+ if (currentBuffers.size == 0) {
+ outbox.synchronized {
+ outbox.getChunk() match {
+ case Some(chunk) => {
+ val buffers = chunk.buffers
+ // If we have 'seen' pending messages, then reset flag - since we handle that as normal
+ // registering of event (below)
+ if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister()
+ currentBuffers ++= buffers
+ }
+ case None => {
+ // changeConnectionKeyInterest(0)
+ /*key.interestOps(0)*/
+ return false
+ }
+ }
+ }
+ }
+
+ if (currentBuffers.size > 0) {
+ val buffer = currentBuffers(0)
+ val remainingBytes = buffer.remaining
+ val writtenBytes = channel.write(buffer)
+ if (buffer.remaining == 0) {
+ currentBuffers -= buffer
+ }
+ if (writtenBytes < remainingBytes) {
+ // re-register for write.
+ return true
+ }
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
+ callOnExceptionCallback(e)
+ close()
+ return false
+ }
+ }
+ // should not happen - to keep scala compiler happy
+ return true
+ }
+
+ // This is a hack to determine if remote socket was closed or not.
+ // SendingConnection DOES NOT expect to receive any data - if it does, it is an error
+ // For a bunch of cases, read will return -1 in case remote socket is closed : hence we
+ // register for reads to determine that.
+ override def read(): Boolean = {
+ // We don't expect the other side to send anything; so, we just read to detect an error or EOF.
+ try {
+ val length = channel.read(ByteBuffer.allocate(1))
+ if (length == -1) { // EOF
+ close()
+ } else if (length > 0) {
+ logWarning(
+ "Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
+ }
+ } catch {
+ case e: Exception =>
+ logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e)
+ callOnExceptionCallback(e)
+ close()
+ }
+
+ false
+ }
+
+ override def changeInterestForRead(): Boolean = false
+
+ override def changeInterestForWrite(): Boolean = ! isClosed
+}
+
+
+// Must be created within selector loop - else deadlock
+private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
+ extends Connection(channel_, selector_) {
+
+ class Inbox() {
+ val messages = new HashMap[Int, BufferMessage]()
+
+ def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
+
+ def createNewMessage: BufferMessage = {
+ val newMessage = Message.create(header).asInstanceOf[BufferMessage]
+ newMessage.started = true
+ newMessage.startTime = System.currentTimeMillis
+ logDebug(
+ "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
+ messages += ((newMessage.id, newMessage))
+ newMessage
+ }
+
+ val message = messages.getOrElseUpdate(header.id, createNewMessage)
+ logTrace(
+ "Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
+ message.getChunkForReceiving(header.chunkSize)
+ }
+
+ def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
+ messages.get(chunk.header.id)
+ }
+
+ def removeMessage(message: Message) {
+ messages -= message.id
+ }
+ }
+
+ @volatile private var inferredRemoteManagerId: ConnectionManagerId = null
+
+ override def getRemoteConnectionManagerId(): ConnectionManagerId = {
+ val currId = inferredRemoteManagerId
+ if (currId != null) currId else super.getRemoteConnectionManagerId()
+ }
+
+ // The reciever's remote address is the local socket on remote side : which is NOT
+ // the connection manager id of the receiver.
+ // We infer that from the messages we receive on the receiver socket.
+ private def processConnectionManagerId(header: MessageChunkHeader) {
+ val currId = inferredRemoteManagerId
+ if (header.address == null || currId != null) return
+
+ val managerId = ConnectionManagerId.fromSocketAddress(header.address)
+
+ if (managerId != null) {
+ inferredRemoteManagerId = managerId
+ }
+ }
+
+
+ val inbox = new Inbox()
+ val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
+ var onReceiveCallback: (Connection , Message) => Unit = null
+ var currentChunk: MessageChunk = null
+
+ channel.register(selector, SelectionKey.OP_READ)
+
+ override def read(): Boolean = {
+ try {
+ while (true) {
+ if (currentChunk == null) {
+ val headerBytesRead = channel.read(headerBuffer)
+ if (headerBytesRead == -1) {
+ close()
+ return false
+ }
+ if (headerBuffer.remaining > 0) {
+ // re-register for read event ...
+ return true
+ }
+ headerBuffer.flip
+ if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
+ throw new Exception(
+ "Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
+ }
+ val header = MessageChunkHeader.create(headerBuffer)
+ headerBuffer.clear()
+
+ processConnectionManagerId(header)
+
+ header.typ match {
+ case Message.BUFFER_MESSAGE => {
+ if (header.totalSize == 0) {
+ if (onReceiveCallback != null) {
+ onReceiveCallback(this, Message.create(header))
+ }
+ currentChunk = null
+ // re-register for read event ...
+ return true
+ } else {
+ currentChunk = inbox.getChunk(header).orNull
+ }
+ }
+ case _ => throw new Exception("Message of unknown type received")
+ }
+ }
+
+ if (currentChunk == null) throw new Exception("No message chunk to receive data")
+
+ val bytesRead = channel.read(currentChunk.buffer)
+ if (bytesRead == 0) {
+ // re-register for read event ...
+ return true
+ } else if (bytesRead == -1) {
+ close()
+ return false
+ }
+
+ /*logDebug("Read " + bytesRead + " bytes for the buffer")*/
+
+ if (currentChunk.buffer.remaining == 0) {
+ /*println("Filled buffer at " + System.currentTimeMillis)*/
+ val bufferMessage = inbox.getMessageForChunk(currentChunk).get
+ if (bufferMessage.isCompletelyReceived) {
+ bufferMessage.flip
+ bufferMessage.finishTime = System.currentTimeMillis
+ logDebug("Finished receiving [" + bufferMessage + "] from " +
+ "[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
+ if (onReceiveCallback != null) {
+ onReceiveCallback(this, bufferMessage)
+ }
+ inbox.removeMessage(bufferMessage)
+ }
+ currentChunk = null
+ }
+ }
+ } catch {
+ case e: Exception => {
+ logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
+ callOnExceptionCallback(e)
+ close()
+ return false
+ }
+ }
+ // should not happen - to keep scala compiler happy
+ return true
+ }
+
+ def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
+
+ // override def changeInterestForRead(): Boolean = ! isClosed
+ override def changeInterestForRead(): Boolean = true
+
+ override def changeInterestForWrite(): Boolean = {
+ throw new IllegalStateException("Unexpected invocation right now")
+ }
+
+ override def registerInterest() {
+ // Registering read too - does not really help in most cases, but for some
+ // it does - so let us keep it for now.
+ changeConnectionKeyInterest(SelectionKey.OP_READ)
+ }
+
+ override def unregisterInterest() {
+ changeConnectionKeyInterest(0)
+ }
+
+ // For read conn, always false.
+ override def resetForceReregister(): Boolean = false
+}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
new file mode 100644
index 0000000000..9e2233c07b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -0,0 +1,720 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import org.apache.spark._
+
+import java.nio._
+import java.nio.channels._
+import java.nio.channels.spi._
+import java.net._
+import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
+
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.SynchronizedMap
+import scala.collection.mutable.SynchronizedQueue
+import scala.collection.mutable.ArrayBuffer
+
+import akka.dispatch.{Await, Promise, ExecutionContext, Future}
+import akka.util.Duration
+import akka.util.duration._
+
+
+private[spark] class ConnectionManager(port: Int) extends Logging {
+
+ class MessageStatus(
+ val message: Message,
+ val connectionManagerId: ConnectionManagerId,
+ completionHandler: MessageStatus => Unit) {
+
+ var ackMessage: Option[Message] = None
+ var attempted = false
+ var acked = false
+
+ def markDone() { completionHandler(this) }
+ }
+
+ private val selector = SelectorProvider.provider.openSelector()
+
+ private val handleMessageExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.handler.threads.min","20").toInt,
+ System.getProperty("spark.core.connection.handler.threads.max","60").toInt,
+ System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ private val handleReadWriteExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.io.threads.min","4").toInt,
+ System.getProperty("spark.core.connection.io.threads.max","32").toInt,
+ System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap
+ private val handleConnectExecutor = new ThreadPoolExecutor(
+ System.getProperty("spark.core.connection.connect.threads.min","1").toInt,
+ System.getProperty("spark.core.connection.connect.threads.max","8").toInt,
+ System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS,
+ new LinkedBlockingDeque[Runnable]())
+
+ private val serverChannel = ServerSocketChannel.open()
+ private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
+ private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
+ private val messageStatuses = new HashMap[Int, MessageStatus]
+ private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
+ private val registerRequests = new SynchronizedQueue[SendingConnection]
+
+ implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
+
+ private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
+
+ serverChannel.configureBlocking(false)
+ serverChannel.socket.setReuseAddress(true)
+ serverChannel.socket.setReceiveBufferSize(256 * 1024)
+
+ serverChannel.socket.bind(new InetSocketAddress(port))
+ serverChannel.register(selector, SelectionKey.OP_ACCEPT)
+
+ val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
+ logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
+
+ private val selectorThread = new Thread("connection-manager-thread") {
+ override def run() = ConnectionManager.this.run()
+ }
+ selectorThread.setDaemon(true)
+ selectorThread.start()
+
+ private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerWrite(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ writeRunnableStarted.synchronized {
+ // So that we do not trigger more write events while processing this one.
+ // The write method will re-register when done.
+ if (conn.changeInterestForWrite()) conn.unregisterInterest()
+ if (writeRunnableStarted.contains(key)) {
+ // key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
+ return
+ }
+
+ writeRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.write()
+ } finally {
+ writeRunnableStarted.synchronized {
+ writeRunnableStarted -= key
+ val needReregister = register || conn.resetForceReregister()
+ if (needReregister && conn.changeInterestForWrite()) {
+ conn.registerInterest()
+ }
+ }
+ }
+ }
+ } )
+ }
+
+ private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
+ private def triggerRead(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ readRunnableStarted.synchronized {
+ // So that we do not trigger more read events while processing this one.
+ // The read method will re-register when done.
+ if (conn.changeInterestForRead())conn.unregisterInterest()
+ if (readRunnableStarted.contains(key)) {
+ return
+ }
+
+ readRunnableStarted += key
+ }
+ handleReadWriteExecutor.execute(new Runnable {
+ override def run() {
+ var register: Boolean = false
+ try {
+ register = conn.read()
+ } finally {
+ readRunnableStarted.synchronized {
+ readRunnableStarted -= key
+ if (register && conn.changeInterestForRead()) {
+ conn.registerInterest()
+ }
+ }
+ }
+ }
+ } )
+ }
+
+ private def triggerConnect(key: SelectionKey) {
+ val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection]
+ if (conn == null) return
+
+ // prevent other events from being triggered
+ // Since we are still trying to connect, we do not need to do the additional steps in triggerWrite
+ conn.changeConnectionKeyInterest(0)
+
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+
+ var tries: Int = 10
+ while (tries >= 0) {
+ if (conn.finishConnect(false)) return
+ // Sleep ?
+ Thread.sleep(1)
+ tries -= 1
+ }
+
+ // fallback to previous behavior : we should not really come here since this method was
+ // triggered since channel became connectable : but at times, the first finishConnect need not
+ // succeed : hence the loop to retry a few 'times'.
+ conn.finishConnect(true)
+ }
+ } )
+ }
+
+ // MUST be called within selector loop - else deadlock.
+ private def triggerForceCloseByException(key: SelectionKey, e: Exception) {
+ try {
+ key.interestOps(0)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+
+ val conn = connectionsByKey.getOrElse(key, null)
+ if (conn == null) return
+
+ // Pushing to connect threadpool
+ handleConnectExecutor.execute(new Runnable {
+ override def run() {
+ try {
+ conn.callOnExceptionCallback(e)
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ try {
+ conn.close()
+ } catch {
+ // ignore exceptions
+ case e: Exception => logDebug("Ignoring exception", e)
+ }
+ }
+ })
+ }
+
+
+ def run() {
+ try {
+ while(!selectorThread.isInterrupted) {
+ while (! registerRequests.isEmpty) {
+ val conn: SendingConnection = registerRequests.dequeue
+ addListeners(conn)
+ conn.connect()
+ addConnection(conn)
+ }
+
+ while(!keyInterestChangeRequests.isEmpty) {
+ val (key, ops) = keyInterestChangeRequests.dequeue
+
+ try {
+ if (key.isValid) {
+ val connection = connectionsByKey.getOrElse(key, null)
+ if (connection != null) {
+ val lastOps = key.interestOps()
+ key.interestOps(ops)
+
+ // hot loop - prevent materialization of string if trace not enabled.
+ if (isTraceEnabled()) {
+ def intToOpStr(op: Int): String = {
+ val opStrs = ArrayBuffer[String]()
+ if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
+ if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
+ if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
+ if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
+ if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
+ }
+
+ logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() +
+ "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
+ }
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+
+ val selectedKeysCount =
+ try {
+ selector.select()
+ } catch {
+ // Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently.
+ case e: CancelledKeyException => {
+ // Some keys within the selectors list are invalid/closed. clear them.
+ val allKeys = selector.keys().iterator()
+
+ while (allKeys.hasNext()) {
+ val key = allKeys.next()
+ try {
+ if (! key.isValid) {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+ }
+ 0
+ }
+
+ if (selectedKeysCount == 0) {
+ logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
+ }
+ if (selectorThread.isInterrupted) {
+ logInfo("Selector thread was interrupted!")
+ return
+ }
+
+ if (0 != selectedKeysCount) {
+ val selectedKeys = selector.selectedKeys().iterator()
+ while (selectedKeys.hasNext()) {
+ val key = selectedKeys.next
+ selectedKeys.remove()
+ try {
+ if (key.isValid) {
+ if (key.isAcceptable) {
+ acceptConnection(key)
+ } else
+ if (key.isConnectable) {
+ triggerConnect(key)
+ } else
+ if (key.isReadable) {
+ triggerRead(key)
+ } else
+ if (key.isWritable) {
+ triggerWrite(key)
+ }
+ } else {
+ logInfo("Key not valid ? " + key)
+ throw new CancelledKeyException()
+ }
+ } catch {
+ // weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException.
+ case e: CancelledKeyException => {
+ logInfo("key already cancelled ? " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ case e: Exception => {
+ logError("Exception processing key " + key, e)
+ triggerForceCloseByException(key, e)
+ }
+ }
+ }
+ }
+ }
+ } catch {
+ case e: Exception => logError("Error in select loop", e)
+ }
+ }
+
+ def acceptConnection(key: SelectionKey) {
+ val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
+
+ var newChannel = serverChannel.accept()
+
+ // accept them all in a tight loop. non blocking accept with no processing, should be fine
+ while (newChannel != null) {
+ try {
+ val newConnection = new ReceivingConnection(newChannel, selector)
+ newConnection.onReceive(receiveMessage)
+ addListeners(newConnection)
+ addConnection(newConnection)
+ logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
+ } catch {
+ // might happen in case of issues with registering with selector
+ case e: Exception => logError("Error in accept loop", e)
+ }
+
+ newChannel = serverChannel.accept()
+ }
+ }
+
+ private def addListeners(connection: Connection) {
+ connection.onKeyInterestChange(changeConnectionKeyInterest)
+ connection.onException(handleConnectionError)
+ connection.onClose(removeConnection)
+ }
+
+ def addConnection(connection: Connection) {
+ connectionsByKey += ((connection.key, connection))
+ }
+
+ def removeConnection(connection: Connection) {
+ connectionsByKey -= connection.key
+
+ try {
+ if (connection.isInstanceOf[SendingConnection]) {
+ val sendingConnection = connection.asInstanceOf[SendingConnection]
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
+
+ connectionsById -= sendingConnectionManagerId
+
+ messageStatuses.synchronized {
+ messageStatuses
+ .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
+ logInfo("Notifying " + status)
+ status.synchronized {
+ status.attempted = true
+ status.acked = false
+ status.markDone()
+ }
+ })
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
+ } else if (connection.isInstanceOf[ReceivingConnection]) {
+ val receivingConnection = connection.asInstanceOf[ReceivingConnection]
+ val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
+ logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
+
+ val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
+ if (! sendingConnectionOpt.isDefined) {
+ logError("Corresponding SendingConnectionManagerId not found")
+ return
+ }
+
+ val sendingConnection = sendingConnectionOpt.get
+ connectionsById -= remoteConnectionManagerId
+ sendingConnection.close()
+
+ val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
+
+ assert (sendingConnectionManagerId == remoteConnectionManagerId)
+
+ messageStatuses.synchronized {
+ for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
+ logInfo("Notifying " + s)
+ s.synchronized {
+ s.attempted = true
+ s.acked = false
+ s.markDone()
+ }
+ }
+
+ messageStatuses.retain((i, status) => {
+ status.connectionManagerId != sendingConnectionManagerId
+ })
+ }
+ }
+ } finally {
+ // So that the selection keys can be removed.
+ wakeupSelector()
+ }
+ }
+
+ def handleConnectionError(connection: Connection, e: Exception) {
+ logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId())
+ removeConnection(connection)
+ }
+
+ def changeConnectionKeyInterest(connection: Connection, ops: Int) {
+ keyInterestChangeRequests += ((connection.key, ops))
+ // so that registerations happen !
+ wakeupSelector()
+ }
+
+ def receiveMessage(connection: Connection, message: Message) {
+ val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
+ logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
+ val runnable = new Runnable() {
+ val creationTime = System.currentTimeMillis
+ def run() {
+ logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ handleMessage(connectionManagerId, message)
+ logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms")
+ }
+ }
+ handleMessageExecutor.execute(runnable)
+ /*handleMessage(connection, message)*/
+ }
+
+ private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ logDebug("Handling [" + message + "] from [" + connectionManagerId + "]")
+ message match {
+ case bufferMessage: BufferMessage => {
+ if (bufferMessage.hasAckId) {
+ val sentMessageStatus = messageStatuses.synchronized {
+ messageStatuses.get(bufferMessage.ackId) match {
+ case Some(status) => {
+ messageStatuses -= bufferMessage.ackId
+ status
+ }
+ case None => {
+ throw new Exception("Could not find reference for received ack message " + message.id)
+ null
+ }
+ }
+ }
+ sentMessageStatus.synchronized {
+ sentMessageStatus.ackMessage = Some(message)
+ sentMessageStatus.attempted = true
+ sentMessageStatus.acked = true
+ sentMessageStatus.markDone()
+ }
+ } else {
+ val ackMessage = if (onReceiveCallback != null) {
+ logDebug("Calling back")
+ onReceiveCallback(bufferMessage, connectionManagerId)
+ } else {
+ logDebug("Not calling back as callback is null")
+ None
+ }
+
+ if (ackMessage.isDefined) {
+ if (!ackMessage.get.isInstanceOf[BufferMessage]) {
+ logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
+ } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
+ logDebug("Response to " + bufferMessage + " does not have ack id set")
+ ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
+ }
+ }
+
+ sendMessage(connectionManagerId, ackMessage.getOrElse {
+ Message.createBufferMessage(bufferMessage.id)
+ })
+ }
+ }
+ case _ => throw new Exception("Unknown type message received")
+ }
+ }
+
+ private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
+ def startNewConnection(): SendingConnection = {
+ val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
+ val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
+ registerRequests.enqueue(newConnection)
+
+ newConnection
+ }
+ // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ...
+ // If we do re-add it, we should consistently use it everywhere I guess ?
+ val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
+ message.senderAddress = id.toSocketAddress()
+ logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
+ connection.send(message)
+
+ wakeupSelector()
+ }
+
+ private def wakeupSelector() {
+ selector.wakeup()
+ }
+
+ def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message)
+ : Future[Option[Message]] = {
+ val promise = Promise[Option[Message]]
+ val status = new MessageStatus(message, connectionManagerId, s => promise.success(s.ackMessage))
+ messageStatuses.synchronized {
+ messageStatuses += ((message.id, status))
+ }
+ sendMessage(connectionManagerId, message)
+ promise.future
+ }
+
+ def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = {
+ Await.result(sendMessageReliably(connectionManagerId, message), Duration.Inf)
+ }
+
+ def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) {
+ onReceiveCallback = callback
+ }
+
+ def stop() {
+ selectorThread.interrupt()
+ selectorThread.join()
+ selector.close()
+ val connections = connectionsByKey.values
+ connections.foreach(_.close())
+ if (connectionsByKey.size != 0) {
+ logWarning("All connections not cleaned up")
+ }
+ handleMessageExecutor.shutdown()
+ handleReadWriteExecutor.shutdown()
+ handleConnectExecutor.shutdown()
+ logInfo("ConnectionManager stopped")
+ }
+}
+
+
+private[spark] object ConnectionManager {
+
+ def main(args: Array[String]) {
+ val manager = new ConnectionManager(9999)
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ println("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ /*testSequentialSending(manager)*/
+ /*System.gc()*/
+
+ /*testParallelSending(manager)*/
+ /*System.gc()*/
+
+ /*testParallelDecreasingSending(manager)*/
+ /*System.gc()*/
+
+ testContinuousSending(manager)
+ System.gc()
+ }
+
+ def testSequentialSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Sequential Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliablySync(manager.id, bufferMessage)
+ })
+ println("--------------------------")
+ println()
+ }
+
+ def testParallelSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Parallel Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {
+ val g = Await.result(f, 1 second)
+ if (!g.isDefined) println("Failed")
+ })
+ val finishTime = System.currentTimeMillis
+
+ val mb = size * count / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ println("Started at " + startTime + ", finished at " + finishTime)
+ println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+
+ def testParallelDecreasingSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Parallel Decreasing Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+ val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
+ buffers.foreach(_.flip)
+ val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0
+
+ val startTime = System.currentTimeMillis
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {
+ val g = Await.result(f, 1 second)
+ if (!g.isDefined) println("Failed")
+ })
+ val finishTime = System.currentTimeMillis
+
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("--------------------------")
+ /*println("Started at " + startTime + ", finished at " + finishTime) */
+ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+
+ def testContinuousSending(manager: ConnectionManager) {
+ println("--------------------------")
+ println("Continuous Sending")
+ println("--------------------------")
+ val size = 10 * 1024 * 1024
+ val count = 10
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ while(true) {
+ (0 until count).map(i => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ manager.sendMessageReliably(manager.id, bufferMessage)
+ }).foreach(f => {
+ val g = Await.result(f, 1 second)
+ if (!g.isDefined) println("Failed")
+ })
+ val finishTime = System.currentTimeMillis
+ Thread.sleep(1000)
+ val mb = size * count / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val tput = mb * 1000.0 / ms
+ println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)")
+ println("--------------------------")
+ println()
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala
new file mode 100644
index 0000000000..0839c011b8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.net.InetSocketAddress
+
+import org.apache.spark.Utils
+
+
+private[spark] case class ConnectionManagerId(host: String, port: Int) {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
+ def toSocketAddress() = new InetSocketAddress(host, port)
+}
+
+
+private[spark] object ConnectionManagerId {
+ def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
+ new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
new file mode 100644
index 0000000000..8d9ad9604d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+
+import scala.io.Source
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+import akka.dispatch.Await
+import akka.util.duration._
+
+private[spark] object ConnectionManagerTest extends Logging{
+ def main(args: Array[String]) {
+ //<mesos cluster> - the master URL
+ //<slaves file> - a list slaves to run connectionTest on
+ //[num of tasks] - the number of parallel tasks to be initiated default is number of slave hosts
+ //[size of msg in MB (integer)] - the size of messages to be sent in each task, default is 10
+ //[count] - how many times to run, default is 3
+ //[await time in seconds] : await time (in seconds), default is 600
+ if (args.length < 2) {
+ println("Usage: ConnectionManagerTest <mesos cluster> <slaves file> [num of tasks] [size of msg in MB (integer)] [count] [await time in seconds)] ")
+ System.exit(1)
+ }
+
+ if (args(0).startsWith("local")) {
+ println("This runs only on a mesos cluster")
+ }
+
+ val sc = new SparkContext(args(0), "ConnectionManagerTest")
+ val slavesFile = Source.fromFile(args(1))
+ val slaves = slavesFile.mkString.split("\n")
+ slavesFile.close()
+
+ /*println("Slaves")*/
+ /*slaves.foreach(println)*/
+ val tasknum = if (args.length > 2) args(2).toInt else slaves.length
+ val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024
+ val count = if (args.length > 4) args(4).toInt else 3
+ val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second
+ println("Running "+count+" rounds of test: " + "parallel tasks = " + tasknum + ", msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime)
+ val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map(
+ i => SparkEnv.get.connectionManager.id).collect()
+ println("\nSlave ConnectionManagerIds")
+ slaveConnManagerIds.foreach(println)
+ println
+
+ (0 until count).foreach(i => {
+ val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => {
+ val connManager = SparkEnv.get.connectionManager
+ val thisConnManagerId = connManager.id
+ connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ logInfo("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val startTime = System.currentTimeMillis
+ val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => {
+ val bufferMessage = Message.createBufferMessage(buffer.duplicate)
+ logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]")
+ connManager.sendMessageReliably(slaveConnManagerId, bufferMessage)
+ })
+ val results = futures.map(f => Await.result(f, awaitTime))
+ val finishTime = System.currentTimeMillis
+ Thread.sleep(5000)
+
+ val mb = size * results.size / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"
+ logInfo(resultStr)
+ resultStr
+ }).collect()
+
+ println("---------------------")
+ println("Run " + i)
+ resultStrs.foreach(println)
+ println("---------------------")
+ })
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala
new file mode 100644
index 0000000000..f2ecc6d439
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/Message.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetSocketAddress
+
+import scala.collection.mutable.ArrayBuffer
+
+
+private[spark] abstract class Message(val typ: Long, val id: Int) {
+ var senderAddress: InetSocketAddress = null
+ var started = false
+ var startTime = -1L
+ var finishTime = -1L
+
+ def size: Int
+
+ def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
+
+ def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
+
+ def timeTaken(): String = (finishTime - startTime).toString + " ms"
+
+ override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
+}
+
+
+private[spark] object Message {
+ val BUFFER_MESSAGE = 1111111111L
+
+ var lastId = 1
+
+ def getNewId() = synchronized {
+ lastId += 1
+ if (lastId == 0) {
+ lastId += 1
+ }
+ lastId
+ }
+
+ def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
+ if (dataBuffers == null) {
+ return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
+ }
+ if (dataBuffers.exists(_ == null)) {
+ throw new Exception("Attempting to create buffer message with null buffer")
+ }
+ return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId)
+ }
+
+ def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
+ createBufferMessage(dataBuffers, 0)
+
+ def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
+ if (dataBuffer == null) {
+ return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
+ } else {
+ return createBufferMessage(Array(dataBuffer), ackId)
+ }
+ }
+
+ def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
+ createBufferMessage(dataBuffer, 0)
+
+ def createBufferMessage(ackId: Int): BufferMessage = {
+ createBufferMessage(new Array[ByteBuffer](0), ackId)
+ }
+
+ def create(header: MessageChunkHeader): Message = {
+ val newMessage: Message = header.typ match {
+ case BUFFER_MESSAGE => new BufferMessage(header.id,
+ ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
+ }
+ newMessage.senderAddress = header.address
+ newMessage
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/MessageChunk.scala
new file mode 100644
index 0000000000..e0fe57b80d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/MessageChunk.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+
+private[network]
+class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
+
+ val size = if (buffer == null) 0 else buffer.remaining
+
+ lazy val buffers = {
+ val ab = new ArrayBuffer[ByteBuffer]()
+ ab += header.buffer
+ if (buffer != null) {
+ ab += buffer
+ }
+ ab
+ }
+
+ override def toString = {
+ "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
new file mode 100644
index 0000000000..235fbc39b3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+
+
+private[spark] class MessageChunkHeader(
+ val typ: Long,
+ val id: Int,
+ val totalSize: Int,
+ val chunkSize: Int,
+ val other: Int,
+ val address: InetSocketAddress) {
+ lazy val buffer = {
+ // No need to change this, at 'use' time, we do a reverse lookup of the hostname.
+ // Refer to network.Connection
+ val ip = address.getAddress.getAddress()
+ val port = address.getPort()
+ ByteBuffer.
+ allocate(MessageChunkHeader.HEADER_SIZE).
+ putLong(typ).
+ putInt(id).
+ putInt(totalSize).
+ putInt(chunkSize).
+ putInt(other).
+ putInt(ip.size).
+ put(ip).
+ putInt(port).
+ position(MessageChunkHeader.HEADER_SIZE).
+ flip.asInstanceOf[ByteBuffer]
+ }
+
+ override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
+ " and sizes " + totalSize + " / " + chunkSize + " bytes"
+}
+
+
+private[spark] object MessageChunkHeader {
+ val HEADER_SIZE = 40
+
+ def create(buffer: ByteBuffer): MessageChunkHeader = {
+ if (buffer.remaining != HEADER_SIZE) {
+ throw new IllegalArgumentException("Cannot convert buffer data to Message")
+ }
+ val typ = buffer.getLong()
+ val id = buffer.getInt()
+ val totalSize = buffer.getInt()
+ val chunkSize = buffer.getInt()
+ val other = buffer.getInt()
+ val ipSize = buffer.getInt()
+ val ipBytes = new Array[Byte](ipSize)
+ buffer.get(ipBytes)
+ val ip = InetAddress.getByAddress(ipBytes)
+ val port = buffer.getInt()
+ new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
new file mode 100644
index 0000000000..781715108b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+private[spark] object ReceiverTest {
+
+ def main(args: Array[String]) {
+ val manager = new ConnectionManager(9999)
+ println("Started connection manager with id = " + manager.id)
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/
+ val buffer = ByteBuffer.wrap("response".getBytes())
+ Some(Message.createBufferMessage(buffer, msg.id))
+ })
+ Thread.currentThread.join()
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
new file mode 100644
index 0000000000..777574980f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network
+
+import java.nio.ByteBuffer
+import java.net.InetAddress
+
+private[spark] object SenderTest {
+
+ def main(args: Array[String]) {
+
+ if (args.length < 2) {
+ println("Usage: SenderTest <target host> <target port>")
+ System.exit(1)
+ }
+
+ val targetHost = args(0)
+ val targetPort = args(1).toInt
+ val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort)
+
+ val manager = new ConnectionManager(0)
+ println("Started connection manager with id = " + manager.id)
+
+ manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
+ println("Received [" + msg + "] from [" + id + "]")
+ None
+ })
+
+ val size = 100 * 1024 * 1024
+ val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
+ buffer.flip
+
+ val targetServer = args(0)
+
+ val count = 100
+ (0 until count).foreach(i => {
+ val dataMessage = Message.createBufferMessage(buffer.duplicate)
+ val startTime = System.currentTimeMillis
+ /*println("Started timer at " + startTime)*/
+ val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match {
+ case Some(response) =>
+ val buffer = response.asInstanceOf[BufferMessage].buffers(0)
+ new String(buffer.array)
+ case None => "none"
+ }
+ val finishTime = System.currentTimeMillis
+ val mb = size / 1024.0 / 1024.0
+ val ms = finishTime - startTime
+ /*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/
+ val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr
+ println(resultStr)
+ })
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
new file mode 100644
index 0000000000..3c29700920
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import io.netty.buffer._
+
+import org.apache.spark.Logging
+
+private[spark] class FileHeader (
+ val fileLen: Int,
+ val blockId: String) extends Logging {
+
+ lazy val buffer = {
+ val buf = Unpooled.buffer()
+ buf.capacity(FileHeader.HEADER_SIZE)
+ buf.writeInt(fileLen)
+ buf.writeInt(blockId.length)
+ blockId.foreach((x: Char) => buf.writeByte(x))
+ //padding the rest of header
+ if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
+ buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
+ } else {
+ throw new Exception("too long header " + buf.readableBytes)
+ logInfo("too long header")
+ }
+ buf
+ }
+
+}
+
+private[spark] object FileHeader {
+
+ val HEADER_SIZE = 40
+
+ def getFileLenOffset = 0
+ def getFileLenSize = Integer.SIZE/8
+
+ def create(buf: ByteBuf): FileHeader = {
+ val length = buf.readInt
+ val idLength = buf.readInt
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buf.readByte().asInstanceOf[Char]
+ }
+ val blockId = idBuilder.toString()
+ new FileHeader(length, blockId)
+ }
+
+
+ def main (args:Array[String]){
+
+ val header = new FileHeader(25,"block_0");
+ val buf = header.buffer;
+ val newheader = FileHeader.create(buf);
+ System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
+
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
new file mode 100644
index 0000000000..9493ccffd9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.util.concurrent.Executors
+
+import io.netty.buffer.ByteBuf
+import io.netty.channel.ChannelHandlerContext
+import io.netty.util.CharsetUtil
+
+import org.apache.spark.Logging
+import org.apache.spark.network.ConnectionManagerId
+
+import scala.collection.JavaConverters._
+
+
+private[spark] class ShuffleCopier extends Logging {
+
+ def getBlock(host: String, port: Int, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+
+ val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
+ val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
+ val fc = new FileClient(handler, connectTimeout)
+
+ try {
+ fc.init()
+ fc.connect(host, port)
+ fc.sendRequest(blockId)
+ fc.waitForClose()
+ fc.close()
+ } catch {
+ // Handle any socket-related exceptions in FileClient
+ case e: Exception => {
+ logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
+ handler.handleError(blockId)
+ }
+ }
+ }
+
+ def getBlock(cmId: ConnectionManagerId, blockId: String,
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
+ }
+
+ def getBlocks(cmId: ConnectionManagerId,
+ blocks: Seq[(String, Long)],
+ resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+
+ for ((blockId, size) <- blocks) {
+ getBlock(cmId, blockId, resultCollectCallback)
+ }
+ }
+}
+
+
+private[spark] object ShuffleCopier extends Logging {
+
+ private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
+ extends FileClientHandler with Logging {
+
+ override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
+ logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
+ resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
+ }
+
+ override def handleError(blockId: String) {
+ if (!isComplete) {
+ resultCollectCallBack(blockId, -1, null)
+ }
+ }
+ }
+
+ def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ if (size != -1) {
+ logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
+ }
+ }
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>")
+ System.exit(1)
+ }
+ val host = args(0)
+ val port = args(1).toInt
+ val file = args(2)
+ val threads = if (args.length > 3) args(3).toInt else 10
+
+ val copiers = Executors.newFixedThreadPool(80)
+ val tasks = (for (i <- Range(0, threads)) yield {
+ Executors.callable(new Runnable() {
+ def run() {
+ val copier = new ShuffleCopier()
+ copier.getBlock(host, port, file, echoResultCollectCallBack)
+ }
+ })
+ }).asJava
+ copiers.invokeAll(tasks)
+ copiers.shutdown
+ System.exit(0)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
new file mode 100644
index 0000000000..537f225469
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.io.File
+
+import org.apache.spark.Logging
+
+
+private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
+
+ val server = new FileServer(pResolver, portIn)
+ server.start()
+
+ def stop() {
+ server.stop()
+ }
+
+ def port: Int = server.getPort()
+}
+
+
+/**
+ * An application for testing the shuffle sender as a standalone program.
+ */
+private[spark] object ShuffleSender {
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println(
+ "Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>")
+ System.exit(1)
+ }
+
+ val port = args(0).toInt
+ val subDirsPerLocalDir = args(1).toInt
+ val localDirs = args.drop(2).map(new File(_))
+
+ val pResovler = new PathResolver {
+ override def getAbsolutePath(blockId: String): String = {
+ if (!blockId.startsWith("shuffle_")) {
+ throw new Exception("Block " + blockId + " is not a shuffle block")
+ }
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = math.abs(blockId.hashCode)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+ val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ val file = new File(subDir, blockId)
+ return file.getAbsolutePath
+ }
+ }
+ val sender = new ShuffleSender(port, pResovler)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
new file mode 100644
index 0000000000..1126480689
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to Spark, while
+ * [[org.apache.spark.RDD]] is the data type representing a distributed collection, and provides most
+ * parallel operations.
+ *
+ * In addition, [[org.apache.spark.PairRDDFunctions]] contains operations available only on RDDs of key-value
+ * pairs, such as `groupByKey` and `join`; [[org.apache.spark.DoubleRDDFunctions]] contains operations
+ * available only on RDDs of Doubles; and [[org.apache.spark.SequenceFileRDDFunctions]] contains operations
+ * available on RDDs that can be saved as SequenceFiles. These operations are automatically
+ * available on any RDD of the right type (e.g. RDD[(Int, Int)] through implicit conversions when
+ * you `import org.apache.spark.SparkContext._`.
+ */
+package object spark {
+ // For package docs only
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala
new file mode 100644
index 0000000000..c5d51bee50
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+import org.apache.spark._
+import org.apache.spark.scheduler.JobListener
+
+/**
+ * A JobListener for an approximate single-result action, such as count() or non-parallel reduce().
+ * This listener waits up to timeout milliseconds and will return a partial answer even if the
+ * complete answer is not available by then.
+ *
+ * This class assumes that the action is performed on an entire RDD[T] via a function that computes
+ * a result of type U for each partition, and that the action returns a partial or complete result
+ * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt).
+ */
+private[spark] class ApproximateActionListener[T, U, R](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ evaluator: ApproximateEvaluator[U, R],
+ timeout: Long)
+ extends JobListener {
+
+ val startTime = System.currentTimeMillis()
+ val totalTasks = rdd.partitions.size
+ var finishedTasks = 0
+ var failure: Option[Exception] = None // Set if the job has failed (permanently)
+ var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult
+
+ override def taskSucceeded(index: Int, result: Any) {
+ synchronized {
+ evaluator.merge(index, result.asInstanceOf[U])
+ finishedTasks += 1
+ if (finishedTasks == totalTasks) {
+ // If we had already returned a PartialResult, set its final value
+ resultObject.foreach(r => r.setFinalValue(evaluator.currentResult()))
+ // Notify any waiting thread that may have called awaitResult
+ this.notifyAll()
+ }
+ }
+ }
+
+ override def jobFailed(exception: Exception) {
+ synchronized {
+ failure = Some(exception)
+ this.notifyAll()
+ }
+ }
+
+ /**
+ * Waits for up to timeout milliseconds since the listener was created and then returns a
+ * PartialResult with the result so far. This may be complete if the whole job is done.
+ */
+ def awaitResult(): PartialResult[R] = synchronized {
+ val finishTime = startTime + timeout
+ while (true) {
+ val time = System.currentTimeMillis()
+ if (failure != None) {
+ throw failure.get
+ } else if (finishedTasks == totalTasks) {
+ return new PartialResult(evaluator.currentResult(), true)
+ } else if (time >= finishTime) {
+ resultObject = Some(new PartialResult(evaluator.currentResult(), false))
+ return resultObject.get
+ } else {
+ this.wait(finishTime - time)
+ }
+ }
+ // Should never be reached, but required to keep the compiler happy
+ return null
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/ApproximateEvaluator.scala
new file mode 100644
index 0000000000..9c2859c8b9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/partial/ApproximateEvaluator.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+/**
+ * An object that computes a function incrementally by merging in results of type U from multiple
+ * tasks. Allows partial evaluation at any point by calling currentResult().
+ */
+private[spark] trait ApproximateEvaluator[U, R] {
+ def merge(outputId: Int, taskResult: U): Unit
+ def currentResult(): R
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
new file mode 100644
index 0000000000..5f4450859c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+/**
+ * A Double with error bars on it.
+ */
+class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) {
+ override def toString(): String = "[%.3f, %.3f]".format(low, high)
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
new file mode 100644
index 0000000000..3155dfe165
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+import cern.jet.stat.Probability
+
+/**
+ * An ApproximateEvaluator for counts.
+ *
+ * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might
+ * be best to make this a special case of GroupedCountEvaluator with one group.
+ */
+private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[Long, BoundedDouble] {
+
+ var outputsMerged = 0
+ var sum: Long = 0
+
+ override def merge(outputId: Int, taskResult: Long) {
+ outputsMerged += 1
+ sum += taskResult
+ }
+
+ override def currentResult(): BoundedDouble = {
+ if (outputsMerged == totalOutputs) {
+ new BoundedDouble(sum, 1.0, sum, sum)
+ } else if (outputsMerged == 0) {
+ new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val mean = (sum + 1 - p) / p
+ val variance = (sum + 1) * (1 - p) / (p * p)
+ val stdev = math.sqrt(variance)
+ val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ new BoundedDouble(mean, confidence, low, high)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
new file mode 100644
index 0000000000..e519e3a548
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.Map
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import cern.jet.stat.Probability
+
+import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
+
+/**
+ * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval.
+ */
+private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] {
+
+ var outputsMerged = 0
+ var sums = new OLMap[T] // Sum of counts for each key
+
+ override def merge(outputId: Int, taskResult: OLMap[T]) {
+ outputsMerged += 1
+ val iter = taskResult.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue)
+ }
+ }
+
+ override def currentResult(): Map[T, BoundedDouble] = {
+ if (outputsMerged == totalOutputs) {
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val sum = entry.getLongValue()
+ result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
+ }
+ result
+ } else if (outputsMerged == 0) {
+ new HashMap[T, BoundedDouble]
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.object2LongEntrySet.fastIterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val sum = entry.getLongValue
+ val mean = (sum + 1 - p) / p
+ val variance = (sum + 1) * (1 - p) / (p * p)
+ val stdev = math.sqrt(variance)
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
+ }
+ result
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala
new file mode 100644
index 0000000000..cf8a5680b6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/partial/GroupedMeanEvaluator.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.mutable.HashMap
+import scala.collection.Map
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import org.apache.spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval.
+ */
+private[spark] class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
+
+ var outputsMerged = 0
+ var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
+
+ override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
+ outputsMerged += 1
+ val iter = taskResult.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val old = sums.get(entry.getKey)
+ if (old != null) {
+ old.merge(entry.getValue)
+ } else {
+ sums.put(entry.getKey, entry.getValue)
+ }
+ }
+ }
+
+ override def currentResult(): Map[T, BoundedDouble] = {
+ if (outputsMerged == totalOutputs) {
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val mean = entry.getValue.mean
+ result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean)
+ }
+ result
+ } else if (outputsMerged == 0) {
+ new HashMap[T, BoundedDouble]
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val studentTCacher = new StudentTCacher(confidence)
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val counter = entry.getValue
+ val mean = counter.mean
+ val stdev = math.sqrt(counter.sampleVariance / counter.count)
+ val confFactor = studentTCacher.get(counter.count)
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ result(entry.getKey) = new BoundedDouble(mean, confidence, low, high)
+ }
+ result
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala
new file mode 100644
index 0000000000..8225a5d933
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/partial/GroupedSumEvaluator.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+import java.util.{HashMap => JHashMap}
+import java.util.{Map => JMap}
+
+import scala.collection.mutable.HashMap
+import scala.collection.Map
+import scala.collection.JavaConversions.mapAsScalaMap
+
+import org.apache.spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval.
+ */
+private[spark] class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] {
+
+ var outputsMerged = 0
+ var sums = new JHashMap[T, StatCounter] // Sum of counts for each key
+
+ override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) {
+ outputsMerged += 1
+ val iter = taskResult.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val old = sums.get(entry.getKey)
+ if (old != null) {
+ old.merge(entry.getValue)
+ } else {
+ sums.put(entry.getKey, entry.getValue)
+ }
+ }
+ }
+
+ override def currentResult(): Map[T, BoundedDouble] = {
+ if (outputsMerged == totalOutputs) {
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val sum = entry.getValue.sum
+ result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum)
+ }
+ result
+ } else if (outputsMerged == 0) {
+ new HashMap[T, BoundedDouble]
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val studentTCacher = new StudentTCacher(confidence)
+ val result = new JHashMap[T, BoundedDouble](sums.size)
+ val iter = sums.entrySet.iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ val counter = entry.getValue
+ val meanEstimate = counter.mean
+ val meanVar = counter.sampleVariance / counter.count
+ val countEstimate = (counter.count + 1 - p) / p
+ val countVar = (counter.count + 1) * (1 - p) / (p * p)
+ val sumEstimate = meanEstimate * countEstimate
+ val sumVar = (meanEstimate * meanEstimate * countVar) +
+ (countEstimate * countEstimate * meanVar) +
+ (meanVar * countVar)
+ val sumStdev = math.sqrt(sumVar)
+ val confFactor = studentTCacher.get(counter.count)
+ val low = sumEstimate - confFactor * sumStdev
+ val high = sumEstimate + confFactor * sumStdev
+ result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high)
+ }
+ result
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
new file mode 100644
index 0000000000..d24959cba8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+import cern.jet.stat.Probability
+
+import org.apache.spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for means.
+ */
+private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[StatCounter, BoundedDouble] {
+
+ var outputsMerged = 0
+ var counter = new StatCounter
+
+ override def merge(outputId: Int, taskResult: StatCounter) {
+ outputsMerged += 1
+ counter.merge(taskResult)
+ }
+
+ override def currentResult(): BoundedDouble = {
+ if (outputsMerged == totalOutputs) {
+ new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean)
+ } else if (outputsMerged == 0) {
+ new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val mean = counter.mean
+ val stdev = math.sqrt(counter.sampleVariance / counter.count)
+ val confFactor = {
+ if (counter.count > 100) {
+ Probability.normalInverse(1 - (1 - confidence) / 2)
+ } else {
+ Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ }
+ }
+ val low = mean - confFactor * stdev
+ val high = mean + confFactor * stdev
+ new BoundedDouble(mean, confidence, low, high)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala
new file mode 100644
index 0000000000..5ce49b8100
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+class PartialResult[R](initialVal: R, isFinal: Boolean) {
+ private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None
+ private var failure: Option[Exception] = None
+ private var completionHandler: Option[R => Unit] = None
+ private var failureHandler: Option[Exception => Unit] = None
+
+ def initialValue: R = initialVal
+
+ def isInitialValueFinal: Boolean = isFinal
+
+ /**
+ * Blocking method to wait for and return the final value.
+ */
+ def getFinalValue(): R = synchronized {
+ while (finalValue == None && failure == None) {
+ this.wait()
+ }
+ if (finalValue != None) {
+ return finalValue.get
+ } else {
+ throw failure.get
+ }
+ }
+
+ /**
+ * Set a handler to be called when this PartialResult completes. Only one completion handler
+ * is supported per PartialResult.
+ */
+ def onComplete(handler: R => Unit): PartialResult[R] = synchronized {
+ if (completionHandler != None) {
+ throw new UnsupportedOperationException("onComplete cannot be called twice")
+ }
+ completionHandler = Some(handler)
+ if (finalValue != None) {
+ // We already have a final value, so let's call the handler
+ handler(finalValue.get)
+ }
+ return this
+ }
+
+ /**
+ * Set a handler to be called if this PartialResult's job fails. Only one failure handler
+ * is supported per PartialResult.
+ */
+ def onFail(handler: Exception => Unit) {
+ synchronized {
+ if (failureHandler != None) {
+ throw new UnsupportedOperationException("onFail cannot be called twice")
+ }
+ failureHandler = Some(handler)
+ if (failure != None) {
+ // We already have a failure, so let's call the handler
+ handler(failure.get)
+ }
+ }
+ }
+
+ /**
+ * Transform this PartialResult into a PartialResult of type T.
+ */
+ def map[T](f: R => T) : PartialResult[T] = {
+ new PartialResult[T](f(initialVal), isFinal) {
+ override def getFinalValue() : T = synchronized {
+ f(PartialResult.this.getFinalValue())
+ }
+ override def onComplete(handler: T => Unit): PartialResult[T] = synchronized {
+ PartialResult.this.onComplete(handler.compose(f)).map(f)
+ }
+ override def onFail(handler: Exception => Unit) {
+ synchronized {
+ PartialResult.this.onFail(handler)
+ }
+ }
+ override def toString : String = synchronized {
+ PartialResult.this.getFinalValueInternal() match {
+ case Some(value) => "(final: " + f(value) + ")"
+ case None => "(partial: " + initialValue + ")"
+ }
+ }
+ def getFinalValueInternal() = PartialResult.this.getFinalValueInternal().map(f)
+ }
+ }
+
+ private[spark] def setFinalValue(value: R) {
+ synchronized {
+ if (finalValue != None) {
+ throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult")
+ }
+ finalValue = Some(value)
+ // Call the completion handler if it was set
+ completionHandler.foreach(h => h(value))
+ // Notify any threads that may be calling getFinalValue()
+ this.notifyAll()
+ }
+ }
+
+ private def getFinalValueInternal() = finalValue
+
+ private[spark] def setFailure(exception: Exception) {
+ synchronized {
+ if (failure != None) {
+ throw new UnsupportedOperationException("setFailure called twice on a PartialResult")
+ }
+ failure = Some(exception)
+ // Call the failure handler if it was set
+ failureHandler.foreach(h => h(exception))
+ // Notify any threads that may be calling getFinalValue()
+ this.notifyAll()
+ }
+ }
+
+ override def toString: String = synchronized {
+ finalValue match {
+ case Some(value) => "(final: " + value + ")"
+ case None => "(partial: " + initialValue + ")"
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
new file mode 100644
index 0000000000..92915ee66d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+import cern.jet.stat.Probability
+
+/**
+ * A utility class for caching Student's T distribution values for a given confidence level
+ * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate
+ * confidence intervals for many keys.
+ */
+private[spark] class StudentTCacher(confidence: Double) {
+ val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
+ val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
+
+ def get(sampleSize: Long): Double = {
+ if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) {
+ normalApprox
+ } else {
+ val size = sampleSize.toInt
+ if (cache(size) < 0) {
+ cache(size) = Probability.studentTInverse(1 - confidence, size - 1)
+ }
+ cache(size)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
new file mode 100644
index 0000000000..a74f800944
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.partial
+
+import cern.jet.stat.Probability
+
+import org.apache.spark.util.StatCounter
+
+/**
+ * An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them
+ * together, then uses the formula for the variance of two independent random variables to get
+ * a variance for the result and compute a confidence interval.
+ */
+private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
+ extends ApproximateEvaluator[StatCounter, BoundedDouble] {
+
+ var outputsMerged = 0
+ var counter = new StatCounter
+
+ override def merge(outputId: Int, taskResult: StatCounter) {
+ outputsMerged += 1
+ counter.merge(taskResult)
+ }
+
+ override def currentResult(): BoundedDouble = {
+ if (outputsMerged == totalOutputs) {
+ new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum)
+ } else if (outputsMerged == 0) {
+ new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)
+ } else {
+ val p = outputsMerged.toDouble / totalOutputs
+ val meanEstimate = counter.mean
+ val meanVar = counter.sampleVariance / counter.count
+ val countEstimate = (counter.count + 1 - p) / p
+ val countVar = (counter.count + 1) * (1 - p) / (p * p)
+ val sumEstimate = meanEstimate * countEstimate
+ val sumVar = (meanEstimate * meanEstimate * countVar) +
+ (countEstimate * countEstimate * meanVar) +
+ (meanVar * countVar)
+ val sumStdev = math.sqrt(sumVar)
+ val confFactor = {
+ if (counter.count > 100) {
+ Probability.normalInverse(1 - (1 - confidence) / 2)
+ } else {
+ Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ }
+ }
+ val low = sumEstimate - confFactor * sumStdev
+ val high = sumEstimate + confFactor * sumStdev
+ new BoundedDouble(sumEstimate, confidence, low, high)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
new file mode 100644
index 0000000000..4bb01efa86
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext}
+import org.apache.spark.storage.BlockManager
+
+private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
+ val index = idx
+}
+
+private[spark]
+class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
+ extends RDD[T](sc, Nil) {
+
+ @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
+
+ override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => {
+ new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
+ }).toArray
+
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ val blockManager = SparkEnv.get.blockManager
+ val blockId = split.asInstanceOf[BlockRDDPartition].blockId
+ blockManager.get(blockId) match {
+ case Some(block) => block.asInstanceOf[Iterator[T]]
+ case None =>
+ throw new Exception("Could not compute split, block " + blockId + " not found")
+ }
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ locations_(split.asInstanceOf[BlockRDDPartition].blockId)
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
new file mode 100644
index 0000000000..9b0c882481
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.io.{ObjectOutputStream, IOException}
+import org.apache.spark._
+
+
+private[spark]
+class CartesianPartition(
+ idx: Int,
+ @transient rdd1: RDD[_],
+ @transient rdd2: RDD[_],
+ s1Index: Int,
+ s2Index: Int
+ ) extends Partition {
+ var s1 = rdd1.partitions(s1Index)
+ var s2 = rdd2.partitions(s2Index)
+ override val index: Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ s1 = rdd1.partitions(s1Index)
+ s2 = rdd2.partitions(s2Index)
+ oos.defaultWriteObject()
+ }
+}
+
+private[spark]
+class CartesianRDD[T: ClassManifest, U:ClassManifest](
+ sc: SparkContext,
+ var rdd1 : RDD[T],
+ var rdd2 : RDD[U])
+ extends RDD[Pair[T, U]](sc, Nil)
+ with Serializable {
+
+ val numPartitionsInRdd2 = rdd2.partitions.size
+
+ override def getPartitions: Array[Partition] = {
+ // create the cross product split
+ val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size)
+ for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) {
+ val idx = s1.index * numPartitionsInRdd2 + s2.index
+ array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index)
+ }
+ array
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ val currSplit = split.asInstanceOf[CartesianPartition]
+ (rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct
+ }
+
+ override def compute(split: Partition, context: TaskContext) = {
+ val currSplit = split.asInstanceOf[CartesianPartition]
+ for (x <- rdd1.iterator(currSplit.s1, context);
+ y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
+ }
+
+ override def getDependencies: Seq[Dependency[_]] = List(
+ new NarrowDependency(rdd1) {
+ def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2)
+ },
+ new NarrowDependency(rdd2) {
+ def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2)
+ }
+ )
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
new file mode 100644
index 0000000000..3311757189
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark._
+import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.{NullWritable, BytesWritable}
+import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.fs.Path
+import java.io.{File, IOException, EOFException}
+import java.text.NumberFormat
+
+private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
+
+/**
+ * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
+ */
+private[spark]
+class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
+ extends RDD[T](sc, Nil) {
+
+ @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
+
+ override def getPartitions: Array[Partition] = {
+ val cpath = new Path(checkpointPath)
+ val numPartitions =
+ // listStatus can throw exception if path does not exist.
+ if (fs.exists(cpath)) {
+ val dirContents = fs.listStatus(cpath)
+ val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
+ val numPart = partitionFiles.size
+ if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
+ ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) {
+ throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
+ }
+ numPart
+ } else 0
+
+ Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
+ }
+
+ checkpointData = Some(new RDDCheckpointData[T](this))
+ checkpointData.get.cpFile = Some(checkpointPath)
+
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)))
+ val locations = fs.getFileBlockLocations(status, 0, status.getLen)
+ locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+ }
+
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
+ CheckpointRDD.readFromFile(file, context)
+ }
+
+ override def checkpoint() {
+ // Do nothing. CheckpointRDD should not be checkpointed.
+ }
+}
+
+private[spark] object CheckpointRDD extends Logging {
+
+ def splitIdToFile(splitId: Int): String = {
+ "part-%05d".format(splitId)
+ }
+
+ def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
+ val env = SparkEnv.get
+ val outputDir = new Path(path)
+ val fs = outputDir.getFileSystem(env.hadoop.newConfiguration())
+
+ val finalOutputName = splitIdToFile(ctx.splitId)
+ val finalOutputPath = new Path(outputDir, finalOutputName)
+ val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
+
+ if (fs.exists(tempOutputPath)) {
+ throw new IOException("Checkpoint failed: temporary path " +
+ tempOutputPath + " already exists")
+ }
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
+ val fileOutputStream = if (blockSize < 0) {
+ fs.create(tempOutputPath, false, bufferSize)
+ } else {
+ // This is mainly for testing purpose
+ fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
+ }
+ val serializer = env.serializer.newInstance()
+ val serializeStream = serializer.serializeStream(fileOutputStream)
+ serializeStream.writeAll(iterator)
+ serializeStream.close()
+
+ if (!fs.rename(tempOutputPath, finalOutputPath)) {
+ if (!fs.exists(finalOutputPath)) {
+ logInfo("Deleting tempOutputPath " + tempOutputPath)
+ fs.delete(tempOutputPath, false)
+ throw new IOException("Checkpoint failed: failed to save output of task: "
+ + ctx.attemptId + " and final output path does not exist")
+ } else {
+ // Some other copy of this task must've finished before us and renamed it
+ logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
+ fs.delete(tempOutputPath, false)
+ }
+ }
+ }
+
+ def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
+ val env = SparkEnv.get
+ val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val fileInputStream = fs.open(path, bufferSize)
+ val serializer = env.serializer.newInstance()
+ val deserializeStream = serializer.deserializeStream(fileInputStream)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => deserializeStream.close())
+
+ deserializeStream.asIterator.asInstanceOf[Iterator[T]]
+ }
+
+ // Test whether CheckpointRDD generate expected number of partitions despite
+ // each split file having multiple blocks. This needs to be run on a
+ // cluster (mesos or standalone) using HDFS.
+ def main(args: Array[String]) {
+ import org.apache.spark._
+
+ val Array(cluster, hdfsPath) = args
+ val env = SparkEnv.get
+ val sc = new SparkContext(cluster, "CheckpointRDD Test")
+ val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
+ val path = new Path(hdfsPath, "temp")
+ val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
+ val cpRDD = new CheckpointRDD[Int](sc, path.toString)
+ assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
+ assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
+ fs.delete(path, true)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
new file mode 100644
index 0000000000..dcc35e8d0e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.io.{ObjectOutputStream, IOException}
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.JavaConversions
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{Partition, Partitioner, RDD, SparkEnv, TaskContext}
+import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
+
+
+private[spark] sealed trait CoGroupSplitDep extends Serializable
+
+private[spark] case class NarrowCoGroupSplitDep(
+ rdd: RDD[_],
+ splitIndex: Int,
+ var split: Partition
+ ) extends CoGroupSplitDep {
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split = rdd.partitions(splitIndex)
+ oos.defaultWriteObject()
+ }
+}
+
+private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
+
+private[spark]
+class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
+ extends Partition with Serializable {
+ override val index: Int = idx
+ override def hashCode(): Int = idx
+}
+
+
+/**
+ * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
+ * tuple with the list of values for that key.
+ *
+ * @param rdds parent RDDs.
+ * @param part partitioner used to partition the shuffle output.
+ */
+class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner)
+ extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
+
+ private var serializerClass: String = null
+
+ def setSerializer(cls: String): CoGroupedRDD[K] = {
+ serializerClass = cls
+ this
+ }
+
+ override def getDependencies: Seq[Dependency[_]] = {
+ rdds.map { rdd: RDD[_ <: Product2[K, _]] =>
+ if (rdd.partitioner == Some(part)) {
+ logDebug("Adding one-to-one dependency with " + rdd)
+ new OneToOneDependency(rdd)
+ } else {
+ logDebug("Adding shuffle dependency with " + rdd)
+ new ShuffleDependency[Any, Any](rdd, part, serializerClass)
+ }
+ }
+ }
+
+ override def getPartitions: Array[Partition] = {
+ val array = new Array[Partition](part.numPartitions)
+ for (i <- 0 until array.size) {
+ // Each CoGroupPartition will have a dependency per contributing RDD
+ array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) =>
+ // Assume each RDD contributed a single dependency, and get it
+ dependencies(j) match {
+ case s: ShuffleDependency[_, _] =>
+ new ShuffleCoGroupSplitDep(s.shuffleId)
+ case _ =>
+ new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
+ }
+ }.toArray)
+ }
+ array
+ }
+
+ override val partitioner = Some(part)
+
+ override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
+ val split = s.asInstanceOf[CoGroupPartition]
+ val numRdds = split.deps.size
+ // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
+ val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
+
+ def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
+ val seq = map.get(k)
+ if (seq != null) {
+ seq
+ } else {
+ val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
+ map.put(k, seq)
+ seq
+ }
+ }
+
+ val ser = SparkEnv.get.serializerManager.get(serializerClass)
+ for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
+ // Read them from the parent
+ rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]].foreach { kv =>
+ getSeq(kv._1)(depNum) += kv._2
+ }
+ }
+ case ShuffleCoGroupSplitDep(shuffleId) => {
+ // Read map outputs of shuffle
+ val fetcher = SparkEnv.get.shuffleFetcher
+ fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
+ kv => getSeq(kv._1)(depNum) += kv._2
+ }
+ }
+ }
+ JavaConversions.mapAsScalaMap(map).iterator
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdds = null
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
new file mode 100644
index 0000000000..c5de6362a9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -0,0 +1,342 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark._
+import java.io.{ObjectOutputStream, IOException}
+import scala.collection.mutable
+import scala.Some
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * Class that captures a coalesced RDD by essentially keeping track of parent partitions
+ * @param index of this coalesced partition
+ * @param rdd which it belongs to
+ * @param parentsIndices list of indices in the parent that have been coalesced into this partition
+ * @param preferredLocation the preferred location for this partition
+ */
+case class CoalescedRDDPartition(
+ index: Int,
+ @transient rdd: RDD[_],
+ parentsIndices: Array[Int],
+ @transient preferredLocation: String = ""
+ ) extends Partition {
+ var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_))
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent partition at the time of task serialization
+ parents = parentsIndices.map(rdd.partitions(_))
+ oos.defaultWriteObject()
+ }
+
+ /**
+ * Computes how many of the parents partitions have getPreferredLocation
+ * as one of their preferredLocations
+ * @return locality of this coalesced partition between 0 and 1
+ */
+ def localFraction: Double = {
+ val loc = parents.count(p =>
+ rdd.context.getPreferredLocs(rdd, p.index).map(tl => tl.host).contains(preferredLocation))
+
+ if (parents.size == 0) 0.0 else (loc.toDouble / parents.size.toDouble)
+ }
+}
+
+/**
+ * Represents a coalesced RDD that has fewer partitions than its parent RDD
+ * This class uses the PartitionCoalescer class to find a good partitioning of the parent RDD
+ * so that each new partition has roughly the same number of parent partitions and that
+ * the preferred location of each new partition overlaps with as many preferred locations of its
+ * parent partitions
+ * @param prev RDD to be coalesced
+ * @param maxPartitions number of desired partitions in the coalesced RDD
+ * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
+ */
+class CoalescedRDD[T: ClassManifest](
+ @transient var prev: RDD[T],
+ maxPartitions: Int,
+ balanceSlack: Double = 0.10)
+ extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
+
+ override def getPartitions: Array[Partition] = {
+ val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack)
+
+ pc.run().zipWithIndex.map {
+ case (pg, i) =>
+ val ids = pg.arr.map(_.index).toArray
+ new CoalescedRDDPartition(i, prev, ids, pg.prefLoc)
+ }
+ }
+
+ override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
+ partition.asInstanceOf[CoalescedRDDPartition].parents.iterator.flatMap { parentPartition =>
+ firstParent[T].iterator(parentPartition, context)
+ }
+ }
+
+ override def getDependencies: Seq[Dependency[_]] = {
+ Seq(new NarrowDependency(prev) {
+ def getParents(id: Int): Seq[Int] =
+ partitions(id).asInstanceOf[CoalescedRDDPartition].parentsIndices
+ })
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ prev = null
+ }
+
+ /**
+ * Returns the preferred machine for the partition. If split is of type CoalescedRDDPartition,
+ * then the preferred machine will be one which most parent splits prefer too.
+ * @param partition
+ * @return the machine most preferred by split
+ */
+ override def getPreferredLocations(partition: Partition): Seq[String] = {
+ List(partition.asInstanceOf[CoalescedRDDPartition].preferredLocation)
+ }
+}
+
+/**
+ * Coalesce the partitions of a parent RDD (`prev`) into fewer partitions, so that each partition of
+ * this RDD computes one or more of the parent ones. It will produce exactly `maxPartitions` if the
+ * parent had more than maxPartitions, or fewer if the parent had fewer.
+ *
+ * This transformation is useful when an RDD with many partitions gets filtered into a smaller one,
+ * or to avoid having a large number of small tasks when processing a directory with many files.
+ *
+ * If there is no locality information (no preferredLocations) in the parent, then the coalescing
+ * is very simple: chunk parents that are close in the Array in chunks.
+ * If there is locality information, it proceeds to pack them with the following four goals:
+ *
+ * (1) Balance the groups so they roughly have the same number of parent partitions
+ * (2) Achieve locality per partition, i.e. find one machine which most parent partitions prefer
+ * (3) Be efficient, i.e. O(n) algorithm for n parent partitions (problem is likely NP-hard)
+ * (4) Balance preferred machines, i.e. avoid as much as possible picking the same preferred machine
+ *
+ * Furthermore, it is assumed that the parent RDD may have many partitions, e.g. 100 000.
+ * We assume the final number of desired partitions is small, e.g. less than 1000.
+ *
+ * The algorithm tries to assign unique preferred machines to each partition. If the number of
+ * desired partitions is greater than the number of preferred machines (can happen), it needs to
+ * start picking duplicate preferred machines. This is determined using coupon collector estimation
+ * (2n log(n)). The load balancing is done using power-of-two randomized bins-balls with one twist:
+ * it tries to also achieve locality. This is done by allowing a slack (balanceSlack) between two
+ * bins. If two bins are within the slack in terms of balance, the algorithm will assign partitions
+ * according to locality. (contact alig for questions)
+ *
+ */
+
+private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
+
+ def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.size < o2.size
+ def compare(o1: Option[PartitionGroup], o2: Option[PartitionGroup]): Boolean =
+ if (o1 == None) false else if (o2 == None) true else compare(o1.get, o2.get)
+
+ val rnd = new scala.util.Random(7919) // keep this class deterministic
+
+ // each element of groupArr represents one coalesced partition
+ val groupArr = ArrayBuffer[PartitionGroup]()
+
+ // hash used to check whether some machine is already in groupArr
+ val groupHash = mutable.Map[String, ArrayBuffer[PartitionGroup]]()
+
+ // hash used for the first maxPartitions (to avoid duplicates)
+ val initialHash = mutable.Set[Partition]()
+
+ // determines the tradeoff between load-balancing the partitions sizes and their locality
+ // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality
+ val slack = (balanceSlack * prev.partitions.size).toInt
+
+ var noLocality = true // if true if no preferredLocations exists for parent RDD
+
+ // gets the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
+ def currPrefLocs(part: Partition): Seq[String] = {
+ prev.context.getPreferredLocs(prev, part.index).map(tl => tl.host)
+ }
+
+ // this class just keeps iterating and rotating infinitely over the partitions of the RDD
+ // next() returns the next preferred machine that a partition is replicated on
+ // the rotator first goes through the first replica copy of each partition, then second, third
+ // the iterators return type is a tuple: (replicaString, partition)
+ class LocationIterator(prev: RDD[_]) extends Iterator[(String, Partition)] {
+
+ var it: Iterator[(String, Partition)] = resetIterator()
+
+ override val isEmpty = !it.hasNext
+
+ // initializes/resets to start iterating from the beginning
+ def resetIterator() = {
+ val iterators = (0 to 2).map( x =>
+ prev.partitions.iterator.flatMap(p => {
+ if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None
+ } )
+ )
+ iterators.reduceLeft((x, y) => x ++ y)
+ }
+
+ // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD
+ def hasNext(): Boolean = { !isEmpty }
+
+ // return the next preferredLocation of some partition of the RDD
+ def next(): (String, Partition) = {
+ if (it.hasNext)
+ it.next()
+ else {
+ it = resetIterator() // ran out of preferred locations, reset and rotate to the beginning
+ it.next()
+ }
+ }
+ }
+
+ /**
+ * Sorts and gets the least element of the list associated with key in groupHash
+ * The returned PartitionGroup is the least loaded of all groups that represent the machine "key"
+ * @param key string representing a partitioned group on preferred machine key
+ * @return Option of PartitionGroup that has least elements for key
+ */
+ def getLeastGroupHash(key: String): Option[PartitionGroup] = {
+ groupHash.get(key).map(_.sortWith(compare).head)
+ }
+
+ def addPartToPGroup(part: Partition, pgroup: PartitionGroup): Boolean = {
+ if (!initialHash.contains(part)) {
+ pgroup.arr += part // already assign this element
+ initialHash += part // needed to avoid assigning partitions to multiple buckets
+ true
+ } else { false }
+ }
+
+ /**
+ * Initializes targetLen partition groups and assigns a preferredLocation
+ * This uses coupon collector to estimate how many preferredLocations it must rotate through
+ * until it has seen most of the preferred locations (2 * n log(n))
+ * @param targetLen
+ */
+ def setupGroups(targetLen: Int) {
+ val rotIt = new LocationIterator(prev)
+
+ // deal with empty case, just create targetLen partition groups with no preferred location
+ if (!rotIt.hasNext()) {
+ (1 to targetLen).foreach(x => groupArr += PartitionGroup())
+ return
+ }
+
+ noLocality = false
+
+ // number of iterations needed to be certain that we've seen most preferred locations
+ val expectedCoupons2 = 2 * (math.log(targetLen)*targetLen + targetLen + 0.5).toInt
+ var numCreated = 0
+ var tries = 0
+
+ // rotate through until either targetLen unique/distinct preferred locations have been created
+ // OR we've rotated expectedCoupons2, in which case we have likely seen all preferred locations,
+ // i.e. likely targetLen >> number of preferred locations (more buckets than there are machines)
+ while (numCreated < targetLen && tries < expectedCoupons2) {
+ tries += 1
+ val (nxt_replica, nxt_part) = rotIt.next()
+ if (!groupHash.contains(nxt_replica)) {
+ val pgroup = PartitionGroup(nxt_replica)
+ groupArr += pgroup
+ addPartToPGroup(nxt_part, pgroup)
+ groupHash += (nxt_replica -> (ArrayBuffer(pgroup))) // list in case we have multiple
+ numCreated += 1
+ }
+ }
+
+ while (numCreated < targetLen) { // if we don't have enough partition groups, create duplicates
+ var (nxt_replica, nxt_part) = rotIt.next()
+ val pgroup = PartitionGroup(nxt_replica)
+ groupArr += pgroup
+ groupHash.get(nxt_replica).get += pgroup
+ var tries = 0
+ while (!addPartToPGroup(nxt_part, pgroup) && tries < targetLen) { // ensure at least one part
+ nxt_part = rotIt.next()._2
+ tries += 1
+ }
+ numCreated += 1
+ }
+
+ }
+
+ /**
+ * Takes a parent RDD partition and decides which of the partition groups to put it in
+ * Takes locality into account, but also uses power of 2 choices to load balance
+ * It strikes a balance between the two use the balanceSlack variable
+ * @param p partition (ball to be thrown)
+ * @return partition group (bin to be put in)
+ */
+ def pickBin(p: Partition): PartitionGroup = {
+ val pref = currPrefLocs(p).map(getLeastGroupHash(_)).sortWith(compare) // least loaded pref locs
+ val prefPart = if (pref == Nil) None else pref.head
+
+ val r1 = rnd.nextInt(groupArr.size)
+ val r2 = rnd.nextInt(groupArr.size)
+ val minPowerOfTwo = if (groupArr(r1).size < groupArr(r2).size) groupArr(r1) else groupArr(r2)
+ if (prefPart== None) // if no preferred locations, just use basic power of two
+ return minPowerOfTwo
+
+ val prefPartActual = prefPart.get
+
+ if (minPowerOfTwo.size + slack <= prefPartActual.size) // more imbalance than the slack allows
+ return minPowerOfTwo // prefer balance over locality
+ else {
+ return prefPartActual // prefer locality over balance
+ }
+ }
+
+ def throwBalls() {
+ if (noLocality) { // no preferredLocations in parent RDD, no randomization needed
+ if (maxPartitions > groupArr.size) { // just return prev.partitions
+ for ((p,i) <- prev.partitions.zipWithIndex) {
+ groupArr(i).arr += p
+ }
+ } else { // no locality available, then simply split partitions based on positions in array
+ for(i <- 0 until maxPartitions) {
+ val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt
+ val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt
+ (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) }
+ }
+ }
+ } else {
+ for (p <- prev.partitions if (!initialHash.contains(p))) { // throw every partition into group
+ pickBin(p).arr += p
+ }
+ }
+ }
+
+ def getPartitions: Array[PartitionGroup] = groupArr.filter( pg => pg.size > 0).toArray
+
+ /**
+ * Runs the packing algorithm and returns an array of PartitionGroups that if possible are
+ * load balanced and grouped by locality
+ * @return array of partition groups
+ */
+ def run(): Array[PartitionGroup] = {
+ setupGroups(math.min(prev.partitions.length, maxPartitions)) // setup the groups (bins)
+ throwBalls() // assign partitions (balls) to each group (bins)
+ getPartitions
+ }
+}
+
+private[spark] case class PartitionGroup(prefLoc: String = "") {
+ var arr = mutable.ArrayBuffer[Partition]()
+
+ def size = arr.size
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
new file mode 100644
index 0000000000..24ce4abbc4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext}
+
+
+/**
+ * An RDD that is empty, i.e. has no element in it.
+ */
+class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) {
+
+ override def getPartitions: Array[Partition] = Array.empty
+
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ throw new UnsupportedOperationException("empty RDD")
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
new file mode 100644
index 0000000000..4df8ceb58b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{OneToOneDependency, RDD, Partition, TaskContext}
+
+private[spark] class FilteredRDD[T: ClassManifest](
+ prev: RDD[T],
+ f: T => Boolean)
+ extends RDD[T](prev) {
+
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+ override val partitioner = prev.partitioner // Since filter cannot change a partition's keys
+
+ override def compute(split: Partition, context: TaskContext) =
+ firstParent[T].iterator(split, context).filter(f)
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
new file mode 100644
index 0000000000..2bf7653af1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{RDD, Partition, TaskContext}
+
+
+private[spark]
+class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: T => TraversableOnce[U])
+ extends RDD[U](prev) {
+
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+ override def compute(split: Partition, context: TaskContext) =
+ firstParent[T].iterator(split, context).flatMap(f)
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala
new file mode 100644
index 0000000000..e544720b05
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{TaskContext, Partition, RDD}
+
+
+private[spark]
+class FlatMappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => TraversableOnce[U])
+ extends RDD[(K, U)](prev) {
+
+ override def getPartitions = firstParent[Product2[K, V]].partitions
+
+ override val partitioner = firstParent[Product2[K, V]].partitioner
+
+ override def compute(split: Partition, context: TaskContext) = {
+ firstParent[Product2[K, V]].iterator(split, context).flatMap { case Product2(k, v) =>
+ f(v).map(x => (k, x))
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
new file mode 100644
index 0000000000..2ce94199f2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{RDD, Partition, TaskContext}
+
+private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T])
+ extends RDD[Array[T]](prev) {
+
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+ override def compute(split: Partition, context: TaskContext) =
+ Array(firstParent[T].iterator(split, context).toArray).iterator
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
new file mode 100644
index 0000000000..08e6154bb9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.io.EOFException
+import java.util.NoSuchElementException
+
+import org.apache.hadoop.io.LongWritable
+import org.apache.hadoop.io.NullWritable
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapred.FileInputFormat
+import org.apache.hadoop.mapred.InputFormat
+import org.apache.hadoop.mapred.InputSplit
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.TextInputFormat
+import org.apache.hadoop.mapred.RecordReader
+import org.apache.hadoop.mapred.Reporter
+import org.apache.hadoop.util.ReflectionUtils
+
+import org.apache.spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark.util.NextIterator
+import org.apache.hadoop.conf.{Configuration, Configurable}
+
+
+/**
+ * A Spark split class that wraps around a Hadoop InputSplit.
+ */
+private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit)
+ extends Partition {
+
+ val inputSplit = new SerializableWritable[InputSplit](s)
+
+ override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
+
+ override val index: Int = idx
+}
+
+/**
+ * An RDD that reads a Hadoop dataset as specified by a JobConf (e.g. files in HDFS, the local file
+ * system, or S3, tables in HBase, etc).
+ */
+class HadoopRDD[K, V](
+ sc: SparkContext,
+ @transient conf: JobConf,
+ inputFormatClass: Class[_ <: InputFormat[K, V]],
+ keyClass: Class[K],
+ valueClass: Class[V],
+ minSplits: Int)
+ extends RDD[(K, V)](sc, Nil) with Logging {
+
+ // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it
+ private val confBroadcast = sc.broadcast(new SerializableWritable(conf))
+
+ override def getPartitions: Array[Partition] = {
+ val env = SparkEnv.get
+ env.hadoop.addCredentials(conf)
+ val inputFormat = createInputFormat(conf)
+ if (inputFormat.isInstanceOf[Configurable]) {
+ inputFormat.asInstanceOf[Configurable].setConf(conf)
+ }
+ val inputSplits = inputFormat.getSplits(conf, minSplits)
+ val array = new Array[Partition](inputSplits.size)
+ for (i <- 0 until inputSplits.size) {
+ array(i) = new HadoopPartition(id, i, inputSplits(i))
+ }
+ array
+ }
+
+ def createInputFormat(conf: JobConf): InputFormat[K, V] = {
+ ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf)
+ .asInstanceOf[InputFormat[K, V]]
+ }
+
+ override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
+ val split = theSplit.asInstanceOf[HadoopPartition]
+ logInfo("Input split: " + split.inputSplit)
+ var reader: RecordReader[K, V] = null
+
+ val conf = confBroadcast.value.value
+ val fmt = createInputFormat(conf)
+ if (fmt.isInstanceOf[Configurable]) {
+ fmt.asInstanceOf[Configurable].setConf(conf)
+ }
+ reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback{ () => closeIfNeeded() }
+
+ val key: K = reader.createKey()
+ val value: V = reader.createValue()
+
+ override def getNext() = {
+ try {
+ finished = !reader.next(key, value)
+ } catch {
+ case eof: EOFException =>
+ finished = true
+ }
+ (key, value)
+ }
+
+ override def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
+ }
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ // TODO: Filtering out "localhost" in case of file:// URLs
+ val hadoopSplit = split.asInstanceOf[HadoopPartition]
+ hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
+ }
+
+ override def checkpoint() {
+ // Do nothing. Hadoop RDD should not be checkpointed.
+ }
+
+ def getConf: Configuration = confBroadcast.value.value
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
new file mode 100644
index 0000000000..3db460b3ce
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.sql.{Connection, ResultSet}
+
+import org.apache.spark.{Logging, Partition, RDD, SparkContext, TaskContext}
+import org.apache.spark.util.NextIterator
+
+private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
+ override def index = idx
+}
+
+/**
+ * An RDD that executes an SQL query on a JDBC connection and reads results.
+ * For usage example, see test case JdbcRDDSuite.
+ *
+ * @param getConnection a function that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
+ * This should only call getInt, getString, etc; the RDD takes care of calling next.
+ * The default maps a ResultSet to an array of Object.
+ */
+class JdbcRDD[T: ClassManifest](
+ sc: SparkContext,
+ getConnection: () => Connection,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int,
+ mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
+ extends RDD[T](sc, Nil) with Logging {
+
+ override def getPartitions: Array[Partition] = {
+ // bounds are inclusive, hence the + 1 here and - 1 on end
+ val length = 1 + upperBound - lowerBound
+ (0 until numPartitions).map(i => {
+ val start = lowerBound + ((i * length) / numPartitions).toLong
+ val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1
+ new JdbcPartition(i, start, end)
+ }).toArray
+ }
+
+ override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
+ context.addOnCompleteCallback{ () => closeIfNeeded() }
+ val part = thePart.asInstanceOf[JdbcPartition]
+ val conn = getConnection()
+ val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
+
+ // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,
+ // rather than pulling entire resultset into memory.
+ // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
+ if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
+ stmt.setFetchSize(Integer.MIN_VALUE)
+ logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
+ }
+
+ stmt.setLong(1, part.lower)
+ stmt.setLong(2, part.upper)
+ val rs = stmt.executeQuery()
+
+ override def getNext: T = {
+ if (rs.next()) {
+ mapRow(rs)
+ } else {
+ finished = true
+ null.asInstanceOf[T]
+ }
+ }
+
+ override def close() {
+ try {
+ if (null != rs && ! rs.isClosed()) rs.close()
+ } catch {
+ case e: Exception => logWarning("Exception closing resultset", e)
+ }
+ try {
+ if (null != stmt && ! stmt.isClosed()) stmt.close()
+ } catch {
+ case e: Exception => logWarning("Exception closing statement", e)
+ }
+ try {
+ if (null != conn && ! stmt.isClosed()) conn.close()
+ logInfo("closed connection")
+ } catch {
+ case e: Exception => logWarning("Exception closing connection", e)
+ }
+ }
+ }
+}
+
+object JdbcRDD {
+ def resultSetToObjectArray(rs: ResultSet) = {
+ Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
new file mode 100644
index 0000000000..13009d3e17
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{RDD, Partition, TaskContext}
+
+
+private[spark]
+class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: Iterator[T] => Iterator[U],
+ preservesPartitioning: Boolean = false)
+ extends RDD[U](prev) {
+
+ override val partitioner =
+ if (preservesPartitioning) firstParent[T].partitioner else None
+
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+ override def compute(split: Partition, context: TaskContext) =
+ f(firstParent[T].iterator(split, context))
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala
new file mode 100644
index 0000000000..1683050b86
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{RDD, Partition, TaskContext}
+
+
+/**
+ * A variant of the MapPartitionsRDD that passes the partition index into the
+ * closure. This can be used to generate or collect partition specific
+ * information such as the number of tuples in a partition.
+ */
+private[spark]
+class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T],
+ f: (Int, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean
+ ) extends RDD[U](prev) {
+
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+ override val partitioner = if (preservesPartitioning) prev.partitioner else None
+
+ override def compute(split: Partition, context: TaskContext) =
+ f(split.index, firstParent[T].iterator(split, context))
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
new file mode 100644
index 0000000000..26d4806edb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{RDD, Partition, TaskContext}
+
+private[spark]
+class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
+ extends RDD[U](prev) {
+
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+ override def compute(split: Partition, context: TaskContext) =
+ firstParent[T].iterator(split, context).map(f)
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala
new file mode 100644
index 0000000000..a405e9acdd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+
+import org.apache.spark.{TaskContext, Partition, RDD}
+
+private[spark]
+class MappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => U)
+ extends RDD[(K, U)](prev) {
+
+ override def getPartitions = firstParent[Product2[K, U]].partitions
+
+ override val partitioner = firstParent[Product2[K, U]].partitioner
+
+ override def compute(split: Partition, context: TaskContext): Iterator[(K, U)] = {
+ firstParent[Product2[K, V]].iterator(split, context).map { case Product2(k ,v) => (k, f(v)) }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
new file mode 100644
index 0000000000..114b504486
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.text.SimpleDateFormat
+import java.util.Date
+
+import org.apache.hadoop.conf.{Configurable, Configuration}
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapreduce._
+
+import org.apache.spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext}
+
+
+private[spark]
+class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
+ extends Partition {
+
+ val serializableHadoopSplit = new SerializableWritable(rawSplit)
+
+ override def hashCode(): Int = (41 * (41 + rddId) + index)
+}
+
+class NewHadoopRDD[K, V](
+ sc : SparkContext,
+ inputFormatClass: Class[_ <: InputFormat[K, V]],
+ keyClass: Class[K],
+ valueClass: Class[V],
+ @transient conf: Configuration)
+ extends RDD[(K, V)](sc, Nil)
+ with SparkHadoopMapReduceUtil
+ with Logging {
+
+ // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
+ private val confBroadcast = sc.broadcast(new SerializableWritable(conf))
+ // private val serializableConf = new SerializableWritable(conf)
+
+ private val jobtrackerId: String = {
+ val formatter = new SimpleDateFormat("yyyyMMddHHmm")
+ formatter.format(new Date())
+ }
+
+ @transient private val jobId = new JobID(jobtrackerId, id)
+
+ override def getPartitions: Array[Partition] = {
+ val inputFormat = inputFormatClass.newInstance
+ if (inputFormat.isInstanceOf[Configurable]) {
+ inputFormat.asInstanceOf[Configurable].setConf(conf)
+ }
+ val jobContext = newJobContext(conf, jobId)
+ val rawSplits = inputFormat.getSplits(jobContext).toArray
+ val result = new Array[Partition](rawSplits.size)
+ for (i <- 0 until rawSplits.size) {
+ result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ }
+ result
+ }
+
+ override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[NewHadoopPartition]
+ logInfo("Input split: " + split.serializableHadoopSplit)
+ val conf = confBroadcast.value.value
+ val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
+ val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
+ val format = inputFormatClass.newInstance
+ if (format.isInstanceOf[Configurable]) {
+ format.asInstanceOf[Configurable].setConf(conf)
+ }
+ val reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => close())
+
+ var havePair = false
+ var finished = false
+
+ override def hasNext: Boolean = {
+ if (!finished && !havePair) {
+ finished = !reader.nextKeyValue
+ havePair = !finished
+ }
+ !finished
+ }
+
+ override def next: (K, V) = {
+ if (!hasNext) {
+ throw new java.util.NoSuchElementException("End of stream")
+ }
+ havePair = false
+ return (reader.getCurrentKey, reader.getCurrentValue)
+ }
+
+ private def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
+ }
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ val theSplit = split.asInstanceOf[NewHadoopPartition]
+ theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
+ }
+
+ def getConf: Configuration = confBroadcast.value.value
+}
+
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
new file mode 100644
index 0000000000..4c3df0eaf4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{RangePartitioner, Logging, RDD}
+
+/**
+ * Extra functions available on RDDs of (key, value) pairs where the key is sortable through
+ * an implicit conversion. Import `spark.SparkContext._` at the top of your program to use these
+ * functions. They will work with any key type that has a `scala.math.Ordered` implementation.
+ */
+class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest,
+ V: ClassManifest,
+ P <: Product2[K, V] : ClassManifest](
+ self: RDD[P])
+ extends Logging with Serializable {
+
+ /**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+ * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+ * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+ * order of the keys).
+ */
+ def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
+ val part = new RangePartitioner(numPartitions, self, ascending)
+ val shuffled = new ShuffledRDD[K, V, P](self, part)
+ shuffled.mapPartitions(iter => {
+ val buf = iter.toArray
+ if (ascending) {
+ buf.sortWith((x, y) => x._1 < y._1).iterator
+ } else {
+ buf.sortWith((x, y) => x._1 > y._1).iterator
+ }
+ }, preservesPartitioning = true)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
new file mode 100644
index 0000000000..8db3611054
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.collection.immutable.NumericRange
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.Map
+import org.apache.spark._
+import java.io._
+import scala.Serializable
+
+private[spark] class ParallelCollectionPartition[T: ClassManifest](
+ var rddId: Long,
+ var slice: Int,
+ var values: Seq[T])
+ extends Partition with Serializable {
+
+ def iterator: Iterator[T] = values.iterator
+
+ override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt
+
+ override def equals(other: Any): Boolean = other match {
+ case that: ParallelCollectionPartition[_] => (this.rddId == that.rddId && this.slice == that.slice)
+ case _ => false
+ }
+
+ override def index: Int = slice
+
+ @throws(classOf[IOException])
+ private def writeObject(out: ObjectOutputStream): Unit = {
+
+ val sfactory = SparkEnv.get.serializer
+
+ // Treat java serializer with default action rather than going thru serialization, to avoid a
+ // separate serialization header.
+
+ sfactory match {
+ case js: JavaSerializer => out.defaultWriteObject()
+ case _ =>
+ out.writeLong(rddId)
+ out.writeInt(slice)
+
+ val ser = sfactory.newInstance()
+ Utils.serializeViaNestedStream(out, ser)(_.writeObject(values))
+ }
+ }
+
+ @throws(classOf[IOException])
+ private def readObject(in: ObjectInputStream): Unit = {
+
+ val sfactory = SparkEnv.get.serializer
+ sfactory match {
+ case js: JavaSerializer => in.defaultReadObject()
+ case _ =>
+ rddId = in.readLong()
+ slice = in.readInt()
+
+ val ser = sfactory.newInstance()
+ Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject())
+ }
+ }
+}
+
+private[spark] class ParallelCollectionRDD[T: ClassManifest](
+ @transient sc: SparkContext,
+ @transient data: Seq[T],
+ numSlices: Int,
+ locationPrefs: Map[Int, Seq[String]])
+ extends RDD[T](sc, Nil) {
+ // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
+ // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
+ // instead.
+ // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal.
+
+ override def getPartitions: Array[Partition] = {
+ val slices = ParallelCollectionRDD.slice(data, numSlices).toArray
+ slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
+ }
+
+ override def compute(s: Partition, context: TaskContext) =
+ s.asInstanceOf[ParallelCollectionPartition[T]].iterator
+
+ override def getPreferredLocations(s: Partition): Seq[String] = {
+ locationPrefs.getOrElse(s.index, Nil)
+ }
+}
+
+private object ParallelCollectionRDD {
+ /**
+ * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
+ * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
+ * it efficient to run Spark over RDDs representing large sets of numbers.
+ */
+ def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
+ if (numSlices < 1) {
+ throw new IllegalArgumentException("Positive number of slices required")
+ }
+ seq match {
+ case r: Range.Inclusive => {
+ val sign = if (r.step < 0) {
+ -1
+ } else {
+ 1
+ }
+ slice(new Range(
+ r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
+ }
+ case r: Range => {
+ (0 until numSlices).map(i => {
+ val start = ((i * r.length.toLong) / numSlices).toInt
+ val end = (((i + 1) * r.length.toLong) / numSlices).toInt
+ new Range(r.start + start * r.step, r.start + end * r.step, r.step)
+ }).asInstanceOf[Seq[Seq[T]]]
+ }
+ case nr: NumericRange[_] => {
+ // For ranges of Long, Double, BigInteger, etc
+ val slices = new ArrayBuffer[Seq[T]](numSlices)
+ val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything
+ var r = nr
+ for (i <- 0 until numSlices) {
+ slices += r.take(sliceSize).asInstanceOf[Seq[T]]
+ r = r.drop(sliceSize)
+ }
+ slices
+ }
+ case _ => {
+ val array = seq.toArray // To prevent O(n^2) operations for List etc
+ (0 until numSlices).map(i => {
+ val start = ((i * array.length.toLong) / numSlices).toInt
+ val end = (((i + 1) * array.length.toLong) / numSlices).toInt
+ array.slice(start, end).toSeq
+ })
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
new file mode 100644
index 0000000000..8e79a5c874
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{NarrowDependency, RDD, SparkEnv, Partition, TaskContext}
+
+
+class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends Partition {
+ override val index = idx
+}
+
+
+/**
+ * Represents a dependency between the PartitionPruningRDD and its parent. In this
+ * case, the child RDD contains a subset of partitions of the parents'.
+ */
+class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
+ extends NarrowDependency[T](rdd) {
+
+ @transient
+ val partitions: Array[Partition] = rdd.partitions.zipWithIndex
+ .filter(s => partitionFilterFunc(s._2))
+ .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition }
+
+ override def getParents(partitionId: Int) = List(partitions(partitionId).index)
+}
+
+
+/**
+ * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on
+ * all partitions. An example use case: If we know the RDD is partitioned by range,
+ * and the execution DAG has a filter on the key, we can avoid launching tasks
+ * on partitions that don't have the range covering the key.
+ */
+class PartitionPruningRDD[T: ClassManifest](
+ @transient prev: RDD[T],
+ @transient partitionFilterFunc: Int => Boolean)
+ extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
+
+ override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator(
+ split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context)
+
+ override protected def getPartitions: Array[Partition] =
+ getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
+}
+
+
+object PartitionPruningRDD {
+
+ /**
+ * Create a PartitionPruningRDD. This function can be used to create the PartitionPruningRDD
+ * when its type T is not known at compile time.
+ */
+ def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = {
+ new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassManifest)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
new file mode 100644
index 0000000000..98498d5ddf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.io.PrintWriter
+import java.util.StringTokenizer
+
+import scala.collection.Map
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+import scala.io.Source
+
+import org.apache.spark.{RDD, SparkEnv, Partition, TaskContext}
+import org.apache.spark.broadcast.Broadcast
+
+
+/**
+ * An RDD that pipes the contents of each parent partition through an external command
+ * (printing them one per line) and returns the output as a collection of strings.
+ */
+class PipedRDD[T: ClassManifest](
+ prev: RDD[T],
+ command: Seq[String],
+ envVars: Map[String, String],
+ printPipeContext: (String => Unit) => Unit,
+ printRDDElement: (T, String => Unit) => Unit)
+ extends RDD[String](prev) {
+
+ // Similar to Runtime.exec(), if we are given a single string, split it into words
+ // using a standard StringTokenizer (i.e. by spaces)
+ def this(
+ prev: RDD[T],
+ command: String,
+ envVars: Map[String, String] = Map(),
+ printPipeContext: (String => Unit) => Unit = null,
+ printRDDElement: (T, String => Unit) => Unit = null) =
+ this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement)
+
+
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+ override def compute(split: Partition, context: TaskContext): Iterator[String] = {
+ val pb = new ProcessBuilder(command)
+ // Add the environmental variables to the process.
+ val currentEnvVars = pb.environment()
+ envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) }
+
+ val proc = pb.start()
+ val env = SparkEnv.get
+
+ // Start a thread to print the process's stderr to ours
+ new Thread("stderr reader for " + command) {
+ override def run() {
+ for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
+ System.err.println(line)
+ }
+ }
+ }.start()
+
+ // Start a thread to feed the process input from our parent's iterator
+ new Thread("stdin writer for " + command) {
+ override def run() {
+ SparkEnv.set(env)
+ val out = new PrintWriter(proc.getOutputStream)
+
+ // input the pipe context firstly
+ if (printPipeContext != null) {
+ printPipeContext(out.println(_))
+ }
+ for (elem <- firstParent[T].iterator(split, context)) {
+ if (printRDDElement != null) {
+ printRDDElement(elem, out.println(_))
+ } else {
+ out.println(elem)
+ }
+ }
+ out.close()
+ }
+ }.start()
+
+ // Return an iterator that read lines from the process's stdout
+ val lines = Source.fromInputStream(proc.getInputStream).getLines
+ return new Iterator[String] {
+ def next() = lines.next()
+ def hasNext = {
+ if (lines.hasNext) {
+ true
+ } else {
+ val exitStatus = proc.waitFor()
+ if (exitStatus != 0) {
+ throw new Exception("Subprocess exited with status " + exitStatus)
+ }
+ false
+ }
+ }
+ }
+ }
+}
+
+object PipedRDD {
+ // Split a string into words using a standard StringTokenizer
+ def tokenize(command: String): Seq[String] = {
+ val buf = new ArrayBuffer[String]
+ val tok = new StringTokenizer(command)
+ while(tok.hasMoreElements) {
+ buf += tok.nextToken()
+ }
+ buf
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
new file mode 100644
index 0000000000..1e8d89e912
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.util.Random
+
+import cern.jet.random.Poisson
+import cern.jet.random.engine.DRand
+
+import org.apache.spark.{RDD, Partition, TaskContext}
+
+private[spark]
+class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable {
+ override val index: Int = prev.index
+}
+
+class SampledRDD[T: ClassManifest](
+ prev: RDD[T],
+ withReplacement: Boolean,
+ frac: Double,
+ seed: Int)
+ extends RDD[T](prev) {
+
+ override def getPartitions: Array[Partition] = {
+ val rg = new Random(seed)
+ firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt))
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] =
+ firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev)
+
+ override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = {
+ val split = splitIn.asInstanceOf[SampledRDDPartition]
+ if (withReplacement) {
+ // For large datasets, the expected number of occurrences of each element in a sample with
+ // replacement is Poisson(frac). We use that to get a count for each element.
+ val poisson = new Poisson(frac, new DRand(split.seed))
+ firstParent[T].iterator(split.prev, context).flatMap { element =>
+ val count = poisson.nextInt()
+ if (count == 0) {
+ Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
+ } else {
+ Iterator.fill(count)(element)
+ }
+ }
+ } else { // Sampling without replacement
+ val rand = new Random(split.seed)
+ firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac))
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
new file mode 100644
index 0000000000..f0e9ab8b80
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{Dependency, Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext}
+
+
+private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
+ override val index = idx
+ override def hashCode(): Int = idx
+}
+
+/**
+ * The resulting RDD from a shuffle (e.g. repartitioning of data).
+ * @param prev the parent RDD.
+ * @param part the partitioner used to partition the RDD
+ * @tparam K the key class.
+ * @tparam V the value class.
+ */
+class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest](
+ @transient var prev: RDD[P],
+ part: Partitioner)
+ extends RDD[P](prev.context, Nil) {
+
+ private var serializerClass: String = null
+
+ def setSerializer(cls: String): ShuffledRDD[K, V, P] = {
+ serializerClass = cls
+ this
+ }
+
+ override def getDependencies: Seq[Dependency[_]] = {
+ List(new ShuffleDependency(prev, part, serializerClass))
+ }
+
+ override val partitioner = Some(part)
+
+ override def getPartitions: Array[Partition] = {
+ Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
+ }
+
+ override def compute(split: Partition, context: TaskContext): Iterator[P] = {
+ val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
+ SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics,
+ SparkEnv.get.serializerManager.get(serializerClass))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ prev = null
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
new file mode 100644
index 0000000000..7369dfaa74
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.util.{HashMap => JHashMap}
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.RDD
+import org.apache.spark.Partitioner
+import org.apache.spark.Dependency
+import org.apache.spark.TaskContext
+import org.apache.spark.Partition
+import org.apache.spark.SparkEnv
+import org.apache.spark.ShuffleDependency
+import org.apache.spark.OneToOneDependency
+
+
+/**
+ * An optimized version of cogroup for set difference/subtraction.
+ *
+ * It is possible to implement this operation with just `cogroup`, but
+ * that is less efficient because all of the entries from `rdd2`, for
+ * both matching and non-matching values in `rdd1`, are kept in the
+ * JHashMap until the end.
+ *
+ * With this implementation, only the entries from `rdd1` are kept in-memory,
+ * and the entries from `rdd2` are essentially streamed, as we only need to
+ * touch each once to decide if the value needs to be removed.
+ *
+ * This is particularly helpful when `rdd1` is much smaller than `rdd2`, as
+ * you can use `rdd1`'s partitioner/partition size and not worry about running
+ * out of memory because of the size of `rdd2`.
+ */
+private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
+ @transient var rdd1: RDD[_ <: Product2[K, V]],
+ @transient var rdd2: RDD[_ <: Product2[K, W]],
+ part: Partitioner)
+ extends RDD[(K, V)](rdd1.context, Nil) {
+
+ private var serializerClass: String = null
+
+ def setSerializer(cls: String): SubtractedRDD[K, V, W] = {
+ serializerClass = cls
+ this
+ }
+
+ override def getDependencies: Seq[Dependency[_]] = {
+ Seq(rdd1, rdd2).map { rdd =>
+ if (rdd.partitioner == Some(part)) {
+ logDebug("Adding one-to-one dependency with " + rdd)
+ new OneToOneDependency(rdd)
+ } else {
+ logDebug("Adding shuffle dependency with " + rdd)
+ new ShuffleDependency(rdd, part, serializerClass)
+ }
+ }
+ }
+
+ override def getPartitions: Array[Partition] = {
+ val array = new Array[Partition](part.numPartitions)
+ for (i <- 0 until array.size) {
+ // Each CoGroupPartition will depend on rdd1 and rdd2
+ array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
+ dependencies(j) match {
+ case s: ShuffleDependency[_, _] =>
+ new ShuffleCoGroupSplitDep(s.shuffleId)
+ case _ =>
+ new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
+ }
+ }.toArray)
+ }
+ array
+ }
+
+ override val partitioner = Some(part)
+
+ override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
+ val partition = p.asInstanceOf[CoGroupPartition]
+ val serializer = SparkEnv.get.serializerManager.get(serializerClass)
+ val map = new JHashMap[K, ArrayBuffer[V]]
+ def getSeq(k: K): ArrayBuffer[V] = {
+ val seq = map.get(k)
+ if (seq != null) {
+ seq
+ } else {
+ val seq = new ArrayBuffer[V]()
+ map.put(k, seq)
+ seq
+ }
+ }
+ def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match {
+ case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
+ rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
+ }
+ case ShuffleCoGroupSplitDep(shuffleId) => {
+ val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
+ context.taskMetrics, serializer)
+ iter.foreach(op)
+ }
+ }
+ // the first dep is rdd1; add all values to the map
+ integrate(partition.deps(0), t => getSeq(t._1) += t._2)
+ // the second dep is rdd2; remove all of its keys
+ integrate(partition.deps(1), t => map.remove(t._1))
+ map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
new file mode 100644
index 0000000000..fd02476b62
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.{Dependency, RangeDependency, RDD, SparkContext, Partition, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+
+private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
+ extends Partition {
+
+ var split: Partition = rdd.partitions(splitIndex)
+
+ def iterator(context: TaskContext) = rdd.iterator(split, context)
+
+ def preferredLocations() = rdd.preferredLocations(split)
+
+ override val index: Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ split = rdd.partitions(splitIndex)
+ oos.defaultWriteObject()
+ }
+}
+
+class UnionRDD[T: ClassManifest](
+ sc: SparkContext,
+ @transient var rdds: Seq[RDD[T]])
+ extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
+
+ override def getPartitions: Array[Partition] = {
+ val array = new Array[Partition](rdds.map(_.partitions.size).sum)
+ var pos = 0
+ for (rdd <- rdds; split <- rdd.partitions) {
+ array(pos) = new UnionPartition(pos, rdd, split.index)
+ pos += 1
+ }
+ array
+ }
+
+ override def getDependencies: Seq[Dependency[_]] = {
+ val deps = new ArrayBuffer[Dependency[_]]
+ var pos = 0
+ for (rdd <- rdds) {
+ deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size)
+ pos += rdd.partitions.size
+ }
+ deps
+ }
+
+ override def compute(s: Partition, context: TaskContext): Iterator[T] =
+ s.asInstanceOf[UnionPartition[T]].iterator(context)
+
+ override def getPreferredLocations(s: Partition): Seq[String] =
+ s.asInstanceOf[UnionPartition[T]].preferredLocations()
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
new file mode 100644
index 0000000000..5ae1db3e67
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+
+private[spark] class ZippedPartitionsPartition(
+ idx: Int,
+ @transient rdds: Seq[RDD[_]])
+ extends Partition {
+
+ override val index: Int = idx
+ var partitionValues = rdds.map(rdd => rdd.partitions(idx))
+ def partitions = partitionValues
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent split at the time of task serialization
+ partitionValues = rdds.map(rdd => rdd.partitions(idx))
+ oos.defaultWriteObject()
+ }
+}
+
+abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
+ sc: SparkContext,
+ var rdds: Seq[RDD[_]])
+ extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
+
+ override def getPartitions: Array[Partition] = {
+ val sizes = rdds.map(x => x.partitions.size)
+ if (!sizes.forall(x => x == sizes(0))) {
+ throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
+ }
+ val array = new Array[Partition](sizes(0))
+ for (i <- 0 until sizes(0)) {
+ array(i) = new ZippedPartitionsPartition(i, rdds)
+ }
+ array
+ }
+
+ override def getPreferredLocations(s: Partition): Seq[String] = {
+ val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) }
+ // Check whether there are any hosts that match all RDDs; otherwise return the union
+ val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y))
+ if (!exactMatchLocations.isEmpty) {
+ exactMatchLocations
+ } else {
+ prefs.flatten.distinct
+ }
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdds = null
+ }
+}
+
+class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ }
+}
+
+class ZippedPartitionsRDD3
+ [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B],
+ var rdd3: RDD[C])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context),
+ rdd2.iterator(partitions(1), context),
+ rdd3.iterator(partitions(2), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ rdd3 = null
+ }
+}
+
+class ZippedPartitionsRDD4
+ [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest](
+ sc: SparkContext,
+ f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
+ var rdd1: RDD[A],
+ var rdd2: RDD[B],
+ var rdd3: RDD[C],
+ var rdd4: RDD[D])
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) {
+
+ override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
+ f(rdd1.iterator(partitions(0), context),
+ rdd2.iterator(partitions(1), context),
+ rdd3.iterator(partitions(2), context),
+ rdd4.iterator(partitions(3), context))
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ rdd3 = null
+ rdd4 = null
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
new file mode 100644
index 0000000000..3bd00d291b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.spark.{Utils, OneToOneDependency, RDD, SparkContext, Partition, TaskContext}
+import java.io.{ObjectOutputStream, IOException}
+
+
+private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest](
+ idx: Int,
+ @transient rdd1: RDD[T],
+ @transient rdd2: RDD[U]
+ ) extends Partition {
+
+ var partition1 = rdd1.partitions(idx)
+ var partition2 = rdd2.partitions(idx)
+ override val index: Int = idx
+
+ def partitions = (partition1, partition2)
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent partition at the time of task serialization
+ partition1 = rdd1.partitions(idx)
+ partition2 = rdd2.partitions(idx)
+ oos.defaultWriteObject()
+ }
+}
+
+class ZippedRDD[T: ClassManifest, U: ClassManifest](
+ sc: SparkContext,
+ var rdd1: RDD[T],
+ var rdd2: RDD[U])
+ extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) {
+
+ override def getPartitions: Array[Partition] = {
+ if (rdd1.partitions.size != rdd2.partitions.size) {
+ throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
+ }
+ val array = new Array[Partition](rdd1.partitions.size)
+ for (i <- 0 until rdd1.partitions.size) {
+ array(i) = new ZippedPartition(i, rdd1, rdd2)
+ }
+ array
+ }
+
+ override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = {
+ val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
+ rdd1.iterator(partition1, context).zip(rdd2.iterator(partition2, context))
+ }
+
+ override def getPreferredLocations(s: Partition): Seq[String] = {
+ val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
+ val pref1 = rdd1.preferredLocations(partition1)
+ val pref2 = rdd2.preferredLocations(partition2)
+ // Check whether there are any hosts that match both RDDs; otherwise return the union
+ val exactMatchLocations = pref1.intersect(pref2)
+ if (!exactMatchLocations.isEmpty) {
+ exactMatchLocations
+ } else {
+ (pref1 ++ pref2).distinct
+ }
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdd1 = null
+ rdd2 = null
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
new file mode 100644
index 0000000000..0b04607d01
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.TaskContext
+
+import java.util.Properties
+
+/**
+ * Tracks information about an active job in the DAGScheduler.
+ */
+private[spark] class ActiveJob(
+ val jobId: Int,
+ val finalStage: Stage,
+ val func: (TaskContext, Iterator[_]) => _,
+ val partitions: Array[Int],
+ val callSite: String,
+ val listener: JobListener,
+ val properties: Properties) {
+
+ val numPartitions = partitions.length
+ val finished = Array.fill[Boolean](numPartitions)(false)
+ var numFinished = 0
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
new file mode 100644
index 0000000000..5ac700bbf4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -0,0 +1,849 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.NotSerializableException
+import java.util.Properties
+import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+
+import org.apache.spark._
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
+import org.apache.spark.scheduler.cluster.TaskInfo
+import org.apache.spark.storage.{BlockManager, BlockManagerMaster}
+import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
+
+/**
+ * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
+ * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a
+ * minimal schedule to run the job. It then submits stages as TaskSets to an underlying
+ * TaskScheduler implementation that runs them on the cluster.
+ *
+ * In addition to coming up with a DAG of stages, this class also determines the preferred
+ * locations to run each task on, based on the current cache status, and passes these to the
+ * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being
+ * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are
+ * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task
+ * a small number of times before cancelling the whole stage.
+ *
+ * THREADING: This class runs all its logic in a single thread executing the run() method, to which
+ * events are submitted using a synchonized queue (eventQueue). The public API methods, such as
+ * runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods
+ * should be private.
+ */
+private[spark]
+class DAGScheduler(
+ taskSched: TaskScheduler,
+ mapOutputTracker: MapOutputTracker,
+ blockManagerMaster: BlockManagerMaster,
+ env: SparkEnv)
+ extends TaskSchedulerListener with Logging {
+
+ def this(taskSched: TaskScheduler) {
+ this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ }
+ taskSched.setListener(this)
+
+ // Called by TaskScheduler to report task's starting.
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ eventQueue.put(BeginEvent(task, taskInfo))
+ }
+
+ // Called by TaskScheduler to report task completions or failures.
+ override def taskEnded(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: Map[Long, Any],
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) {
+ eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
+ }
+
+ // Called by TaskScheduler when an executor fails.
+ override def executorLost(execId: String) {
+ eventQueue.put(ExecutorLost(execId))
+ }
+
+ // Called by TaskScheduler when a host is added
+ override def executorGained(execId: String, host: String) {
+ eventQueue.put(ExecutorGained(execId, host))
+ }
+
+ // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
+ override def taskSetFailed(taskSet: TaskSet, reason: String) {
+ eventQueue.put(TaskSetFailed(taskSet, reason))
+ }
+
+ // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
+ // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
+ // as more failure events come in
+ val RESUBMIT_TIMEOUT = 50L
+
+ // The time, in millis, to wake up between polls of the completion queue in order to potentially
+ // resubmit failed stages
+ val POLL_TIMEOUT = 10L
+
+ private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
+
+ val nextJobId = new AtomicInteger(0)
+
+ val nextStageId = new AtomicInteger(0)
+
+ val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+
+ val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+
+ private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
+
+ private val listenerBus = new SparkListenerBus()
+
+ // Contains the locations that each RDD's partitions are cached on
+ private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
+
+ // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
+ // every task. When we detect a node failing, we note the current epoch number and failed
+ // executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask results.
+ //
+ // TODO: Garbage collect information about failure epochs when we know there are no more
+ // stray messages to detect.
+ val failedEpoch = new HashMap[String, Long]
+
+ val idToActiveJob = new HashMap[Int, ActiveJob]
+
+ val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
+ val running = new HashSet[Stage] // Stages we are running right now
+ val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
+ val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage
+ var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits
+
+ val activeJobs = new HashSet[ActiveJob]
+ val resultStageToJob = new HashMap[Stage, ActiveJob]
+
+ val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
+
+ // Start a thread to run the DAGScheduler event loop
+ def start() {
+ new Thread("DAGScheduler") {
+ setDaemon(true)
+ override def run() {
+ DAGScheduler.this.run()
+ }
+ }.start()
+ }
+
+ def addSparkListener(listener: SparkListener) {
+ listenerBus.addListener(listener)
+ }
+
+ private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
+ if (!cacheLocs.contains(rdd.id)) {
+ val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
+ val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
+ cacheLocs(rdd.id) = blockIds.map { id =>
+ locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
+ }
+ }
+ cacheLocs(rdd.id)
+ }
+
+ private def clearCacheLocs() {
+ cacheLocs.clear()
+ }
+
+ /**
+ * Get or create a shuffle map stage for the given shuffle dependency's map side.
+ * The jobId value passed in will be used if the stage doesn't already exist with
+ * a lower jobId (jobId always increases across jobs.)
+ */
+ private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = {
+ shuffleToMapStage.get(shuffleDep.shuffleId) match {
+ case Some(stage) => stage
+ case None =>
+ val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
+ shuffleToMapStage(shuffleDep.shuffleId) = stage
+ stage
+ }
+ }
+
+ /**
+ * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or
+ * as a result stage for the final RDD used directly in an action. The stage will also be
+ * associated with the provided jobId.
+ */
+ private def newStage(
+ rdd: RDD[_],
+ shuffleDep: Option[ShuffleDependency[_,_]],
+ jobId: Int,
+ callSite: Option[String] = None)
+ : Stage =
+ {
+ if (shuffleDep != None) {
+ // Kind of ugly: need to register RDDs with the cache and map output tracker here
+ // since we can't do it in the RDD constructor because # of partitions is unknown
+ logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
+ mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
+ }
+ val id = nextStageId.getAndIncrement()
+ val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
+ stageIdToStage(id) = stage
+ stageToInfos(stage) = StageInfo(stage)
+ stage
+ }
+
+ /**
+ * Get or create the list of parent stages for a given RDD. The stages will be assigned the
+ * provided jobId if they haven't already been created with a lower jobId.
+ */
+ private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = {
+ val parents = new HashSet[Stage]
+ val visited = new HashSet[RDD[_]]
+ def visit(r: RDD[_]) {
+ if (!visited(r)) {
+ visited += r
+ // Kind of ugly: need to register RDDs with the cache here since
+ // we can't do it in its constructor because # of partitions is unknown
+ for (dep <- r.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ parents += getShuffleMapStage(shufDep, jobId)
+ case _ =>
+ visit(dep.rdd)
+ }
+ }
+ }
+ }
+ visit(rdd)
+ parents.toList
+ }
+
+ private def getMissingParentStages(stage: Stage): List[Stage] = {
+ val missing = new HashSet[Stage]
+ val visited = new HashSet[RDD[_]]
+ def visit(rdd: RDD[_]) {
+ if (!visited(rdd)) {
+ visited += rdd
+ if (getCacheLocs(rdd).contains(Nil)) {
+ for (dep <- rdd.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val mapStage = getShuffleMapStage(shufDep, stage.jobId)
+ if (!mapStage.isAvailable) {
+ missing += mapStage
+ }
+ case narrowDep: NarrowDependency[_] =>
+ visit(narrowDep.rdd)
+ }
+ }
+ }
+ }
+ }
+ visit(stage.rdd)
+ missing.toList
+ }
+
+ /**
+ * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
+ * JobWaiter whose getResult() method will return the result of the job when it is complete.
+ *
+ * The job is assumed to have at least one partition; zero partition jobs should be handled
+ * without a JobSubmitted event.
+ */
+ private[scheduler] def prepareJob[T, U: ClassManifest](
+ finalRdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ callSite: String,
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit,
+ properties: Properties = null)
+ : (JobSubmitted, JobWaiter[U]) =
+ {
+ assert(partitions.size > 0)
+ val waiter = new JobWaiter(partitions.size, resultHandler)
+ val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+ val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
+ properties)
+ (toSubmit, waiter)
+ }
+
+ def runJob[T, U: ClassManifest](
+ finalRdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ partitions: Seq[Int],
+ callSite: String,
+ allowLocal: Boolean,
+ resultHandler: (Int, U) => Unit,
+ properties: Properties = null)
+ {
+ if (partitions.size == 0) {
+ return
+ }
+
+ // Check to make sure we are not launching a task on a partition that does not exist.
+ val maxPartitions = finalRdd.partitions.length
+ partitions.find(p => p >= maxPartitions).foreach { p =>
+ throw new IllegalArgumentException(
+ "Attempting to access a non-existent partition: " + p + ". " +
+ "Total number of partitions: " + maxPartitions)
+ }
+
+ val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
+ finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
+ eventQueue.put(toSubmit)
+ waiter.awaitResult() match {
+ case JobSucceeded => {}
+ case JobFailed(exception: Exception, _) =>
+ logInfo("Failed to run " + callSite)
+ throw exception
+ }
+ }
+
+ def runApproximateJob[T, U, R](
+ rdd: RDD[T],
+ func: (TaskContext, Iterator[T]) => U,
+ evaluator: ApproximateEvaluator[U, R],
+ callSite: String,
+ timeout: Long,
+ properties: Properties = null)
+ : PartialResult[R] =
+ {
+ val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
+ val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
+ val partitions = (0 until rdd.partitions.size).toArray
+ eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
+ listener.awaitResult() // Will throw an exception if the job fails
+ }
+
+ /**
+ * Process one event retrieved from the event queue.
+ * Returns true if we should stop the event loop.
+ */
+ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
+ event match {
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
+ val jobId = nextJobId.getAndIncrement()
+ val finalStage = newStage(finalRDD, None, jobId, Some(callSite))
+ val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
+ clearCacheLocs()
+ logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
+ " output partitions (allowLocal=" + allowLocal + ")")
+ logInfo("Final stage: " + finalStage + " (" + finalStage.name + ")")
+ logInfo("Parents of final stage: " + finalStage.parents)
+ logInfo("Missing parents: " + getMissingParentStages(finalStage))
+ if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
+ // Compute very short actions like first() or take() with no parent stages locally.
+ runLocally(job)
+ } else {
+ listenerBus.post(SparkListenerJobStart(job, properties))
+ idToActiveJob(jobId) = job
+ activeJobs += job
+ resultStageToJob(finalStage) = job
+ submitStage(finalStage)
+ }
+
+ case ExecutorGained(execId, host) =>
+ handleExecutorGained(execId, host)
+
+ case ExecutorLost(execId) =>
+ handleExecutorLost(execId)
+
+ case begin: BeginEvent =>
+ listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
+
+ case completion: CompletionEvent =>
+ listenerBus.post(SparkListenerTaskEnd(
+ completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
+ handleTaskCompletion(completion)
+
+ case TaskSetFailed(taskSet, reason) =>
+ abortStage(stageIdToStage(taskSet.stageId), reason)
+
+ case StopDAGScheduler =>
+ // Cancel any active jobs
+ for (job <- activeJobs) {
+ val error = new SparkException("Job cancelled because SparkContext was shut down")
+ job.listener.jobFailed(error)
+ listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, None)))
+ }
+ return true
+ }
+ false
+ }
+
+ /**
+ * Resubmit any failed stages. Ordinarily called after a small amount of time has passed since
+ * the last fetch failure.
+ */
+ private[scheduler] def resubmitFailedStages() {
+ logInfo("Resubmitting failed stages")
+ clearCacheLocs()
+ val failed2 = failed.toArray
+ failed.clear()
+ for (stage <- failed2.sortBy(_.jobId)) {
+ submitStage(stage)
+ }
+ }
+
+ /**
+ * Check for waiting or failed stages which are now eligible for resubmission.
+ * Ordinarily run on every iteration of the event loop.
+ */
+ private[scheduler] def submitWaitingStages() {
+ // TODO: We might want to run this less often, when we are sure that something has become
+ // runnable that wasn't before.
+ logTrace("Checking for newly runnable parent stages")
+ logTrace("running: " + running)
+ logTrace("waiting: " + waiting)
+ logTrace("failed: " + failed)
+ val waiting2 = waiting.toArray
+ waiting.clear()
+ for (stage <- waiting2.sortBy(_.jobId)) {
+ submitStage(stage)
+ }
+ }
+
+
+ /**
+ * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
+ * events and responds by launching tasks. This runs in a dedicated thread and receives events
+ * via the eventQueue.
+ */
+ private def run() {
+ SparkEnv.set(env)
+
+ while (true) {
+ val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
+ if (event != null) {
+ logDebug("Got event of type " + event.getClass.getName)
+ }
+ this.synchronized { // needed in case other threads makes calls into methods of this class
+ if (event != null) {
+ if (processEvent(event)) {
+ return
+ }
+ }
+
+ val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
+ // Periodically resubmit failed stages if some map output fetches have failed and we have
+ // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
+ // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
+ // the same time, so we want to make sure we've identified all the reduce tasks that depend
+ // on the failed node.
+ if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
+ resubmitFailedStages()
+ } else {
+ submitWaitingStages()
+ }
+ }
+ }
+ }
+
+ /**
+ * Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
+ * We run the operation in a separate thread just in case it takes a bunch of time, so that we
+ * don't block the DAGScheduler event loop or other concurrent jobs.
+ */
+ protected def runLocally(job: ActiveJob) {
+ logInfo("Computing the requested partition locally")
+ new Thread("Local computation of job " + job.jobId) {
+ override def run() {
+ runLocallyWithinThread(job)
+ }
+ }.start()
+ }
+
+ // Broken out for easier testing in DAGSchedulerSuite.
+ protected def runLocallyWithinThread(job: ActiveJob) {
+ try {
+ SparkEnv.set(env)
+ val rdd = job.finalStage.rdd
+ val split = rdd.partitions(job.partitions(0))
+ val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
+ try {
+ val result = job.func(taskContext, rdd.iterator(split, taskContext))
+ job.listener.taskSucceeded(0, result)
+ } finally {
+ taskContext.executeOnCompleteCallbacks()
+ }
+ } catch {
+ case e: Exception =>
+ job.listener.jobFailed(e)
+ }
+ }
+
+ /** Submits stage, but first recursively submits any missing parents. */
+ private def submitStage(stage: Stage) {
+ logDebug("submitStage(" + stage + ")")
+ if (!waiting(stage) && !running(stage) && !failed(stage)) {
+ val missing = getMissingParentStages(stage).sortBy(_.id)
+ logDebug("missing: " + missing)
+ if (missing == Nil) {
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
+ submitMissingTasks(stage)
+ running += stage
+ } else {
+ for (parent <- missing) {
+ submitStage(parent)
+ }
+ waiting += stage
+ }
+ }
+ }
+
+ /** Called when stage's parents are available and we can now do its task. */
+ private def submitMissingTasks(stage: Stage) {
+ logDebug("submitMissingTasks(" + stage + ")")
+ // Get our pending tasks and remember them in our pendingTasks entry
+ val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
+ myPending.clear()
+ var tasks = ArrayBuffer[Task[_]]()
+ if (stage.isShuffleMap) {
+ for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
+ val locs = getPreferredLocs(stage.rdd, p)
+ tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
+ }
+ } else {
+ // This is a final stage; figure out its job's missing partitions
+ val job = resultStageToJob(stage)
+ for (id <- 0 until job.numPartitions if !job.finished(id)) {
+ val partition = job.partitions(id)
+ val locs = getPreferredLocs(stage.rdd, partition)
+ tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
+ }
+ }
+ // must be run listener before possible NotSerializableException
+ // should be "StageSubmitted" first and then "JobEnded"
+ val properties = idToActiveJob(stage.jobId).properties
+ listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
+
+ if (tasks.size > 0) {
+ // Preemptively serialize a task to make sure it can be serialized. We are catching this
+ // exception here because it would be fairly hard to catch the non-serializable exception
+ // down the road, where we have several different implementations for local scheduler and
+ // cluster schedulers.
+ try {
+ SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
+ } catch {
+ case e: NotSerializableException =>
+ abortStage(stage, e.toString)
+ running -= stage
+ return
+ }
+
+ logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
+ myPending ++= tasks
+ logDebug("New pending tasks: " + myPending)
+ taskSched.submitTasks(
+ new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
+ if (!stage.submissionTime.isDefined) {
+ stage.submissionTime = Some(System.currentTimeMillis())
+ }
+ } else {
+ logDebug("Stage " + stage + " is actually done; %b %d %d".format(
+ stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
+ running -= stage
+ }
+ }
+
+ /**
+ * Responds to a task finishing. This is called inside the event loop so it assumes that it can
+ * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
+ */
+ private def handleTaskCompletion(event: CompletionEvent) {
+ val task = event.task
+ val stage = stageIdToStage(task.stageId)
+
+ def markStageAsFinished(stage: Stage) = {
+ val serviceTime = stage.submissionTime match {
+ case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
+ case _ => "Unkown"
+ }
+ logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
+ stage.completionTime = Some(System.currentTimeMillis)
+ listenerBus.post(StageCompleted(stageToInfos(stage)))
+ running -= stage
+ }
+ event.reason match {
+ case Success =>
+ logInfo("Completed " + task)
+ if (event.accumUpdates != null) {
+ Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
+ }
+ pendingTasks(stage) -= task
+ stageToInfos(stage).taskInfos += event.taskInfo -> event.taskMetrics
+ task match {
+ case rt: ResultTask[_, _] =>
+ resultStageToJob.get(stage) match {
+ case Some(job) =>
+ if (!job.finished(rt.outputId)) {
+ job.finished(rt.outputId) = true
+ job.numFinished += 1
+ // If the whole job has finished, remove it
+ if (job.numFinished == job.numPartitions) {
+ idToActiveJob -= stage.jobId
+ activeJobs -= job
+ resultStageToJob -= stage
+ markStageAsFinished(stage)
+ listenerBus.post(SparkListenerJobEnd(job, JobSucceeded))
+ }
+ job.listener.taskSucceeded(rt.outputId, event.result)
+ }
+ case None =>
+ logInfo("Ignoring result from " + rt + " because its job has finished")
+ }
+
+ case smt: ShuffleMapTask =>
+ val status = event.result.asInstanceOf[MapStatus]
+ val execId = status.location.executorId
+ logDebug("ShuffleMapTask finished on " + execId)
+ if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
+ logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
+ } else {
+ stage.addOutputLoc(smt.partition, status)
+ }
+ if (running.contains(stage) && pendingTasks(stage).isEmpty) {
+ markStageAsFinished(stage)
+ logInfo("looking for newly runnable stages")
+ logInfo("running: " + running)
+ logInfo("waiting: " + waiting)
+ logInfo("failed: " + failed)
+ if (stage.shuffleDep != None) {
+ // We supply true to increment the epoch number here in case this is a
+ // recomputation of the map outputs. In that case, some nodes may have cached
+ // locations with holes (from when we detected the error) and will need the
+ // epoch incremented to refetch them.
+ // TODO: Only increment the epoch number if this is not the first time
+ // we registered these map outputs.
+ mapOutputTracker.registerMapOutputs(
+ stage.shuffleDep.get.shuffleId,
+ stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
+ changeEpoch = true)
+ }
+ clearCacheLocs()
+ if (stage.outputLocs.count(_ == Nil) != 0) {
+ // Some tasks had failed; let's resubmit this stage
+ // TODO: Lower-level scheduler should also deal with this
+ logInfo("Resubmitting " + stage + " (" + stage.name +
+ ") because some of its tasks had failed: " +
+ stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", "))
+ submitStage(stage)
+ } else {
+ val newlyRunnable = new ArrayBuffer[Stage]
+ for (stage <- waiting) {
+ logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage))
+ }
+ for (stage <- waiting if getMissingParentStages(stage) == Nil) {
+ newlyRunnable += stage
+ }
+ waiting --= newlyRunnable
+ running ++= newlyRunnable
+ for (stage <- newlyRunnable.sortBy(_.id)) {
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
+ submitMissingTasks(stage)
+ }
+ }
+ }
+ }
+
+ case Resubmitted =>
+ logInfo("Resubmitted " + task + ", so marking it as still running")
+ pendingTasks(stage) += task
+
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ // Mark the stage that the reducer was in as unrunnable
+ val failedStage = stageIdToStage(task.stageId)
+ running -= failedStage
+ failed += failedStage
+ // TODO: Cancel running tasks in the stage
+ logInfo("Marking " + failedStage + " (" + failedStage.name +
+ ") for resubmision due to a fetch failure")
+ // Mark the map whose fetch failed as broken in the map stage
+ val mapStage = shuffleToMapStage(shuffleId)
+ if (mapId != -1) {
+ mapStage.removeOutputLoc(mapId, bmAddress)
+ mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+ }
+ logInfo("The failed fetch was from " + mapStage + " (" + mapStage.name +
+ "); marking it for resubmission")
+ failed += mapStage
+ // Remember that a fetch failed now; this is used to resubmit the broken
+ // stages later, after a small wait (to give other tasks the chance to fail)
+ lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
+ // TODO: mark the executor as failed only if there were lots of fetch failures on it
+ if (bmAddress != null) {
+ handleExecutorLost(bmAddress.executorId, Some(task.epoch))
+ }
+
+ case ExceptionFailure(className, description, stackTrace, metrics) =>
+ // Do nothing here, left up to the TaskScheduler to decide how to handle user failures
+
+ case other =>
+ // Unrecognized failure - abort all jobs depending on this stage
+ abortStage(stageIdToStage(task.stageId), task + " failed: " + other)
+ }
+ }
+
+ /**
+ * Responds to an executor being lost. This is called inside the event loop, so it assumes it can
+ * modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
+ *
+ * Optionally the epoch during which the failure was caught can be passed to avoid allowing
+ * stray fetch failures from possibly retriggering the detection of a node as lost.
+ */
+ private def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) {
+ val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
+ if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
+ failedEpoch(execId) = currentEpoch
+ logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch))
+ blockManagerMaster.removeExecutor(execId)
+ // TODO: This will be really slow if we keep accumulating shuffle map stages
+ for ((shuffleId, stage) <- shuffleToMapStage) {
+ stage.removeOutputsOnExecutor(execId)
+ val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
+ mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
+ }
+ if (shuffleToMapStage.isEmpty) {
+ mapOutputTracker.incrementEpoch()
+ }
+ clearCacheLocs()
+ } else {
+ logDebug("Additional executor lost message for " + execId +
+ "(epoch " + currentEpoch + ")")
+ }
+ }
+
+ private def handleExecutorGained(execId: String, host: String) {
+ // remove from failedEpoch(execId) ?
+ if (failedEpoch.contains(execId)) {
+ logInfo("Host gained which was in lost list earlier: " + host)
+ failedEpoch -= execId
+ }
+ }
+
+ /**
+ * Aborts all jobs depending on a particular Stage. This is called in response to a task set
+ * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
+ */
+ private def abortStage(failedStage: Stage, reason: String) {
+ val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
+ failedStage.completionTime = Some(System.currentTimeMillis())
+ for (resultStage <- dependentStages) {
+ val job = resultStageToJob(resultStage)
+ val error = new SparkException("Job failed: " + reason)
+ job.listener.jobFailed(error)
+ listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
+ idToActiveJob -= resultStage.jobId
+ activeJobs -= job
+ resultStageToJob -= resultStage
+ }
+ if (dependentStages.isEmpty) {
+ logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
+ }
+ }
+
+ /**
+ * Return true if one of stage's ancestors is target.
+ */
+ private def stageDependsOn(stage: Stage, target: Stage): Boolean = {
+ if (stage == target) {
+ return true
+ }
+ val visitedRdds = new HashSet[RDD[_]]
+ val visitedStages = new HashSet[Stage]
+ def visit(rdd: RDD[_]) {
+ if (!visitedRdds(rdd)) {
+ visitedRdds += rdd
+ for (dep <- rdd.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val mapStage = getShuffleMapStage(shufDep, stage.jobId)
+ if (!mapStage.isAvailable) {
+ visitedStages += mapStage
+ visit(mapStage.rdd)
+ } // Otherwise there's no need to follow the dependency back
+ case narrowDep: NarrowDependency[_] =>
+ visit(narrowDep.rdd)
+ }
+ }
+ }
+ }
+ visit(stage.rdd)
+ visitedRdds.contains(target.rdd)
+ }
+
+ /**
+ * Synchronized method that might be called from other threads.
+ * @param rdd whose partitions are to be looked at
+ * @param partition to lookup locality information for
+ * @return list of machines that are preferred by the partition
+ */
+ private[spark]
+ def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized {
+ // If the partition is cached, return the cache locations
+ val cached = getCacheLocs(rdd)(partition)
+ if (!cached.isEmpty) {
+ return cached
+ }
+ // If the RDD has some placement preferences (as is the case for input RDDs), get those
+ val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
+ if (!rddPrefs.isEmpty) {
+ return rddPrefs.map(host => TaskLocation(host))
+ }
+ // If the RDD has narrow dependencies, pick the first partition of the first narrow dep
+ // that has any placement preferences. Ideally we would choose based on transfer sizes,
+ // but this will do for now.
+ rdd.dependencies.foreach(_ match {
+ case n: NarrowDependency[_] =>
+ for (inPart <- n.getParents(partition)) {
+ val locs = getPreferredLocs(n.rdd, inPart)
+ if (locs != Nil)
+ return locs
+ }
+ case _ =>
+ })
+ Nil
+ }
+
+ private def cleanup(cleanupTime: Long) {
+ var sizeBefore = stageIdToStage.size
+ stageIdToStage.clearOldValues(cleanupTime)
+ logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size)
+
+ sizeBefore = shuffleToMapStage.size
+ shuffleToMapStage.clearOldValues(cleanupTime)
+ logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size)
+
+ sizeBefore = pendingTasks.size
+ pendingTasks.clearOldValues(cleanupTime)
+ logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
+
+ sizeBefore = stageToInfos.size
+ stageToInfos.clearOldValues(cleanupTime)
+ logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size)
+ }
+
+ def stop() {
+ eventQueue.put(StopDAGScheduler)
+ metadataCleaner.cancel()
+ taskSched.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
new file mode 100644
index 0000000000..5b07933eed
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+import org.apache.spark.scheduler.cluster.TaskInfo
+import scala.collection.mutable.Map
+
+import org.apache.spark._
+import org.apache.spark.executor.TaskMetrics
+
+/**
+ * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue
+ * architecture where any thread can post an event (e.g. a task finishing or a new job being
+ * submitted) but there is a single "logic" thread that reads these events and takes decisions.
+ * This greatly simplifies synchronization.
+ */
+private[spark] sealed trait DAGSchedulerEvent
+
+private[spark] case class JobSubmitted(
+ finalRDD: RDD[_],
+ func: (TaskContext, Iterator[_]) => _,
+ partitions: Array[Int],
+ allowLocal: Boolean,
+ callSite: String,
+ listener: JobListener,
+ properties: Properties = null)
+ extends DAGSchedulerEvent
+
+private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+
+private[spark] case class CompletionEvent(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: Map[Long, Any],
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics)
+ extends DAGSchedulerEvent
+
+private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
+
+private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
+
+private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+
+private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
new file mode 100644
index 0000000000..ce0dc9093d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
@@ -0,0 +1,30 @@
+package org.apache.spark.scheduler
+
+import com.codahale.metrics.{Gauge,MetricRegistry}
+
+import org.apache.spark.metrics.source.Source
+
+private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source {
+ val metricRegistry = new MetricRegistry()
+ val sourceName = "DAGScheduler"
+
+ metricRegistry.register(MetricRegistry.name("stage", "failedStages", "number"), new Gauge[Int] {
+ override def getValue: Int = dagScheduler.failed.size
+ })
+
+ metricRegistry.register(MetricRegistry.name("stage", "runningStages", "number"), new Gauge[Int] {
+ override def getValue: Int = dagScheduler.running.size
+ })
+
+ metricRegistry.register(MetricRegistry.name("stage", "waitingStages", "number"), new Gauge[Int] {
+ override def getValue: Int = dagScheduler.waiting.size
+ })
+
+ metricRegistry.register(MetricRegistry.name("job", "allJobs", "number"), new Gauge[Int] {
+ override def getValue: Int = dagScheduler.nextJobId.get()
+ })
+
+ metricRegistry.register(MetricRegistry.name("job", "activeJobs", "number"), new Gauge[Int] {
+ override def getValue: Int = dagScheduler.activeJobs.size
+ })
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
new file mode 100644
index 0000000000..370ccd183c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.{Logging, SparkEnv}
+import scala.collection.immutable.Set
+import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.mapreduce.Job
+import org.apache.hadoop.conf.Configuration
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.JavaConversions._
+
+
+/**
+ * Parses and holds information about inputFormat (and files) specified as a parameter.
+ */
+class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_],
+ val path: String) extends Logging {
+
+ var mapreduceInputFormat: Boolean = false
+ var mapredInputFormat: Boolean = false
+
+ validate()
+
+ override def toString(): String = {
+ "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path
+ }
+
+ override def hashCode(): Int = {
+ var hashCode = inputFormatClazz.hashCode
+ hashCode = hashCode * 31 + path.hashCode
+ hashCode
+ }
+
+ // Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path
+ // .. which is fine, this is best case effort to remove duplicates - right ?
+ override def equals(other: Any): Boolean = other match {
+ case that: InputFormatInfo => {
+ // not checking config - that should be fine, right ?
+ this.inputFormatClazz == that.inputFormatClazz &&
+ this.path == that.path
+ }
+ case _ => false
+ }
+
+ private def validate() {
+ logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path)
+
+ try {
+ if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
+ logDebug("inputformat is from mapreduce package")
+ mapreduceInputFormat = true
+ }
+ else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
+ logDebug("inputformat is from mapred package")
+ mapredInputFormat = true
+ }
+ else {
+ throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz +
+ " is NOT a supported input format ? does not implement either of the supported hadoop api's")
+ }
+ }
+ catch {
+ case e: ClassNotFoundException => {
+ throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e)
+ }
+ }
+ }
+
+
+ // This method does not expect failures, since validate has already passed ...
+ private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
+ val env = SparkEnv.get
+ val conf = new JobConf(configuration)
+ env.hadoop.addCredentials(conf)
+ FileInputFormat.setInputPaths(conf, path)
+
+ val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
+ ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[
+ org.apache.hadoop.mapreduce.InputFormat[_, _]]
+ val job = new Job(conf)
+
+ val retval = new ArrayBuffer[SplitInfo]()
+ val list = instance.getSplits(job)
+ for (split <- list) {
+ retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split)
+ }
+
+ return retval.toSet
+ }
+
+ // This method does not expect failures, since validate has already passed ...
+ private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
+ val env = SparkEnv.get
+ val jobConf = new JobConf(configuration)
+ env.hadoop.addCredentials(jobConf)
+ FileInputFormat.setInputPaths(jobConf, path)
+
+ val instance: org.apache.hadoop.mapred.InputFormat[_, _] =
+ ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[
+ org.apache.hadoop.mapred.InputFormat[_, _]]
+
+ val retval = new ArrayBuffer[SplitInfo]()
+ instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach(
+ elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem)
+ )
+
+ return retval.toSet
+ }
+
+ private def findPreferredLocations(): Set[SplitInfo] = {
+ logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat +
+ ", inputFormatClazz : " + inputFormatClazz)
+ if (mapreduceInputFormat) {
+ return prefLocsFromMapreduceInputFormat()
+ }
+ else {
+ assert(mapredInputFormat)
+ return prefLocsFromMapredInputFormat()
+ }
+ }
+}
+
+
+
+
+object InputFormatInfo {
+ /**
+ Computes the preferred locations based on input(s) and returned a location to block map.
+ Typical use of this method for allocation would follow some algo like this
+ (which is what we currently do in YARN branch) :
+ a) For each host, count number of splits hosted on that host.
+ b) Decrement the currently allocated containers on that host.
+ c) Compute rack info for each host and update rack -> count map based on (b).
+ d) Allocate nodes based on (c)
+ e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node
+ (even if data locality on that is very high) : this is to prevent fragility of job if a single
+ (or small set of) hosts go down.
+
+ go to (a) until required nodes are allocated.
+
+ If a node 'dies', follow same procedure.
+
+ PS: I know the wording here is weird, hopefully it makes some sense !
+ */
+ def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = {
+
+ val nodeToSplit = new HashMap[String, HashSet[SplitInfo]]
+ for (inputSplit <- formats) {
+ val splits = inputSplit.findPreferredLocations()
+
+ for (split <- splits){
+ val location = split.hostLocation
+ val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo])
+ set += split
+ }
+ }
+
+ nodeToSplit
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala
new file mode 100644
index 0000000000..50c2b9acd6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+/**
+ * Interface used to listen for job completion or failure events after submitting a job to the
+ * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole
+ * job fails (and no further taskSucceeded events will happen).
+ */
+private[spark] trait JobListener {
+ def taskSucceeded(index: Int, result: Any)
+ def jobFailed(exception: Exception)
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
new file mode 100644
index 0000000000..98ef4d1e63
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.PrintWriter
+import java.io.File
+import java.io.FileNotFoundException
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.{Map, HashMap, ListBuffer}
+import scala.io.Source
+
+import org.apache.spark._
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler.cluster.TaskInfo
+
+// Used to record runtime information for each job, including RDD graph
+// tasks' start/stop shuffle information and information from outside
+
+class JobLogger(val logDirName: String) extends SparkListener with Logging {
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark"
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+ def this() = this(String.valueOf(System.currentTimeMillis()))
+
+ def getLogDir = logDir
+ def getJobIDtoPrintWriter = jobIDToPrintWriter
+ def getStageIDToJobID = stageIDToJobID
+ def getJobIDToStages = jobIDToStages
+ def getEventQueue = eventQueue
+
+ // Create a folder for log files, the folder's name is the creation time of the jobLogger
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ // Create a log file for one job, the file name is the jobID
+ protected def createLogWriter(jobID: Int) {
+ try{
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ // Close log file, and clean the stage relationship in stageIDToJobID
+ protected def closeLogWriter(jobID: Int) =
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+
+ // Write log information to log file, withTime parameter controls whether to recored
+ // time stamp for the information
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.jobId == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ case _ => rddList ++= getRddsInStage(dep.rdd)
+ }
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ // Generate indents and convert to String
+ protected def indentString(indent: Int) = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ protected def getRddName(rdd: RDD[_]) = {
+ var rddName = rdd.getClass.getName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach{ dep => dep match {
+ case shufDep: ShuffleDependency[_,_] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+ }
+
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
+ var stageInfo: String = ""
+ if (stage.isShuffleMap) {
+ stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
+ stage.shuffleDep.get.shuffleId
+ }else{
+ stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.jobId == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
+ } else
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
+
+ // Record task metrics into job log files
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics =
+ taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics =
+ taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ stageLogInfo(
+ stageSubmitted.stage.id,
+ "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.id, stageSubmitted.taskSize))
+ }
+
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ stageLogInfo(
+ stageCompleted.stageInfo.stage.id,
+ "STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
+
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) { }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val task = taskEnd.task
+ val taskInfo = taskEnd.taskInfo
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ taskEnd.reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ val job = jobEnd.job
+ var info = "JOB_ID=" + job.jobId
+ jobEnd.jobResult match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception, _) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.jobId)
+ }
+
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
+ jobLogInfo(jobID, description, false)
+ }
+ }
+
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ val job = jobStart.job
+ val properties = jobStart.properties
+ createLogWriter(job.jobId)
+ recordJobProperties(job.jobId, properties)
+ buildJobDep(job.jobId, job.finalStage)
+ recordStageDep(job.jobId)
+ recordStageDepGraph(job.jobId, job.finalStage)
+ jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala
new file mode 100644
index 0000000000..c381348a8d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+/**
+ * A result of a job in the DAGScheduler.
+ */
+private[spark] sealed trait JobResult
+
+private[spark] case object JobSucceeded extends JobResult
+private[spark] case class JobFailed(exception: Exception, failedStage: Option[Stage]) extends JobResult
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
new file mode 100644
index 0000000000..200d881799
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
+ * results to the given handler function.
+ */
+private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
+ extends JobListener {
+
+ private var finishedTasks = 0
+
+ private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
+ private var jobResult: JobResult = null // If the job is finished, this will be its result
+
+ override def taskSucceeded(index: Int, result: Any) {
+ synchronized {
+ if (jobFinished) {
+ throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
+ }
+ resultHandler(index, result.asInstanceOf[T])
+ finishedTasks += 1
+ if (finishedTasks == totalTasks) {
+ jobFinished = true
+ jobResult = JobSucceeded
+ this.notifyAll()
+ }
+ }
+ }
+
+ override def jobFailed(exception: Exception) {
+ synchronized {
+ if (jobFinished) {
+ throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
+ }
+ jobFinished = true
+ jobResult = JobFailed(exception, None)
+ this.notifyAll()
+ }
+ }
+
+ def awaitResult(): JobResult = synchronized {
+ while (!jobFinished) {
+ this.wait()
+ }
+ return jobResult
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
new file mode 100644
index 0000000000..1c61687f28
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.storage.BlockManagerId
+import java.io.{ObjectOutput, ObjectInput, Externalizable}
+
+/**
+ * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
+ * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
+ * The map output sizes are compressed using MapOutputTracker.compressSize.
+ */
+private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte])
+ extends Externalizable {
+
+ def this() = this(null, null) // For deserialization only
+
+ def writeExternal(out: ObjectOutput) {
+ location.writeExternal(out)
+ out.writeInt(compressedSizes.length)
+ out.write(compressedSizes)
+ }
+
+ def readExternal(in: ObjectInput) {
+ location = BlockManagerId(in)
+ compressedSizes = new Array[Byte](in.readInt())
+ in.readFully(compressedSizes)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
new file mode 100644
index 0000000000..2f157ccdd2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark._
+import java.io._
+import util.{MetadataCleaner, TimeStampedHashMap}
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+private[spark] object ResultTask {
+
+ // A simple map between the stage id to the serialized byte array of a task.
+ // Served as a cache for task serialization because serialization can be
+ // expensive on the master node if it needs to launch thousands of tasks.
+ val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
+
+ val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)
+
+ def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
+ synchronized {
+ val old = serializedInfoCache.get(stageId).orNull
+ if (old != null) {
+ return old
+ } else {
+ val out = new ByteArrayOutputStream
+ val ser = SparkEnv.get.closureSerializer.newInstance
+ val objOut = ser.serializeStream(new GZIPOutputStream(out))
+ objOut.writeObject(rdd)
+ objOut.writeObject(func)
+ objOut.close()
+ val bytes = out.toByteArray
+ serializedInfoCache.put(stageId, bytes)
+ return bytes
+ }
+ }
+ }
+
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
+ val loader = Thread.currentThread.getContextClassLoader
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val ser = SparkEnv.get.closureSerializer.newInstance
+ val objIn = ser.deserializeStream(in)
+ val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+ val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
+ return (rdd, func)
+ }
+
+ def clearCache() {
+ synchronized {
+ serializedInfoCache.clear()
+ }
+ }
+}
+
+
+private[spark] class ResultTask[T, U](
+ stageId: Int,
+ var rdd: RDD[T],
+ var func: (TaskContext, Iterator[T]) => U,
+ var partition: Int,
+ @transient locs: Seq[TaskLocation],
+ val outputId: Int)
+ extends Task[U](stageId) with Externalizable {
+
+ def this() = this(0, null, null, 0, null, 0)
+
+ var split = if (rdd == null) {
+ null
+ } else {
+ rdd.partitions(partition)
+ }
+
+ @transient private val preferredLocs: Seq[TaskLocation] = {
+ if (locs == null) Nil else locs.toSet.toSeq
+ }
+
+ override def run(attemptId: Long): U = {
+ val context = new TaskContext(stageId, partition, attemptId)
+ metrics = Some(context.taskMetrics)
+ try {
+ func(context, rdd.iterator(split, context))
+ } finally {
+ context.executeOnCompleteCallbacks()
+ }
+ }
+
+ override def preferredLocations: Seq[TaskLocation] = preferredLocs
+
+ override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+
+ override def writeExternal(out: ObjectOutput) {
+ RDDCheckpointData.synchronized {
+ split = rdd.partitions(partition)
+ out.writeInt(stageId)
+ val bytes = ResultTask.serializeInfo(
+ stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeInt(outputId)
+ out.writeLong(epoch)
+ out.writeObject(split)
+ }
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val stageId = in.readInt()
+ val numBytes = in.readInt()
+ val bytes = new Array[Byte](numBytes)
+ in.readFully(bytes)
+ val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
+ rdd = rdd_.asInstanceOf[RDD[T]]
+ func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
+ partition = in.readInt()
+ val outputId = in.readInt()
+ epoch = in.readLong()
+ split = in.readObject().asInstanceOf[Partition]
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
new file mode 100644
index 0000000000..ca716b44e8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io._
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+import scala.collection.mutable.HashMap
+
+import org.apache.spark._
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.storage._
+import org.apache.spark.util.{TimeStampedHashMap, MetadataCleaner}
+
+
+private[spark] object ShuffleMapTask {
+
+ // A simple map between the stage id to the serialized byte array of a task.
+ // Served as a cache for task serialization because serialization can be
+ // expensive on the master node if it needs to launch thousands of tasks.
+ val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
+
+ val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues)
+
+ def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
+ synchronized {
+ val old = serializedInfoCache.get(stageId).orNull
+ if (old != null) {
+ return old
+ } else {
+ val out = new ByteArrayOutputStream
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val objOut = ser.serializeStream(new GZIPOutputStream(out))
+ objOut.writeObject(rdd)
+ objOut.writeObject(dep)
+ objOut.close()
+ val bytes = out.toByteArray
+ serializedInfoCache.put(stageId, bytes)
+ return bytes
+ }
+ }
+ }
+
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = {
+ synchronized {
+ val loader = Thread.currentThread.getContextClassLoader
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val objIn = ser.deserializeStream(in)
+ val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+ val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
+ return (rdd, dep)
+ }
+ }
+
+ // Since both the JarSet and FileSet have the same format this is used for both.
+ def deserializeFileSet(bytes: Array[Byte]) : HashMap[String, Long] = {
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val objIn = new ObjectInputStream(in)
+ val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
+ return (HashMap(set.toSeq: _*))
+ }
+
+ def clearCache() {
+ synchronized {
+ serializedInfoCache.clear()
+ }
+ }
+}
+
+private[spark] class ShuffleMapTask(
+ stageId: Int,
+ var rdd: RDD[_],
+ var dep: ShuffleDependency[_,_],
+ var partition: Int,
+ @transient private var locs: Seq[TaskLocation])
+ extends Task[MapStatus](stageId)
+ with Externalizable
+ with Logging {
+
+ protected def this() = this(0, null, null, 0, null)
+
+ @transient private val preferredLocs: Seq[TaskLocation] = {
+ if (locs == null) Nil else locs.toSet.toSeq
+ }
+
+ var split = if (rdd == null) null else rdd.partitions(partition)
+
+ override def writeExternal(out: ObjectOutput) {
+ RDDCheckpointData.synchronized {
+ split = rdd.partitions(partition)
+ out.writeInt(stageId)
+ val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partition)
+ out.writeLong(epoch)
+ out.writeObject(split)
+ }
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val stageId = in.readInt()
+ val numBytes = in.readInt()
+ val bytes = new Array[Byte](numBytes)
+ in.readFully(bytes)
+ val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
+ rdd = rdd_
+ dep = dep_
+ partition = in.readInt()
+ epoch = in.readLong()
+ split = in.readObject().asInstanceOf[Partition]
+ }
+
+ override def run(attemptId: Long): MapStatus = {
+ val numOutputSplits = dep.partitioner.numPartitions
+
+ val taskContext = new TaskContext(stageId, partition, attemptId)
+ metrics = Some(taskContext.taskMetrics)
+
+ val blockManager = SparkEnv.get.blockManager
+ var shuffle: ShuffleBlocks = null
+ var buckets: ShuffleWriterGroup = null
+
+ try {
+ // Obtain all the block writers for shuffle blocks.
+ val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
+ shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
+ buckets = shuffle.acquireWriters(partition)
+
+ // Write the map output to its associated buckets.
+ for (elem <- rdd.iterator(split, taskContext)) {
+ val pair = elem.asInstanceOf[Product2[Any, Any]]
+ val bucketId = dep.partitioner.getPartition(pair._1)
+ buckets.writers(bucketId).write(pair)
+ }
+
+ // Commit the writes. Get the size of each bucket block (total block size).
+ var totalBytes = 0L
+ val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+ writer.commit()
+ writer.close()
+ val size = writer.size()
+ totalBytes += size
+ MapOutputTracker.compressSize(size)
+ }
+
+ // Update shuffle metrics.
+ val shuffleMetrics = new ShuffleWriteMetrics
+ shuffleMetrics.shuffleBytesWritten = totalBytes
+ metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
+
+ return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ } catch { case e: Exception =>
+ // If there is an exception from running the task, revert the partial writes
+ // and throw the exception upstream to Spark.
+ if (buckets != null) {
+ buckets.writers.foreach(_.revertPartialWrites())
+ }
+ throw e
+ } finally {
+ // Release the writers back to the shuffle block manager.
+ if (shuffle != null && buckets != null) {
+ shuffle.releaseWriters(buckets)
+ }
+ // Execute the callbacks on task completion.
+ taskContext.executeOnCompleteCallbacks()
+ }
+ }
+
+ override def preferredLocations: Seq[TaskLocation] = preferredLocs
+
+ override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
new file mode 100644
index 0000000000..3504424fa9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+import org.apache.spark.scheduler.cluster.TaskInfo
+import org.apache.spark.util.Distribution
+import org.apache.spark.{Logging, SparkContext, TaskEndReason, Utils}
+import org.apache.spark.executor.TaskMetrics
+
+sealed trait SparkListenerEvents
+
+case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties)
+ extends SparkListenerEvents
+
+case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+
+case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+
+case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) extends SparkListenerEvents
+
+case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null)
+ extends SparkListenerEvents
+
+case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
+ extends SparkListenerEvents
+
+trait SparkListener {
+ /**
+ * Called when a stage is completed, with information on the completed stage
+ */
+ def onStageCompleted(stageCompleted: StageCompleted) { }
+
+ /**
+ * Called when a stage is submitted
+ */
+ def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { }
+
+ /**
+ * Called when a task starts
+ */
+ def onTaskStart(taskEnd: SparkListenerTaskStart) { }
+
+ /**
+ * Called when a task ends
+ */
+ def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
+
+ /**
+ * Called when a job starts
+ */
+ def onJobStart(jobStart: SparkListenerJobStart) { }
+
+ /**
+ * Called when a job ends
+ */
+ def onJobEnd(jobEnd: SparkListenerJobEnd) { }
+
+}
+
+/**
+ * Simple SparkListener that logs a few summary statistics when each stage completes
+ */
+class StatsReportListener extends SparkListener with Logging {
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ import org.apache.spark.scheduler.StatsReportListener._
+ implicit val sc = stageCompleted
+ this.logInfo("Finished stage: " + stageCompleted.stageInfo)
+ showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
+
+ //shuffle write
+ showBytesDistribution("shuffle bytes written:",(_,metric) => metric.shuffleWriteMetrics.map{_.shuffleBytesWritten})
+
+ //fetch & io
+ showMillisDistribution("fetch wait time:",(_, metric) => metric.shuffleReadMetrics.map{_.fetchWaitTime})
+ showBytesDistribution("remote bytes read:", (_, metric) => metric.shuffleReadMetrics.map{_.remoteBytesRead})
+ showBytesDistribution("task result size:", (_, metric) => Some(metric.resultSize))
+
+ //runtime breakdown
+
+ val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
+ case (info, metrics) => RuntimePercentage(info.duration, metrics)
+ }
+ showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
+ showDistribution("fetch wait time pct: ", Distribution(runtimePcts.flatMap{_.fetchPct.map{_ * 100}}), "%2.0f %%")
+ showDistribution("other time pct: ", Distribution(runtimePcts.map{_.other * 100}), "%2.0f %%")
+ }
+
+}
+
+object StatsReportListener extends Logging {
+
+ //for profiling, the extremes are more interesting
+ val percentiles = Array[Int](0,5,10,25,50,75,90,95,100)
+ val probabilities = percentiles.map{_ / 100.0}
+ val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
+
+ def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
+ Distribution(stage.stageInfo.taskInfos.flatMap{
+ case ((info,metric)) => getMetric(info, metric)})
+ }
+
+ //is there some way to setup the types that I can get rid of this completely?
+ def extractLongDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Long]): Option[Distribution] = {
+ extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble})
+ }
+
+ def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
+ val stats = d.statCounter
+ logInfo(heading + stats)
+ val quantiles = d.getQuantiles(probabilities).map{formatNumber}
+ logInfo(percentilesHeader)
+ logInfo("\t" + quantiles.mkString("\t"))
+ }
+
+ def showDistribution(heading: String, dOpt: Option[Distribution], formatNumber: Double => String) {
+ dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
+ }
+
+ def showDistribution(heading: String, dOpt: Option[Distribution], format:String) {
+ def f(d:Double) = format.format(d)
+ showDistribution(heading, dOpt, f _)
+ }
+
+ def showDistribution(heading:String, format: String, getMetric: (TaskInfo,TaskMetrics) => Option[Double])
+ (implicit stage: StageCompleted) {
+ showDistribution(heading, extractDoubleDistribution(stage, getMetric), format)
+ }
+
+ def showBytesDistribution(heading:String, getMetric: (TaskInfo,TaskMetrics) => Option[Long])
+ (implicit stage: StageCompleted) {
+ showBytesDistribution(heading, extractLongDistribution(stage, getMetric))
+ }
+
+ def showBytesDistribution(heading: String, dOpt: Option[Distribution]) {
+ dOpt.foreach{dist => showBytesDistribution(heading, dist)}
+ }
+
+ def showBytesDistribution(heading: String, dist: Distribution) {
+ showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String)
+ }
+
+ def showMillisDistribution(heading: String, dOpt: Option[Distribution]) {
+ showDistribution(heading, dOpt, (d => StatsReportListener.millisToString(d.toLong)): Double => String)
+ }
+
+ def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long])
+ (implicit stage: StageCompleted) {
+ showMillisDistribution(heading, extractLongDistribution(stage, getMetric))
+ }
+
+
+
+ val seconds = 1000L
+ val minutes = seconds * 60
+ val hours = minutes * 60
+
+ /**
+ * reformat a time interval in milliseconds to a prettier format for output
+ */
+ def millisToString(ms: Long) = {
+ val (size, units) =
+ if (ms > hours) {
+ (ms.toDouble / hours, "hours")
+ } else if (ms > minutes) {
+ (ms.toDouble / minutes, "min")
+ } else if (ms > seconds) {
+ (ms.toDouble / seconds, "s")
+ } else {
+ (ms.toDouble, "ms")
+ }
+ "%.1f %s".format(size, units)
+ }
+}
+
+
+
+case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
+object RuntimePercentage {
+ def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
+ val denom = totalTime.toDouble
+ val fetchTime = metrics.shuffleReadMetrics.map{_.fetchWaitTime}
+ val fetch = fetchTime.map{_ / denom}
+ val exec = (metrics.executorRunTime - fetchTime.getOrElse(0l)) / denom
+ val other = 1.0 - (exec + fetch.getOrElse(0d))
+ RuntimePercentage(exec, fetch, other)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
new file mode 100644
index 0000000000..a65e1ecd6d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+
+import org.apache.spark.Logging
+
+/** Asynchronously passes SparkListenerEvents to registered SparkListeners. */
+private[spark] class SparkListenerBus() extends Logging {
+ private val sparkListeners = new ArrayBuffer[SparkListener]() with SynchronizedBuffer[SparkListener]
+
+ /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
+ * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
+ private val EVENT_QUEUE_CAPACITY = 10000
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents](EVENT_QUEUE_CAPACITY)
+ private var queueFullErrorMessageLogged = false
+
+ new Thread("SparkListenerBus") {
+ setDaemon(true)
+ override def run() {
+ while (true) {
+ val event = eventQueue.take
+ event match {
+ case stageSubmitted: SparkListenerStageSubmitted =>
+ sparkListeners.foreach(_.onStageSubmitted(stageSubmitted))
+ case stageCompleted: StageCompleted =>
+ sparkListeners.foreach(_.onStageCompleted(stageCompleted))
+ case jobStart: SparkListenerJobStart =>
+ sparkListeners.foreach(_.onJobStart(jobStart))
+ case jobEnd: SparkListenerJobEnd =>
+ sparkListeners.foreach(_.onJobEnd(jobEnd))
+ case taskStart: SparkListenerTaskStart =>
+ sparkListeners.foreach(_.onTaskStart(taskStart))
+ case taskEnd: SparkListenerTaskEnd =>
+ sparkListeners.foreach(_.onTaskEnd(taskEnd))
+ case _ =>
+ }
+ }
+ }
+ }.start()
+
+ def addListener(listener: SparkListener) {
+ sparkListeners += listener
+ }
+
+ def post(event: SparkListenerEvents) {
+ val eventAdded = eventQueue.offer(event)
+ if (!eventAdded && !queueFullErrorMessageLogged) {
+ logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
+ "This likely means one of the SparkListeners is too slow and cannot keep up with the " +
+ "rate at which tasks are being started by the scheduler.")
+ queueFullErrorMessageLogged = true
+ }
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala
new file mode 100644
index 0000000000..5b40a3eb29
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import collection.mutable.ArrayBuffer
+
+// information about a specific split instance : handles both split instances.
+// So that we do not need to worry about the differences.
+class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String,
+ val length: Long, val underlyingSplit: Any) {
+ override def toString(): String = {
+ "SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz +
+ ", hostLocation : " + hostLocation + ", path : " + path +
+ ", length : " + length + ", underlyingSplit " + underlyingSplit
+ }
+
+ override def hashCode(): Int = {
+ var hashCode = inputFormatClazz.hashCode
+ hashCode = hashCode * 31 + hostLocation.hashCode
+ hashCode = hashCode * 31 + path.hashCode
+ // ignore overflow ? It is hashcode anyway !
+ hashCode = hashCode * 31 + (length & 0x7fffffff).toInt
+ hashCode
+ }
+
+ // This is practically useless since most of the Split impl's dont seem to implement equals :-(
+ // So unless there is identity equality between underlyingSplits, it will always fail even if it
+ // is pointing to same block.
+ override def equals(other: Any): Boolean = other match {
+ case that: SplitInfo => {
+ this.hostLocation == that.hostLocation &&
+ this.inputFormatClazz == that.inputFormatClazz &&
+ this.path == that.path &&
+ this.length == that.length &&
+ // other split specific checks (like start for FileSplit)
+ this.underlyingSplit == that.underlyingSplit
+ }
+ case _ => false
+ }
+}
+
+object SplitInfo {
+
+ def toSplitInfo(inputFormatClazz: Class[_], path: String,
+ mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = {
+ val retval = new ArrayBuffer[SplitInfo]()
+ val length = mapredSplit.getLength
+ for (host <- mapredSplit.getLocations) {
+ retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit)
+ }
+ retval
+ }
+
+ def toSplitInfo(inputFormatClazz: Class[_], path: String,
+ mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = {
+ val retval = new ArrayBuffer[SplitInfo]()
+ val length = mapreduceSplit.getLength
+ for (host <- mapreduceSplit.getLocations) {
+ retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit)
+ }
+ retval
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
new file mode 100644
index 0000000000..87b1fe4e0c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.net.URI
+
+import org.apache.spark._
+import org.apache.spark.storage.BlockManagerId
+
+/**
+ * A stage is a set of independent tasks all computing the same function that need to run as part
+ * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run
+ * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the
+ * DAGScheduler runs these stages in topological order.
+ *
+ * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for
+ * another stage, or a result stage, in which case its tasks directly compute the action that
+ * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes
+ * that each output partition is on.
+ *
+ * Each Stage also has a jobId, identifying the job that first submitted the stage. When FIFO
+ * scheduling is used, this allows Stages from earlier jobs to be computed first or recovered
+ * faster on failure.
+ */
+private[spark] class Stage(
+ val id: Int,
+ val rdd: RDD[_],
+ val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
+ val parents: List[Stage],
+ val jobId: Int,
+ callSite: Option[String])
+ extends Logging {
+
+ val isShuffleMap = shuffleDep != None
+ val numPartitions = rdd.partitions.size
+ val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
+ var numAvailableOutputs = 0
+
+ /** When first task was submitted to scheduler. */
+ var submissionTime: Option[Long] = None
+ var completionTime: Option[Long] = None
+
+ private var nextAttemptId = 0
+
+ def isAvailable: Boolean = {
+ if (!isShuffleMap) {
+ true
+ } else {
+ numAvailableOutputs == numPartitions
+ }
+ }
+
+ def addOutputLoc(partition: Int, status: MapStatus) {
+ val prevList = outputLocs(partition)
+ outputLocs(partition) = status :: prevList
+ if (prevList == Nil)
+ numAvailableOutputs += 1
+ }
+
+ def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) {
+ val prevList = outputLocs(partition)
+ val newList = prevList.filterNot(_.location == bmAddress)
+ outputLocs(partition) = newList
+ if (prevList != Nil && newList == Nil) {
+ numAvailableOutputs -= 1
+ }
+ }
+
+ def removeOutputsOnExecutor(execId: String) {
+ var becameUnavailable = false
+ for (partition <- 0 until numPartitions) {
+ val prevList = outputLocs(partition)
+ val newList = prevList.filterNot(_.location.executorId == execId)
+ outputLocs(partition) = newList
+ if (prevList != Nil && newList == Nil) {
+ becameUnavailable = true
+ numAvailableOutputs -= 1
+ }
+ }
+ if (becameUnavailable) {
+ logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
+ this, execId, numAvailableOutputs, numPartitions, isAvailable))
+ }
+ }
+
+ def newAttemptId(): Int = {
+ val id = nextAttemptId
+ nextAttemptId += 1
+ return id
+ }
+
+ val name = callSite.getOrElse(rdd.origin)
+
+ override def toString = "Stage " + id
+
+ override def hashCode(): Int = id
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
new file mode 100644
index 0000000000..72cb1c9ce8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.scheduler.cluster.TaskInfo
+import scala.collection._
+import org.apache.spark.executor.TaskMetrics
+
+case class StageInfo(
+ val stage: Stage,
+ val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
+) {
+ override def toString = stage.rdd.toString
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
new file mode 100644
index 0000000000..598d91752a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.serializer.SerializerInstance
+import java.io.{DataInputStream, DataOutputStream}
+import java.nio.ByteBuffer
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import org.apache.spark.util.ByteBufferInputStream
+import scala.collection.mutable.HashMap
+import org.apache.spark.executor.TaskMetrics
+
+/**
+ * A task to execute on a worker node.
+ */
+private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
+ def run(attemptId: Long): T
+ def preferredLocations: Seq[TaskLocation] = Nil
+
+ var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler.
+
+ var metrics: Option[TaskMetrics] = None
+
+}
+
+/**
+ * Handles transmission of tasks and their dependencies, because this can be slightly tricky. We
+ * need to send the list of JARs and files added to the SparkContext with each task to ensure that
+ * worker nodes find out about it, but we can't make it part of the Task because the user's code in
+ * the task might depend on one of the JARs. Thus we serialize each task as multiple objects, by
+ * first writing out its dependencies.
+ */
+private[spark] object Task {
+ /**
+ * Serialize a task and the current app dependencies (files and JARs added to the SparkContext)
+ */
+ def serializeWithDependencies(
+ task: Task[_],
+ currentFiles: HashMap[String, Long],
+ currentJars: HashMap[String, Long],
+ serializer: SerializerInstance)
+ : ByteBuffer = {
+
+ val out = new FastByteArrayOutputStream(4096)
+ val dataOut = new DataOutputStream(out)
+
+ // Write currentFiles
+ dataOut.writeInt(currentFiles.size)
+ for ((name, timestamp) <- currentFiles) {
+ dataOut.writeUTF(name)
+ dataOut.writeLong(timestamp)
+ }
+
+ // Write currentJars
+ dataOut.writeInt(currentJars.size)
+ for ((name, timestamp) <- currentJars) {
+ dataOut.writeUTF(name)
+ dataOut.writeLong(timestamp)
+ }
+
+ // Write the task itself and finish
+ dataOut.flush()
+ val taskBytes = serializer.serialize(task).array()
+ out.write(taskBytes)
+ out.trim()
+ ByteBuffer.wrap(out.array)
+ }
+
+ /**
+ * Deserialize the list of dependencies in a task serialized with serializeWithDependencies,
+ * and return the task itself as a serialized ByteBuffer. The caller can then update its
+ * ClassLoaders and deserialize the task.
+ *
+ * @return (taskFiles, taskJars, taskBytes)
+ */
+ def deserializeWithDependencies(serializedTask: ByteBuffer)
+ : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = {
+
+ val in = new ByteBufferInputStream(serializedTask)
+ val dataIn = new DataInputStream(in)
+
+ // Read task's files
+ val taskFiles = new HashMap[String, Long]()
+ val numFiles = dataIn.readInt()
+ for (i <- 0 until numFiles) {
+ taskFiles(dataIn.readUTF()) = dataIn.readLong()
+ }
+
+ // Read task's JARs
+ val taskJars = new HashMap[String, Long]()
+ val numJars = dataIn.readInt()
+ for (i <- 0 until numJars) {
+ taskJars(dataIn.readUTF()) = dataIn.readLong()
+ }
+
+ // Create a sub-buffer for the rest of the data, which is the serialized Task object
+ val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task
+ (taskFiles, taskJars, subBuffer)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
new file mode 100644
index 0000000000..67c9a6760b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+/**
+ * A location where a task should run. This can either be a host or a (host, executorID) pair.
+ * In the latter case, we will prefer to launch the task on that executorID, but our next level
+ * of preference will be executors on the same host if this is not possible.
+ */
+private[spark]
+class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable {
+ override def toString: String = "TaskLocation(" + host + ", " + executorId + ")"
+}
+
+private[spark] object TaskLocation {
+ def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId))
+
+ def apply(host: String) = new TaskLocation(host, None)
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
new file mode 100644
index 0000000000..776675d28c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io._
+
+import scala.collection.mutable.Map
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.{Utils, SparkEnv}
+import java.nio.ByteBuffer
+
+// Task result. Also contains updates to accumulator variables.
+// TODO: Use of distributed cache to return result is a hack to get around
+// what seems to be a bug with messages over 60KB in libprocess; fix it
+private[spark]
+class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
+ extends Externalizable
+{
+ def this() = this(null.asInstanceOf[T], null, null)
+
+ override def writeExternal(out: ObjectOutput) {
+
+ val objectSer = SparkEnv.get.serializer.newInstance()
+ val bb = objectSer.serialize(value)
+
+ out.writeInt(bb.remaining())
+ Utils.writeByteBuffer(bb, out)
+
+ out.writeInt(accumUpdates.size)
+ for ((key, value) <- accumUpdates) {
+ out.writeLong(key)
+ out.writeObject(value)
+ }
+ out.writeObject(metrics)
+ }
+
+ override def readExternal(in: ObjectInput) {
+
+ val objectSer = SparkEnv.get.serializer.newInstance()
+
+ val blen = in.readInt()
+ val byteVal = new Array[Byte](blen)
+ in.readFully(byteVal)
+ value = objectSer.deserialize(ByteBuffer.wrap(byteVal))
+
+ val numUpdates = in.readInt
+ if (numUpdates == 0) {
+ accumUpdates = null
+ } else {
+ accumUpdates = Map()
+ for (i <- 0 until numUpdates) {
+ accumUpdates(in.readLong()) = in.readObject()
+ }
+ }
+ metrics = in.readObject().asInstanceOf[TaskMetrics]
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
new file mode 100644
index 0000000000..63be8ba3f5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.scheduler.cluster.Pool
+import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+/**
+ * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
+ * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
+ * and are responsible for sending the tasks to the cluster, running them, retrying if there
+ * are failures, and mitigating stragglers. They return events to the DAGScheduler through
+ * the TaskSchedulerListener interface.
+ */
+private[spark] trait TaskScheduler {
+
+ def rootPool: Pool
+
+ def schedulingMode: SchedulingMode
+
+ def start(): Unit
+
+ // Invoked after system has successfully initialized (typically in spark context).
+ // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc.
+ def postStartHook() { }
+
+ // Disconnect from the cluster.
+ def stop(): Unit
+
+ // Submit a sequence of tasks to run.
+ def submitTasks(taskSet: TaskSet): Unit
+
+ // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
+ def setListener(listener: TaskSchedulerListener): Unit
+
+ // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
+ def defaultParallelism(): Int
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
new file mode 100644
index 0000000000..83be051c1a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.scheduler.cluster.TaskInfo
+import scala.collection.mutable.Map
+
+import org.apache.spark.TaskEndReason
+import org.apache.spark.executor.TaskMetrics
+
+/**
+ * Interface for getting events back from the TaskScheduler.
+ */
+private[spark] trait TaskSchedulerListener {
+ // A task has started.
+ def taskStarted(task: Task[_], taskInfo: TaskInfo)
+
+ // A task has finished or failed.
+ def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
+
+ // A node was added to the cluster.
+ def executorGained(execId: String, host: String): Unit
+
+ // A node was lost from the cluster.
+ def executorLost(execId: String): Unit
+
+ // The TaskScheduler wants to abort an entire task set.
+ def taskSetFailed(taskSet: TaskSet, reason: String): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
new file mode 100644
index 0000000000..c3ad325156
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.util.Properties
+
+/**
+ * A set of tasks submitted together to the low-level TaskScheduler, usually representing
+ * missing partitions of a particular stage.
+ */
+private[spark] class TaskSet(
+ val tasks: Array[Task[_]],
+ val stageId: Int,
+ val attempt: Int,
+ val priority: Int,
+ val properties: Properties) {
+ val id: String = stageId + "." + attempt
+
+ override def toString: String = "TaskSet " + id
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
new file mode 100644
index 0000000000..3196ab5022
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.lang.{Boolean => JBoolean}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+
+import org.apache.spark._
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicLong
+import java.util.{TimerTask, Timer}
+
+/**
+ * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
+ * initialize() and start(), then submit task sets through the runTasks method.
+ *
+ * This class can work with multiple types of clusters by acting through a SchedulerBackend.
+ * It handles common logic, like determining a scheduling order across jobs, waking up to launch
+ * speculative tasks, etc.
+ *
+ * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
+ * threads, so it needs locks in public API methods to maintain its state. In addition, some
+ * SchedulerBackends sycnchronize on themselves when they want to send events here, and then
+ * acquire a lock on us, so we need to make sure that we don't try to lock the backend while
+ * we are holding a lock on ourselves.
+ */
+private[spark] class ClusterScheduler(val sc: SparkContext)
+ extends TaskScheduler
+ with Logging
+{
+ // How often to check for speculative tasks
+ val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
+
+ // Threshold above which we warn user initial TaskSet may be starved
+ val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
+
+ val activeTaskSets = new HashMap[String, TaskSetManager]
+
+ val taskIdToTaskSetId = new HashMap[Long, String]
+ val taskIdToExecutorId = new HashMap[Long, String]
+ val taskSetTaskIds = new HashMap[String, HashSet[Long]]
+
+ @volatile private var hasReceivedTask = false
+ @volatile private var hasLaunchedTask = false
+ private val starvationTimer = new Timer(true)
+
+ // Incrementing Mesos task IDs
+ val nextTaskId = new AtomicLong(0)
+
+ // Which executor IDs we have executors on
+ val activeExecutorIds = new HashSet[String]
+
+ // The set of executors we have on each host; this is used to compute hostsAlive, which
+ // in turn is used to decide when we can attain data locality on a given host
+ private val executorsByHost = new HashMap[String, HashSet[String]]
+
+ private val executorIdToHost = new HashMap[String, String]
+
+ // JAR server, if any JARs were added by the user to the SparkContext
+ var jarServer: HttpServer = null
+
+ // URIs of JARs to pass to executor
+ var jarUris: String = ""
+
+ // Listener object to pass upcalls into
+ var listener: TaskSchedulerListener = null
+
+ var backend: SchedulerBackend = null
+
+ val mapOutputTracker = SparkEnv.get.mapOutputTracker
+
+ var schedulableBuilder: SchedulableBuilder = null
+ var rootPool: Pool = null
+ // default scheduler is FIFO
+ val schedulingMode: SchedulingMode = SchedulingMode.withName(
+ System.getProperty("spark.cluster.schedulingmode", "FIFO"))
+
+ override def setListener(listener: TaskSchedulerListener) {
+ this.listener = listener
+ }
+
+ def initialize(context: SchedulerBackend) {
+ backend = context
+ // temporarily set rootPool name to empty
+ rootPool = new Pool("", schedulingMode, 0, 0)
+ schedulableBuilder = {
+ schedulingMode match {
+ case SchedulingMode.FIFO =>
+ new FIFOSchedulableBuilder(rootPool)
+ case SchedulingMode.FAIR =>
+ new FairSchedulableBuilder(rootPool)
+ }
+ }
+ schedulableBuilder.buildPools()
+ }
+
+ def newTaskId(): Long = nextTaskId.getAndIncrement()
+
+ override def start() {
+ backend.start()
+
+ if (System.getProperty("spark.speculation", "false").toBoolean) {
+ new Thread("ClusterScheduler speculation check") {
+ setDaemon(true)
+
+ override def run() {
+ logInfo("Starting speculative execution thread")
+ while (true) {
+ try {
+ Thread.sleep(SPECULATION_INTERVAL)
+ } catch {
+ case e: InterruptedException => {}
+ }
+ checkSpeculatableTasks()
+ }
+ }
+ }.start()
+ }
+ }
+
+ override def submitTasks(taskSet: TaskSet) {
+ val tasks = taskSet.tasks
+ logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
+ this.synchronized {
+ val manager = new ClusterTaskSetManager(this, taskSet)
+ activeTaskSets(taskSet.id) = manager
+ schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
+ taskSetTaskIds(taskSet.id) = new HashSet[Long]()
+
+ if (!hasReceivedTask) {
+ starvationTimer.scheduleAtFixedRate(new TimerTask() {
+ override def run() {
+ if (!hasLaunchedTask) {
+ logWarning("Initial job has not accepted any resources; " +
+ "check your cluster UI to ensure that workers are registered " +
+ "and have sufficient memory")
+ } else {
+ this.cancel()
+ }
+ }
+ }, STARVATION_TIMEOUT, STARVATION_TIMEOUT)
+ }
+ hasReceivedTask = true
+ }
+ backend.reviveOffers()
+ }
+
+ def taskSetFinished(manager: TaskSetManager) {
+ this.synchronized {
+ activeTaskSets -= manager.taskSet.id
+ manager.parent.removeSchedulable(manager)
+ logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
+ taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+ taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
+ taskSetTaskIds.remove(manager.taskSet.id)
+ }
+ }
+
+ /**
+ * Called by cluster manager to offer resources on slaves. We respond by asking our active task
+ * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
+ * that tasks are balanced across the cluster.
+ */
+ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
+ SparkEnv.set(sc.env)
+
+ // Mark each slave as alive and remember its hostname
+ for (o <- offers) {
+ executorIdToHost(o.executorId) = o.host
+ if (!executorsByHost.contains(o.host)) {
+ executorsByHost(o.host) = new HashSet[String]()
+ executorGained(o.executorId, o.host)
+ }
+ }
+
+ // Build a list of tasks to assign to each worker
+ val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
+ val availableCpus = offers.map(o => o.cores).toArray
+ val sortedTaskSets = rootPool.getSortedTaskSetQueue()
+ for (taskSet <- sortedTaskSets) {
+ logDebug("parentName: %s, name: %s, runningTasks: %s".format(
+ taskSet.parent.name, taskSet.name, taskSet.runningTasks))
+ }
+
+ // Take each TaskSet in our scheduling order, and then offer it each node in increasing order
+ // of locality levels so that it gets a chance to launch local tasks on all of them.
+ var launchedTask = false
+ for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) {
+ do {
+ launchedTask = false
+ for (i <- 0 until offers.size) {
+ val execId = offers(i).executorId
+ val host = offers(i).host
+ for (task <- taskSet.resourceOffer(execId, host, availableCpus(i), maxLocality)) {
+ tasks(i) += task
+ val tid = task.taskId
+ taskIdToTaskSetId(tid) = taskSet.taskSet.id
+ taskSetTaskIds(taskSet.taskSet.id) += tid
+ taskIdToExecutorId(tid) = execId
+ activeExecutorIds += execId
+ executorsByHost(host) += execId
+ availableCpus(i) -= 1
+ launchedTask = true
+ }
+ }
+ } while (launchedTask)
+ }
+
+ if (tasks.size > 0) {
+ hasLaunchedTask = true
+ }
+ return tasks
+ }
+
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ var taskSetToUpdate: Option[TaskSetManager] = None
+ var failedExecutor: Option[String] = None
+ var taskFailed = false
+ synchronized {
+ try {
+ if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
+ // We lost this entire executor, so remember that it's gone
+ val execId = taskIdToExecutorId(tid)
+ if (activeExecutorIds.contains(execId)) {
+ removeExecutor(execId)
+ failedExecutor = Some(execId)
+ }
+ }
+ taskIdToTaskSetId.get(tid) match {
+ case Some(taskSetId) =>
+ if (activeTaskSets.contains(taskSetId)) {
+ taskSetToUpdate = Some(activeTaskSets(taskSetId))
+ }
+ if (TaskState.isFinished(state)) {
+ taskIdToTaskSetId.remove(tid)
+ if (taskSetTaskIds.contains(taskSetId)) {
+ taskSetTaskIds(taskSetId) -= tid
+ }
+ taskIdToExecutorId.remove(tid)
+ }
+ if (state == TaskState.FAILED) {
+ taskFailed = true
+ }
+ case None =>
+ logInfo("Ignoring update from TID " + tid + " because its task set is gone")
+ }
+ } catch {
+ case e: Exception => logError("Exception in statusUpdate", e)
+ }
+ }
+ // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock
+ if (taskSetToUpdate != None) {
+ taskSetToUpdate.get.statusUpdate(tid, state, serializedData)
+ }
+ if (failedExecutor != None) {
+ listener.executorLost(failedExecutor.get)
+ backend.reviveOffers()
+ }
+ if (taskFailed) {
+ // Also revive offers if a task had failed for some reason other than host lost
+ backend.reviveOffers()
+ }
+ }
+
+ def error(message: String) {
+ synchronized {
+ if (activeTaskSets.size > 0) {
+ // Have each task set throw a SparkException with the error
+ for ((taskSetId, manager) <- activeTaskSets) {
+ try {
+ manager.error(message)
+ } catch {
+ case e: Exception => logError("Exception in error callback", e)
+ }
+ }
+ } else {
+ // No task sets are active but we still got an error. Just exit since this
+ // must mean the error is during registration.
+ // It might be good to do something smarter here in the future.
+ logError("Exiting due to error from cluster scheduler: " + message)
+ System.exit(1)
+ }
+ }
+ }
+
+ override def stop() {
+ if (backend != null) {
+ backend.stop()
+ }
+ if (jarServer != null) {
+ jarServer.stop()
+ }
+
+ // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
+ // TODO: Do something better !
+ Thread.sleep(5000L)
+ }
+
+ override def defaultParallelism() = backend.defaultParallelism()
+
+
+ // Check for speculatable tasks in all our active jobs.
+ def checkSpeculatableTasks() {
+ var shouldRevive = false
+ synchronized {
+ shouldRevive = rootPool.checkSpeculatableTasks()
+ }
+ if (shouldRevive) {
+ backend.reviveOffers()
+ }
+ }
+
+ // Check for pending tasks in all our active jobs.
+ def hasPendingTasks: Boolean = {
+ synchronized {
+ rootPool.hasPendingTasks()
+ }
+ }
+
+ def executorLost(executorId: String, reason: ExecutorLossReason) {
+ var failedExecutor: Option[String] = None
+
+ synchronized {
+ if (activeExecutorIds.contains(executorId)) {
+ val hostPort = executorIdToHost(executorId)
+ logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
+ removeExecutor(executorId)
+ failedExecutor = Some(executorId)
+ } else {
+ // We may get multiple executorLost() calls with different loss reasons. For example, one
+ // may be triggered by a dropped connection from the slave while another may be a report
+ // of executor termination from Mesos. We produce log messages for both so we eventually
+ // report the termination reason.
+ logError("Lost an executor " + executorId + " (already removed): " + reason)
+ }
+ }
+ // Call listener.executorLost without holding the lock on this to prevent deadlock
+ if (failedExecutor != None) {
+ listener.executorLost(failedExecutor.get)
+ backend.reviveOffers()
+ }
+ }
+
+ /** Remove an executor from all our data structures and mark it as lost */
+ private def removeExecutor(executorId: String) {
+ activeExecutorIds -= executorId
+ val host = executorIdToHost(executorId)
+ val execs = executorsByHost.getOrElse(host, new HashSet)
+ execs -= executorId
+ if (execs.isEmpty) {
+ executorsByHost -= host
+ }
+ executorIdToHost -= executorId
+ rootPool.executorLost(executorId, host)
+ }
+
+ def executorGained(execId: String, host: String) {
+ listener.executorGained(execId, host)
+ }
+
+ def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
+ executorsByHost.get(host).map(_.toSet)
+ }
+
+ def hasExecutorsAliveOnHost(host: String): Boolean = synchronized {
+ executorsByHost.contains(host)
+ }
+
+ def isExecutorAlive(execId: String): Boolean = synchronized {
+ activeExecutorIds.contains(execId)
+ }
+
+ // By default, rack is unknown
+ def getRackForHost(value: String): Option[String] = None
+}
+
+
+object ClusterScheduler {
+ /**
+ * Used to balance containers across hosts.
+ *
+ * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of
+ * resource offers representing the order in which the offers should be used. The resource
+ * offers are ordered such that we'll allocate one container on each host before allocating a
+ * second container on any host, and so on, in order to reduce the damage if a host fails.
+ *
+ * For example, given <h1, [o1, o2, o3]>, <h2, [o4]>, <h1, [o5, o6]>, returns
+ * [o1, o5, o4, 02, o6, o3]
+ */
+ def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
+ val _keyList = new ArrayBuffer[K](map.size)
+ _keyList ++= map.keys
+
+ // order keyList based on population of value in map
+ val keyList = _keyList.sortWith(
+ (left, right) => map(left).size > map(right).size
+ )
+
+ val retval = new ArrayBuffer[T](keyList.size * 2)
+ var index = 0
+ var found = true
+
+ while (found) {
+ found = false
+ for (key <- keyList) {
+ val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
+ assert(containerList != null)
+ // Get the index'th entry for this host - if present
+ if (index < containerList.size){
+ retval += containerList.apply(index)
+ found = true
+ }
+ }
+ index += 1
+ }
+
+ retval.toList
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
new file mode 100644
index 0000000000..a33307b83a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -0,0 +1,712 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.nio.ByteBuffer
+import java.util.{Arrays, NoSuchElementException}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
+
+import org.apache.spark.{FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskState, Utils}
+import org.apache.spark.{ExceptionFailure, SparkException, TaskResultTooBigFailure}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.scheduler._
+import scala.Some
+import org.apache.spark.FetchFailed
+import org.apache.spark.ExceptionFailure
+import org.apache.spark.TaskResultTooBigFailure
+import org.apache.spark.util.{SystemClock, Clock}
+
+
+/**
+ * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
+ * the status of each task, retries tasks if they fail (up to a limited number of times), and
+ * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
+ * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
+ * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ *
+ * THREADING: This class is designed to only be called from code with a lock on the
+ * ClusterScheduler (e.g. its event handlers). It should not be called from other threads.
+ */
+private[spark] class ClusterTaskSetManager(
+ sched: ClusterScheduler,
+ val taskSet: TaskSet,
+ clock: Clock = SystemClock)
+ extends TaskSetManager
+ with Logging
+{
+ // CPUs to request per task
+ val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
+
+ // Maximum times a task is allowed to fail before failing the job
+ val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt
+
+ // Quantile of tasks at which to start speculation
+ val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
+ val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
+
+ // Serializer for closures and tasks.
+ val env = SparkEnv.get
+ val ser = env.closureSerializer.newInstance()
+
+ val tasks = taskSet.tasks
+ val numTasks = tasks.length
+ val copiesRunning = new Array[Int](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
+ var tasksFinished = 0
+
+ var weight = 1
+ var minShare = 0
+ var runningTasks = 0
+ var priority = taskSet.priority
+ var stageId = taskSet.stageId
+ var name = "TaskSet_"+taskSet.stageId.toString
+ var parent: Schedulable = null
+
+ // Set of pending tasks for each executor. These collections are actually
+ // treated as stacks, in which new tasks are added to the end of the
+ // ArrayBuffer and removed from the end. This makes it faster to detect
+ // tasks that repeatedly fail because whenever a task failed, it is put
+ // back at the head of the stack. They are also only cleaned up lazily;
+ // when a task is launched, it remains in all the pending lists except
+ // the one that it was launched from, but gets removed from them later.
+ private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set of pending tasks for each host. Similar to pendingTasksForExecutor,
+ // but at host level.
+ private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set of pending tasks for each rack -- similar to the above.
+ private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set containing pending tasks with no locality preferences.
+ val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+
+ // Set containing all pending tasks (also used as a stack, as above).
+ val allPendingTasks = new ArrayBuffer[Int]
+
+ // Tasks that can be speculated. Since these will be a small fraction of total
+ // tasks, we'll just hold them in a HashSet.
+ val speculatableTasks = new HashSet[Int]
+
+ // Task index, start and finish time for each task attempt (indexed by task ID)
+ val taskInfos = new HashMap[Long, TaskInfo]
+
+ // Did the TaskSet fail?
+ var failed = false
+ var causeOfFailure = ""
+
+ // How frequently to reprint duplicate exceptions in full, in milliseconds
+ val EXCEPTION_PRINT_INTERVAL =
+ System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
+
+ // Map of recent exceptions (identified by string representation and top stack frame) to
+ // duplicate count (how many times the same exception has appeared) and time the full exception
+ // was printed. This should ideally be an LRU map that can drop old exceptions automatically.
+ val recentExceptions = HashMap[String, (Int, Long)]()
+
+ // Figure out the current map output tracker epoch and set it on all tasks
+ val epoch = sched.mapOutputTracker.getEpoch
+ logDebug("Epoch for " + taskSet + ": " + epoch)
+ for (t <- tasks) {
+ t.epoch = epoch
+ }
+
+ // Add all our tasks to the pending lists. We do this in reverse order
+ // of task index so that tasks with low indices get launched first.
+ for (i <- (0 until numTasks).reverse) {
+ addPendingTask(i)
+ }
+
+ // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
+ val myLocalityLevels = computeValidLocalityLevels()
+ val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
+
+ // Delay scheduling variables: we keep track of our current locality level and the time we
+ // last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
+ // We then move down if we manage to launch a "more local" task.
+ var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
+ var lastLaunchTime = clock.getTime() // Time we last launched a task at this level
+
+ /**
+ * Add a task to all the pending-task lists that it should be on. If readding is set, we are
+ * re-adding the task so only include it in each list if it's not already there.
+ */
+ private def addPendingTask(index: Int, readding: Boolean = false) {
+ // Utility method that adds `index` to a list only if readding=false or it's not already there
+ def addTo(list: ArrayBuffer[Int]) {
+ if (!readding || !list.contains(index)) {
+ list += index
+ }
+ }
+
+ var hadAliveLocations = false
+ for (loc <- tasks(index).preferredLocations) {
+ for (execId <- loc.executorId) {
+ if (sched.isExecutorAlive(execId)) {
+ addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
+ hadAliveLocations = true
+ }
+ }
+ if (sched.hasExecutorsAliveOnHost(loc.host)) {
+ addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
+ for (rack <- sched.getRackForHost(loc.host)) {
+ addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
+ }
+ hadAliveLocations = true
+ }
+ }
+
+ if (!hadAliveLocations) {
+ // Even though the task might've had preferred locations, all of those hosts or executors
+ // are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
+ addTo(pendingTasksWithNoPrefs)
+ }
+
+ if (!readding) {
+ allPendingTasks += index // No point scanning this whole list to find the old task there
+ }
+ }
+
+ /**
+ * Return the pending tasks list for a given executor ID, or an empty list if
+ * there is no map entry for that host
+ */
+ private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
+ pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
+ }
+
+ /**
+ * Return the pending tasks list for a given host, or an empty list if
+ * there is no map entry for that host
+ */
+ private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+ pendingTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ /**
+ * Return the pending rack-local task list for a given rack, or an empty list if
+ * there is no map entry for that rack
+ */
+ private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
+ pendingTasksForRack.getOrElse(rack, ArrayBuffer())
+ }
+
+ /**
+ * Dequeue a pending task from the given list and return its index.
+ * Return None if the list is empty.
+ * This method also cleans up any tasks in the list that have already
+ * been launched, since we want that to happen lazily.
+ */
+ private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ while (!list.isEmpty) {
+ val index = list.last
+ list.trimEnd(1)
+ if (copiesRunning(index) == 0 && !finished(index)) {
+ return Some(index)
+ }
+ }
+ return None
+ }
+
+ /** Check whether a task is currently running an attempt on a given host */
+ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
+ !taskAttempts(taskIndex).exists(_.host == host)
+ }
+
+ /**
+ * Return a speculative task for a given executor if any are available. The task should not have
+ * an attempt running on this host, in case the host is slow. In addition, the task should meet
+ * the given locality constraint.
+ */
+ private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+ : Option[(Int, TaskLocality.Value)] =
+ {
+ speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
+
+ if (!speculatableTasks.isEmpty) {
+ // Check for process-local or preference-less tasks; note that tasks can be process-local
+ // on multiple nodes when we replicate cached blocks, as in Spark Streaming
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val prefs = tasks(index).preferredLocations
+ val executors = prefs.flatMap(_.executorId)
+ if (prefs.size == 0 || executors.contains(execId)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+ }
+
+ // Check for node-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val locations = tasks(index).preferredLocations.map(_.host)
+ if (locations.contains(host)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.NODE_LOCAL))
+ }
+ }
+ }
+
+ // Check for rack-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ for (rack <- sched.getRackForHost(host)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
+ if (racks.contains(rack)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.RACK_LOCAL))
+ }
+ }
+ }
+ }
+
+ // Check for non-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.ANY))
+ }
+ }
+ }
+
+ return None
+ }
+
+ /**
+ * Dequeue a pending task for a given node and return its index and locality level.
+ * Only search for tasks matching the given locality constraint.
+ */
+ private def findTask(execId: String, host: String, locality: TaskLocality.Value)
+ : Option[(Int, TaskLocality.Value)] =
+ {
+ for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+ for (index <- findTaskFromList(getPendingTasksForHost(host))) {
+ return Some((index, TaskLocality.NODE_LOCAL))
+ }
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ for {
+ rack <- sched.getRackForHost(host)
+ index <- findTaskFromList(getPendingTasksForRack(rack))
+ } {
+ return Some((index, TaskLocality.RACK_LOCAL))
+ }
+ }
+
+ // Look for no-pref tasks after rack-local tasks since they can run anywhere.
+ for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ for (index <- findTaskFromList(allPendingTasks)) {
+ return Some((index, TaskLocality.ANY))
+ }
+ }
+
+ // Finally, if all else has failed, find a speculative task
+ return findSpeculativeTask(execId, host, locality)
+ }
+
+ /**
+ * Respond to an offer of a single slave from the scheduler by finding a task
+ */
+ override def resourceOffer(
+ execId: String,
+ host: String,
+ availableCpus: Int,
+ maxLocality: TaskLocality.TaskLocality)
+ : Option[TaskDescription] =
+ {
+ if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
+ val curTime = clock.getTime()
+
+ var allowedLocality = getAllowedLocalityLevel(curTime)
+ if (allowedLocality > maxLocality) {
+ allowedLocality = maxLocality // We're not allowed to search for farther-away tasks
+ }
+
+ findTask(execId, host, allowedLocality) match {
+ case Some((index, taskLocality)) => {
+ // Found a task; do some bookkeeping and return a task description
+ val task = tasks(index)
+ val taskId = sched.newTaskId()
+ // Figure out whether this should count as a preferred launch
+ logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
+ taskSet.id, index, taskId, execId, host, taskLocality))
+ // Do various bookkeeping
+ copiesRunning(index) += 1
+ val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
+ taskInfos(taskId) = info
+ taskAttempts(index) = info :: taskAttempts(index)
+ // Update our locality level for delay scheduling
+ currentLocalityIndex = getLocalityIndex(taskLocality)
+ lastLaunchTime = curTime
+ // Serialize and return the task
+ val startTime = clock.getTime()
+ // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
+ // we assume the task can be serialized without exceptions.
+ val serializedTask = Task.serializeWithDependencies(
+ task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ val timeTaken = clock.getTime() - startTime
+ increaseRunningTasks(1)
+ logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+ taskSet.id, index, serializedTask.limit, timeTaken))
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ if (taskAttempts(index).size == 1)
+ taskStarted(task,info)
+ return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
+ }
+ case _ =>
+ }
+ }
+ return None
+ }
+
+ /**
+ * Get the level we can launch tasks according to delay scheduling, based on current wait time.
+ */
+ private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
+ while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
+ currentLocalityIndex < myLocalityLevels.length - 1)
+ {
+ // Jump to the next locality level, and remove our waiting time for the current one since
+ // we don't want to count it again on the next one
+ lastLaunchTime += localityWaits(currentLocalityIndex)
+ currentLocalityIndex += 1
+ }
+ myLocalityLevels(currentLocalityIndex)
+ }
+
+ /**
+ * Find the index in myLocalityLevels for a given locality. This is also designed to work with
+ * localities that are not in myLocalityLevels (in case we somehow get those) by returning the
+ * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
+ */
+ def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
+ var index = 0
+ while (locality > myLocalityLevels(index)) {
+ index += 1
+ }
+ index
+ }
+
+ /** Called by cluster scheduler when one of our tasks changes state */
+ override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ SparkEnv.set(env)
+ state match {
+ case TaskState.FINISHED =>
+ taskFinished(tid, state, serializedData)
+ case TaskState.LOST =>
+ taskLost(tid, state, serializedData)
+ case TaskState.FAILED =>
+ taskLost(tid, state, serializedData)
+ case TaskState.KILLED =>
+ taskLost(tid, state, serializedData)
+ case _ =>
+ }
+ }
+
+ def taskStarted(task: Task[_], info: TaskInfo) {
+ sched.listener.taskStarted(task, info)
+ }
+
+ def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
+ // or even from Mesos itself when acks get delayed.
+ return
+ }
+ val index = info.index
+ info.markSuccessful()
+ decreaseRunningTasks(1)
+ if (!finished(index)) {
+ tasksFinished += 1
+ logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
+ tid, info.duration, info.host, tasksFinished, numTasks))
+ // Deserialize task result and pass it to the scheduler
+ try {
+ val result = ser.deserialize[TaskResult[_]](serializedData)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(
+ tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+ } catch {
+ case cnf: ClassNotFoundException =>
+ val loader = Thread.currentThread().getContextClassLoader
+ throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
+ case ex => throw ex
+ }
+ // Mark finished and stop if we've finished all the tasks
+ finished(index) = true
+ if (tasksFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ } else {
+ logInfo("Ignoring task-finished event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
+ // or even from Mesos itself when acks get delayed.
+ return
+ }
+ val index = info.index
+ info.markFailed()
+ decreaseRunningTasks(1)
+ if (!finished(index)) {
+ logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ copiesRunning(index) -= 1
+ // Check if the problem is a map output fetch failure. In that case, this
+ // task will never succeed on any node, so tell the scheduler about it.
+ if (serializedData != null && serializedData.limit() > 0) {
+ val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
+ reason match {
+ case fetchFailed: FetchFailed =>
+ logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ finished(index) = true
+ tasksFinished += 1
+ sched.taskSetFinished(this)
+ decreaseRunningTasks(runningTasks)
+ return
+
+ case taskResultTooBig: TaskResultTooBigFailure =>
+ logInfo("Loss was due to task %s result exceeding Akka frame size; aborting job".format(
+ tid))
+ abort("Task %s result exceeded Akka frame size".format(tid))
+ return
+
+ case ef: ExceptionFailure =>
+ sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ val key = ef.description
+ val now = clock.getTime()
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
+ }
+ } else {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ }
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
+
+ case _ => {}
+ }
+ }
+ // On non-fetch failures, re-enqueue the task as pending for a max number of retries
+ addPendingTask(index)
+ // Count failed attempts only on FAILED and LOST state (not on KILLED)
+ if (state == TaskState.FAILED || state == TaskState.LOST) {
+ numFailures(index) += 1
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ logError("Task %s:%d failed more than %d times; aborting job".format(
+ taskSet.id, index, MAX_TASK_FAILURES))
+ abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
+ }
+ }
+ } else {
+ logInfo("Ignoring task-lost event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ override def error(message: String) {
+ // Save the error message
+ abort("Error: " + message)
+ }
+
+ def abort(message: String) {
+ failed = true
+ causeOfFailure = message
+ // TODO: Kill running tasks if we were not terminated due to a Mesos error
+ sched.listener.taskSetFailed(taskSet, message)
+ decreaseRunningTasks(runningTasks)
+ sched.taskSetFinished(this)
+ }
+
+ override def increaseRunningTasks(taskNum: Int) {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ override def decreaseRunningTasks(taskNum: Int) {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ override def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ override def addSchedulable(schedulable: Schedulable) {}
+
+ override def removeSchedulable(schedulable: Schedulable) {}
+
+ override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */
+ override def executorLost(execId: String, host: String) {
+ logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
+
+ // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
+ // task that used to have locations on only this host might now go to the no-prefs list. Note
+ // that it's okay if we add a task to the same queue twice (if it had multiple preferred
+ // locations), because findTaskFromList will skip already-running tasks.
+ for (index <- getPendingTasksForExecutor(execId)) {
+ addPendingTask(index, readding=true)
+ }
+ for (index <- getPendingTasksForHost(host)) {
+ addPendingTask(index, readding=true)
+ }
+
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
+ if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ for ((tid, info) <- taskInfos if info.executorId == execId) {
+ val index = taskInfos(tid).index
+ if (finished(index)) {
+ finished(index) = false
+ copiesRunning(index) -= 1
+ tasksFinished -= 1
+ addPendingTask(index)
+ // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+ // stage finishes when a total of tasks.size tasks finish.
+ sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ }
+ }
+ }
+ // Also re-enqueue any tasks that were running on the node
+ for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
+ taskLost(tid, TaskState.KILLED, null)
+ }
+ }
+
+ /**
+ * Check for tasks to be speculated and return true if there are any. This is called periodically
+ * by the ClusterScheduler.
+ *
+ * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+ * we don't scan the whole task set. It might also help to make this sorted by launch time.
+ */
+ override def checkSpeculatableTasks(): Boolean = {
+ // Can't speculate if we only have one task, or if all tasks have finished.
+ if (numTasks == 1 || tasksFinished == numTasks) {
+ return false
+ }
+ var foundTasks = false
+ val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+ logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+ if (tasksFinished >= minFinishedForSpeculation) {
+ val time = clock.getTime()
+ val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+ Arrays.sort(durations)
+ val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
+ val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+ // TODO: Threshold should also look at standard deviation of task durations and have a lower
+ // bound based on that.
+ logDebug("Task length threshold for speculation: " + threshold)
+ for ((tid, info) <- taskInfos) {
+ val index = info.index
+ if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+ !speculatableTasks.contains(index)) {
+ logInfo(
+ "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+ taskSet.id, index, info.host, threshold))
+ speculatableTasks += index
+ foundTasks = true
+ }
+ }
+ }
+ return foundTasks
+ }
+
+ override def hasPendingTasks(): Boolean = {
+ numTasks > 0 && tasksFinished < numTasks
+ }
+
+ private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
+ val defaultWait = System.getProperty("spark.locality.wait", "3000")
+ level match {
+ case TaskLocality.PROCESS_LOCAL =>
+ System.getProperty("spark.locality.wait.process", defaultWait).toLong
+ case TaskLocality.NODE_LOCAL =>
+ System.getProperty("spark.locality.wait.node", defaultWait).toLong
+ case TaskLocality.RACK_LOCAL =>
+ System.getProperty("spark.locality.wait.rack", defaultWait).toLong
+ case TaskLocality.ANY =>
+ 0L
+ }
+ }
+
+ /**
+ * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
+ * added to queues using addPendingTask.
+ */
+ private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
+ import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
+ val levels = new ArrayBuffer[TaskLocality.TaskLocality]
+ if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
+ levels += PROCESS_LOCAL
+ }
+ if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
+ levels += NODE_LOCAL
+ }
+ if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
+ levels += RACK_LOCAL
+ }
+ levels += ANY
+ logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
+ levels.toArray
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala
new file mode 100644
index 0000000000..5077b2b48b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.executor.ExecutorExitCode
+
+/**
+ * Represents an explanation for a executor or whole slave failing or exiting.
+ */
+private[spark]
+class ExecutorLossReason(val message: String) {
+ override def toString: String = message
+}
+
+private[spark]
+case class ExecutorExited(val exitCode: Int)
+ extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) {
+}
+
+private[spark]
+case class SlaveLost(_message: String = "Slave lost")
+ extends ExecutorLossReason(_message) {
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala
new file mode 100644
index 0000000000..35b32600da
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import org.apache.spark.Logging
+import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+
+/**
+ * An Schedulable entity that represent collection of Pools or TaskSetManagers
+ */
+
+private[spark] class Pool(
+ val poolName: String,
+ val schedulingMode: SchedulingMode,
+ initMinShare: Int,
+ initWeight: Int)
+ extends Schedulable
+ with Logging {
+
+ var schedulableQueue = new ArrayBuffer[Schedulable]
+ var schedulableNameToSchedulable = new HashMap[String, Schedulable]
+
+ var weight = initWeight
+ var minShare = initMinShare
+ var runningTasks = 0
+
+ var priority = 0
+ var stageId = 0
+ var name = poolName
+ var parent:Schedulable = null
+
+ var taskSetSchedulingAlgorithm: SchedulingAlgorithm = {
+ schedulingMode match {
+ case SchedulingMode.FAIR =>
+ new FairSchedulingAlgorithm()
+ case SchedulingMode.FIFO =>
+ new FIFOSchedulingAlgorithm()
+ }
+ }
+
+ override def addSchedulable(schedulable: Schedulable) {
+ schedulableQueue += schedulable
+ schedulableNameToSchedulable(schedulable.name) = schedulable
+ schedulable.parent= this
+ }
+
+ override def removeSchedulable(schedulable: Schedulable) {
+ schedulableQueue -= schedulable
+ schedulableNameToSchedulable -= schedulable.name
+ }
+
+ override def getSchedulableByName(schedulableName: String): Schedulable = {
+ if (schedulableNameToSchedulable.contains(schedulableName)) {
+ return schedulableNameToSchedulable(schedulableName)
+ }
+ for (schedulable <- schedulableQueue) {
+ var sched = schedulable.getSchedulableByName(schedulableName)
+ if (sched != null) {
+ return sched
+ }
+ }
+ return null
+ }
+
+ override def executorLost(executorId: String, host: String) {
+ schedulableQueue.foreach(_.executorLost(executorId, host))
+ }
+
+ override def checkSpeculatableTasks(): Boolean = {
+ var shouldRevive = false
+ for (schedulable <- schedulableQueue) {
+ shouldRevive |= schedulable.checkSpeculatableTasks()
+ }
+ return shouldRevive
+ }
+
+ override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ val sortedSchedulableQueue = schedulableQueue.sortWith(taskSetSchedulingAlgorithm.comparator)
+ for (schedulable <- sortedSchedulableQueue) {
+ sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue()
+ }
+ return sortedTaskSetQueue
+ }
+
+ override def increaseRunningTasks(taskNum: Int) {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ override def decreaseRunningTasks(taskNum: Int) {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ override def hasPendingTasks(): Boolean = {
+ schedulableQueue.exists(_.hasPendingTasks())
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala
new file mode 100644
index 0000000000..f4726450ec
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+
+import scala.collection.mutable.ArrayBuffer
+/**
+ * An interface for schedulable entities.
+ * there are two type of Schedulable entities(Pools and TaskSetManagers)
+ */
+private[spark] trait Schedulable {
+ var parent: Schedulable
+ // child queues
+ def schedulableQueue: ArrayBuffer[Schedulable]
+ def schedulingMode: SchedulingMode
+ def weight: Int
+ def minShare: Int
+ def runningTasks: Int
+ def priority: Int
+ def stageId: Int
+ def name: String
+
+ def increaseRunningTasks(taskNum: Int): Unit
+ def decreaseRunningTasks(taskNum: Int): Unit
+ def addSchedulable(schedulable: Schedulable): Unit
+ def removeSchedulable(schedulable: Schedulable): Unit
+ def getSchedulableByName(name: String): Schedulable
+ def executorLost(executorId: String, host: String): Unit
+ def checkSpeculatableTasks(): Boolean
+ def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager]
+ def hasPendingTasks(): Boolean
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala
new file mode 100644
index 0000000000..d04eeb6b98
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulableBuilder.scala
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.io.{File, FileInputStream, FileOutputStream, FileNotFoundException}
+import java.util.Properties
+
+import scala.xml.XML
+
+import org.apache.spark.Logging
+import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+
+
+/**
+ * An interface to build Schedulable tree
+ * buildPools: build the tree nodes(pools)
+ * addTaskSetManager: build the leaf nodes(TaskSetManagers)
+ */
+private[spark] trait SchedulableBuilder {
+ def buildPools()
+ def addTaskSetManager(manager: Schedulable, properties: Properties)
+}
+
+private[spark] class FIFOSchedulableBuilder(val rootPool: Pool)
+ extends SchedulableBuilder with Logging {
+
+ override def buildPools() {
+ // nothing
+ }
+
+ override def addTaskSetManager(manager: Schedulable, properties: Properties) {
+ rootPool.addSchedulable(manager)
+ }
+}
+
+private[spark] class FairSchedulableBuilder(val rootPool: Pool)
+ extends SchedulableBuilder with Logging {
+
+ val schedulerAllocFile = System.getProperty("spark.fairscheduler.allocation.file")
+ val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.cluster.fair.pool"
+ val DEFAULT_POOL_NAME = "default"
+ val MINIMUM_SHARES_PROPERTY = "minShare"
+ val SCHEDULING_MODE_PROPERTY = "schedulingMode"
+ val WEIGHT_PROPERTY = "weight"
+ val POOL_NAME_PROPERTY = "@name"
+ val POOLS_PROPERTY = "pool"
+ val DEFAULT_SCHEDULING_MODE = SchedulingMode.FIFO
+ val DEFAULT_MINIMUM_SHARE = 2
+ val DEFAULT_WEIGHT = 1
+
+ override def buildPools() {
+ if (schedulerAllocFile != null) {
+ val file = new File(schedulerAllocFile)
+ if (file.exists()) {
+ val xml = XML.loadFile(file)
+ for (poolNode <- (xml \\ POOLS_PROPERTY)) {
+
+ val poolName = (poolNode \ POOL_NAME_PROPERTY).text
+ var schedulingMode = DEFAULT_SCHEDULING_MODE
+ var minShare = DEFAULT_MINIMUM_SHARE
+ var weight = DEFAULT_WEIGHT
+
+ val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text
+ if (xmlSchedulingMode != "") {
+ try {
+ schedulingMode = SchedulingMode.withName(xmlSchedulingMode)
+ } catch {
+ case e: Exception => logInfo("Error xml schedulingMode, using default schedulingMode")
+ }
+ }
+
+ val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text
+ if (xmlMinShare != "") {
+ minShare = xmlMinShare.toInt
+ }
+
+ val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text
+ if (xmlWeight != "") {
+ weight = xmlWeight.toInt
+ }
+
+ val pool = new Pool(poolName, schedulingMode, minShare, weight)
+ rootPool.addSchedulable(pool)
+ logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format(
+ poolName, schedulingMode, minShare, weight))
+ }
+ } else {
+ throw new java.io.FileNotFoundException(
+ "Fair scheduler allocation file not found: " + schedulerAllocFile)
+ }
+ }
+
+ // finally create "default" pool
+ if (rootPool.getSchedulableByName(DEFAULT_POOL_NAME) == null) {
+ val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE,
+ DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
+ rootPool.addSchedulable(pool)
+ logInfo("Created default pool %s, schedulingMode: %s, minShare: %d, weight: %d".format(
+ DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT))
+ }
+ }
+
+ override def addTaskSetManager(manager: Schedulable, properties: Properties) {
+ var poolName = DEFAULT_POOL_NAME
+ var parentPool = rootPool.getSchedulableByName(poolName)
+ if (properties != null) {
+ poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME)
+ parentPool = rootPool.getSchedulableByName(poolName)
+ if (parentPool == null) {
+ // we will create a new pool that user has configured in app
+ // instead of being defined in xml file
+ parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE,
+ DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)
+ rootPool.addSchedulable(parentPool)
+ logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format(
+ poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT))
+ }
+ }
+ parentPool.addSchedulable(manager)
+ logInfo("Added task set " + manager.name + " tasks to pool "+poolName)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
new file mode 100644
index 0000000000..bde2f73df4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.{SparkContext, Utils}
+
+/**
+ * A backend interface for cluster scheduling systems that allows plugging in different ones under
+ * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
+ * machines become available and can launch tasks on them.
+ */
+private[spark] trait SchedulerBackend {
+ def start(): Unit
+ def stop(): Unit
+ def reviveOffers(): Unit
+ def defaultParallelism(): Int
+
+ // Memory used by each executor (in megabytes)
+ protected val executorMemory: Int = SparkContext.executorMemoryRequested
+
+ // TODO: Probably want to add a killTask too
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala
new file mode 100644
index 0000000000..cbeed4731a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingAlgorithm.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+/**
+ * An interface for sort algorithm
+ * FIFO: FIFO algorithm between TaskSetManagers
+ * FS: FS algorithm between Pools, and FIFO or FS within Pools
+ */
+private[spark] trait SchedulingAlgorithm {
+ def comparator(s1: Schedulable, s2: Schedulable): Boolean
+}
+
+private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm {
+ override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {
+ val priority1 = s1.priority
+ val priority2 = s2.priority
+ var res = math.signum(priority1 - priority2)
+ if (res == 0) {
+ val stageId1 = s1.stageId
+ val stageId2 = s2.stageId
+ res = math.signum(stageId1 - stageId2)
+ }
+ if (res < 0) {
+ return true
+ } else {
+ return false
+ }
+ }
+}
+
+private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm {
+ override def comparator(s1: Schedulable, s2: Schedulable): Boolean = {
+ val minShare1 = s1.minShare
+ val minShare2 = s2.minShare
+ val runningTasks1 = s1.runningTasks
+ val runningTasks2 = s2.runningTasks
+ val s1Needy = runningTasks1 < minShare1
+ val s2Needy = runningTasks2 < minShare2
+ val minShareRatio1 = runningTasks1.toDouble / math.max(minShare1, 1.0).toDouble
+ val minShareRatio2 = runningTasks2.toDouble / math.max(minShare2, 1.0).toDouble
+ val taskToWeightRatio1 = runningTasks1.toDouble / s1.weight.toDouble
+ val taskToWeightRatio2 = runningTasks2.toDouble / s2.weight.toDouble
+ var res:Boolean = true
+ var compare:Int = 0
+
+ if (s1Needy && !s2Needy) {
+ return true
+ } else if (!s1Needy && s2Needy) {
+ return false
+ } else if (s1Needy && s2Needy) {
+ compare = minShareRatio1.compareTo(minShareRatio2)
+ } else {
+ compare = taskToWeightRatio1.compareTo(taskToWeightRatio2)
+ }
+
+ if (compare < 0) {
+ return true
+ } else if (compare > 0) {
+ return false
+ } else {
+ return s1.name < s2.name
+ }
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala
new file mode 100644
index 0000000000..34811389a0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulingMode.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+/**
+ * "FAIR" and "FIFO" determines which policy is used
+ * to order tasks amongst a Schedulable's sub-queues
+ * "NONE" is used when the a Schedulable has no sub-queues.
+ */
+object SchedulingMode extends Enumeration("FAIR", "FIFO", "NONE") {
+
+ type SchedulingMode = Value
+ val FAIR,FIFO,NONE = Value
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
new file mode 100644
index 0000000000..ac6dc7d879
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.{Utils, Logging, SparkContext}
+import org.apache.spark.deploy.client.{Client, ClientListener}
+import org.apache.spark.deploy.{Command, ApplicationDescription}
+import scala.collection.mutable.HashMap
+
+private[spark] class SparkDeploySchedulerBackend(
+ scheduler: ClusterScheduler,
+ sc: SparkContext,
+ master: String,
+ appName: String)
+ extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
+ with ClientListener
+ with Logging {
+
+ var client: Client = null
+ var stopping = false
+ var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
+
+ val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
+
+ override def start() {
+ super.start()
+
+ // The endpoint for executors to talk to us
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+ StandaloneSchedulerBackend.ACTOR_NAME)
+ val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
+ val command = Command(
+ "org.apache.spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
+ val sparkHome = sc.getSparkHome().getOrElse(null)
+ val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome,
+ sc.ui.appUIAddress)
+
+ client = new Client(sc.env.actorSystem, master, appDesc, this)
+ client.start()
+ }
+
+ override def stop() {
+ stopping = true
+ super.stop()
+ client.stop()
+ if (shutdownCallback != null) {
+ shutdownCallback(this)
+ }
+ }
+
+ override def connected(appId: String) {
+ logInfo("Connected to Spark cluster with app ID " + appId)
+ }
+
+ override def disconnected() {
+ if (!stopping) {
+ logError("Disconnected from Spark cluster!")
+ scheduler.error("Disconnected from Spark cluster")
+ }
+ }
+
+ override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) {
+ logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format(
+ executorId, hostPort, cores, Utils.megabytesToString(memory)))
+ }
+
+ override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) {
+ val reason: ExecutorLossReason = exitStatus match {
+ case Some(code) => ExecutorExited(code)
+ case None => SlaveLost(message)
+ }
+ logInfo("Executor %s removed: %s".format(executorId, message))
+ removeExecutor(executorId, reason.toString)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
new file mode 100644
index 0000000000..1cc5daf673
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.Utils
+import org.apache.spark.util.SerializableBuffer
+
+
+private[spark] sealed trait StandaloneClusterMessage extends Serializable
+
+private[spark] object StandaloneClusterMessages {
+
+ // Driver to executors
+ case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
+
+ case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
+ extends StandaloneClusterMessage
+
+ case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage
+
+ // Executors to driver
+ case class RegisterExecutor(executorId: String, hostPort: String, cores: Int)
+ extends StandaloneClusterMessage {
+ Utils.checkHostPort(hostPort, "Expected host port")
+ }
+
+ case class StatusUpdate(executorId: String, taskId: Long, state: TaskState,
+ data: SerializableBuffer) extends StandaloneClusterMessage
+
+ object StatusUpdate {
+ /** Alternate factory method that takes a ByteBuffer directly for the data field */
+ def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer)
+ : StatusUpdate = {
+ StatusUpdate(executorId, taskId, state, new SerializableBuffer(data))
+ }
+ }
+
+ // Internal messages in driver
+ case object ReviveOffers extends StandaloneClusterMessage
+
+ case object StopDriver extends StandaloneClusterMessage
+
+ case class RemoveExecutor(executorId: String, reason: String) extends StandaloneClusterMessage
+
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
new file mode 100644
index 0000000000..3677a827e0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
+import akka.actor._
+import akka.dispatch.Await
+import akka.pattern.ask
+import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
+import akka.util.Duration
+import akka.util.duration._
+
+import org.apache.spark.{Utils, SparkException, Logging, TaskState}
+import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
+
+/**
+ * A standalone scheduler backend, which waits for standalone executors to connect to it through
+ * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained
+ * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*).
+ */
+private[spark]
+class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
+ extends SchedulerBackend with Logging
+{
+ // Use an atomic variable to track total number of cores in the cluster for simplicity and speed
+ var totalCoreCount = new AtomicInteger(0)
+
+ class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
+ private val executorActor = new HashMap[String, ActorRef]
+ private val executorAddress = new HashMap[String, Address]
+ private val executorHost = new HashMap[String, String]
+ private val freeCores = new HashMap[String, Int]
+ private val actorToExecutorId = new HashMap[ActorRef, String]
+ private val addressToExecutorId = new HashMap[Address, String]
+
+ override def preStart() {
+ // Listen for remote client disconnection events, since they don't go through Akka's watch()
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+
+ // Periodically revive offers to allow delay scheduling to work
+ val reviveInterval = System.getProperty("spark.scheduler.revive.interval", "1000").toLong
+ context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers)
+ }
+
+ def receive = {
+ case RegisterExecutor(executorId, hostPort, cores) =>
+ Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
+ if (executorActor.contains(executorId)) {
+ sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
+ } else {
+ logInfo("Registered executor: " + sender + " with ID " + executorId)
+ sender ! RegisteredExecutor(sparkProperties)
+ context.watch(sender)
+ executorActor(executorId) = sender
+ executorHost(executorId) = Utils.parseHostPort(hostPort)._1
+ freeCores(executorId) = cores
+ executorAddress(executorId) = sender.path.address
+ actorToExecutorId(sender) = executorId
+ addressToExecutorId(sender.path.address) = executorId
+ totalCoreCount.addAndGet(cores)
+ makeOffers()
+ }
+
+ case StatusUpdate(executorId, taskId, state, data) =>
+ scheduler.statusUpdate(taskId, state, data.value)
+ if (TaskState.isFinished(state)) {
+ freeCores(executorId) += 1
+ makeOffers(executorId)
+ }
+
+ case ReviveOffers =>
+ makeOffers()
+
+ case StopDriver =>
+ sender ! true
+ context.stop(self)
+
+ case RemoveExecutor(executorId, reason) =>
+ removeExecutor(executorId, reason)
+ sender ! true
+
+ case Terminated(actor) =>
+ actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated"))
+
+ case RemoteClientDisconnected(transport, address) =>
+ addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected"))
+
+ case RemoteClientShutdown(transport, address) =>
+ addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown"))
+ }
+
+ // Make fake resource offers on all executors
+ def makeOffers() {
+ launchTasks(scheduler.resourceOffers(
+ executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
+ }
+
+ // Make fake resource offers on just one executor
+ def makeOffers(executorId: String) {
+ launchTasks(scheduler.resourceOffers(
+ Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
+ }
+
+ // Launch tasks returned by a set of resource offers
+ def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
+ for (task <- tasks.flatten) {
+ freeCores(task.executorId) -= 1
+ executorActor(task.executorId) ! LaunchTask(task)
+ }
+ }
+
+ // Remove a disconnected slave from the cluster
+ def removeExecutor(executorId: String, reason: String) {
+ if (executorActor.contains(executorId)) {
+ logInfo("Executor " + executorId + " disconnected, so removing it")
+ val numCores = freeCores(executorId)
+ actorToExecutorId -= executorActor(executorId)
+ addressToExecutorId -= executorAddress(executorId)
+ executorActor -= executorId
+ executorHost -= executorId
+ freeCores -= executorId
+ totalCoreCount.addAndGet(-numCores)
+ scheduler.executorLost(executorId, SlaveLost(reason))
+ }
+ }
+ }
+
+ var driverActor: ActorRef = null
+ val taskIdsOnSlave = new HashMap[String, HashSet[String]]
+
+ override def start() {
+ val properties = new ArrayBuffer[(String, String)]
+ val iterator = System.getProperties.entrySet.iterator
+ while (iterator.hasNext) {
+ val entry = iterator.next
+ val (key, value) = (entry.getKey.toString, entry.getValue.toString)
+ if (key.startsWith("spark.") && !key.equals("spark.hostPort")) {
+ properties += ((key, value))
+ }
+ }
+ driverActor = actorSystem.actorOf(
+ Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
+ }
+
+ private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+
+ override def stop() {
+ try {
+ if (driverActor != null) {
+ val future = driverActor.ask(StopDriver)(timeout)
+ Await.result(future, timeout)
+ }
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Error stopping standalone scheduler's driver actor", e)
+ }
+ }
+
+ override def reviveOffers() {
+ driverActor ! ReviveOffers
+ }
+
+ override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
+ .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2))
+
+ // Called by subclasses when notified of a lost worker
+ def removeExecutor(executorId: String, reason: String) {
+ try {
+ val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
+ Await.result(future, timeout)
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Error notifying standalone scheduler's driver actor", e)
+ }
+ }
+}
+
+private[spark] object StandaloneSchedulerBackend {
+ val ACTOR_NAME = "StandaloneScheduler"
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala
new file mode 100644
index 0000000000..309ac2f6c9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskDescription.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.nio.ByteBuffer
+import org.apache.spark.util.SerializableBuffer
+
+private[spark] class TaskDescription(
+ val taskId: Long,
+ val executorId: String,
+ val name: String,
+ val index: Int, // Index within this task's TaskSet
+ _serializedTask: ByteBuffer)
+ extends Serializable {
+
+ // Because ByteBuffers are not serializable, wrap the task in a SerializableBuffer
+ private val buffer = new SerializableBuffer(_serializedTask)
+
+ def serializedTask: ByteBuffer = buffer.value
+
+ override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index)
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala
new file mode 100644
index 0000000000..7ce14be7fb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskInfo.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.Utils
+
+/**
+ * Information about a running task attempt inside a TaskSet.
+ */
+private[spark]
+class TaskInfo(
+ val taskId: Long,
+ val index: Int,
+ val launchTime: Long,
+ val executorId: String,
+ val host: String,
+ val taskLocality: TaskLocality.TaskLocality) {
+
+ var finishTime: Long = 0
+ var failed = false
+
+ def markSuccessful(time: Long = System.currentTimeMillis) {
+ finishTime = time
+ }
+
+ def markFailed(time: Long = System.currentTimeMillis) {
+ finishTime = time
+ failed = true
+ }
+
+ def finished: Boolean = finishTime != 0
+
+ def successful: Boolean = finished && !failed
+
+ def running: Boolean = !finished
+
+ def status: String = {
+ if (running)
+ "RUNNING"
+ else if (failed)
+ "FAILED"
+ else if (successful)
+ "SUCCESS"
+ else
+ "UNKNOWN"
+ }
+
+ def duration: Long = {
+ if (!finished) {
+ throw new UnsupportedOperationException("duration() called on unfinished tasks")
+ } else {
+ finishTime - launchTime
+ }
+ }
+
+ def timeRunning(currentTime: Long): Long = currentTime - launchTime
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala
new file mode 100644
index 0000000000..5d4130e14a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskLocality.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+
+private[spark] object TaskLocality
+ extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY")
+{
+ // process local is expected to be used ONLY within tasksetmanager for now.
+ val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
+
+ type TaskLocality = Value
+
+ def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
+ condition <= constraint
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala
new file mode 100644
index 0000000000..648a3ef922
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.scheduler.TaskSet
+
+/**
+ * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of
+ * each task and is responsible for retries on failure and locality. The main interfaces to it
+ * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and
+ * statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ *
+ * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler
+ * (e.g. its event handlers). It should not be called from other threads.
+ */
+private[spark] trait TaskSetManager extends Schedulable {
+ def schedulableQueue = null
+
+ def schedulingMode = SchedulingMode.NONE
+
+ def taskSet: TaskSet
+
+ def resourceOffer(
+ execId: String,
+ host: String,
+ availableCpus: Int,
+ maxLocality: TaskLocality.TaskLocality)
+ : Option[TaskDescription]
+
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer)
+
+ def error(message: String)
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala
new file mode 100644
index 0000000000..938f62883a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+/**
+ * Represents free resources available on an executor.
+ */
+private[spark]
+class WorkerOffer(val executorId: String, val host: String, val cores: Int)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
new file mode 100644
index 0000000000..f0ebe66d82
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -0,0 +1,272 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.local
+
+import java.io.File
+import java.lang.management.ManagementFactory
+import java.util.concurrent.atomic.AtomicInteger
+import java.nio.ByteBuffer
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+
+import org.apache.spark._
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.executor.ExecutorURLClassLoader
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster._
+import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+import akka.actor._
+
+/**
+ * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
+ * the scheduler also allows each task to fail up to maxFailures times, which is useful for
+ * testing fault recovery.
+ */
+
+private[spark]
+case class LocalReviveOffers()
+
+private[spark]
+case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+
+private[spark]
+class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
+
+ def receive = {
+ case LocalReviveOffers =>
+ launchTask(localScheduler.resourceOffer(freeCores))
+ case LocalStatusUpdate(taskId, state, serializeData) =>
+ freeCores += 1
+ localScheduler.statusUpdate(taskId, state, serializeData)
+ launchTask(localScheduler.resourceOffer(freeCores))
+ }
+
+ def launchTask(tasks : Seq[TaskDescription]) {
+ for (task <- tasks) {
+ freeCores -= 1
+ localScheduler.threadPool.submit(new Runnable {
+ def run() {
+ localScheduler.runTask(task.taskId, task.serializedTask)
+ }
+ })
+ }
+ }
+}
+
+private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
+ extends TaskScheduler
+ with Logging {
+
+ var attemptId = new AtomicInteger(0)
+ var threadPool = Utils.newDaemonFixedThreadPool(threads)
+ val env = SparkEnv.get
+ var listener: TaskSchedulerListener = null
+
+ // Application dependencies (added through SparkContext) that we've fetched so far on this node.
+ // Each map holds the master's timestamp for the version of that file or JAR we got.
+ val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
+ val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
+
+ val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
+
+ var schedulableBuilder: SchedulableBuilder = null
+ var rootPool: Pool = null
+ val schedulingMode: SchedulingMode = SchedulingMode.withName(
+ System.getProperty("spark.cluster.schedulingmode", "FIFO"))
+ val activeTaskSets = new HashMap[String, TaskSetManager]
+ val taskIdToTaskSetId = new HashMap[Long, String]
+ val taskSetTaskIds = new HashMap[String, HashSet[Long]]
+
+ var localActor: ActorRef = null
+
+ override def start() {
+ // temporarily set rootPool name to empty
+ rootPool = new Pool("", schedulingMode, 0, 0)
+ schedulableBuilder = {
+ schedulingMode match {
+ case SchedulingMode.FIFO =>
+ new FIFOSchedulableBuilder(rootPool)
+ case SchedulingMode.FAIR =>
+ new FairSchedulableBuilder(rootPool)
+ }
+ }
+ schedulableBuilder.buildPools()
+
+ localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
+ }
+
+ override def setListener(listener: TaskSchedulerListener) {
+ this.listener = listener
+ }
+
+ override def submitTasks(taskSet: TaskSet) {
+ synchronized {
+ val manager = new LocalTaskSetManager(this, taskSet)
+ schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
+ activeTaskSets(taskSet.id) = manager
+ taskSetTaskIds(taskSet.id) = new HashSet[Long]()
+ localActor ! LocalReviveOffers
+ }
+ }
+
+ def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
+ synchronized {
+ var freeCpuCores = freeCores
+ val tasks = new ArrayBuffer[TaskDescription](freeCores)
+ val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
+ for (manager <- sortedTaskSetQueue) {
+ logDebug("parentName:%s,name:%s,runningTasks:%s".format(
+ manager.parent.name, manager.name, manager.runningTasks))
+ }
+
+ var launchTask = false
+ for (manager <- sortedTaskSetQueue) {
+ do {
+ launchTask = false
+ manager.resourceOffer(null, null, freeCpuCores, null) match {
+ case Some(task) =>
+ tasks += task
+ taskIdToTaskSetId(task.taskId) = manager.taskSet.id
+ taskSetTaskIds(manager.taskSet.id) += task.taskId
+ freeCpuCores -= 1
+ launchTask = true
+ case None => {}
+ }
+ } while(launchTask)
+ }
+ return tasks
+ }
+ }
+
+ def taskSetFinished(manager: TaskSetManager) {
+ synchronized {
+ activeTaskSets -= manager.taskSet.id
+ manager.parent.removeSchedulable(manager)
+ logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
+ taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+ taskSetTaskIds -= manager.taskSet.id
+ }
+ }
+
+ def runTask(taskId: Long, bytes: ByteBuffer) {
+ logInfo("Running " + taskId)
+ val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
+ // Set the Spark execution environment for the worker thread
+ SparkEnv.set(env)
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val objectSer = SparkEnv.get.serializer.newInstance()
+ var attemptedTask: Option[Task[_]] = None
+ val start = System.currentTimeMillis()
+ var taskStart: Long = 0
+ def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
+ val startGCTime = getTotalGCTime
+
+ try {
+ Accumulators.clear()
+ Thread.currentThread().setContextClassLoader(classLoader)
+
+ // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
+ // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
+ val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
+ updateDependencies(taskFiles, taskJars) // Download any files added with addFile
+ val deserializedTask = ser.deserialize[Task[_]](
+ taskBytes, Thread.currentThread.getContextClassLoader)
+ attemptedTask = Some(deserializedTask)
+ val deserTime = System.currentTimeMillis() - start
+ taskStart = System.currentTimeMillis()
+
+ // Run it
+ val result: Any = deserializedTask.run(taskId)
+
+ // Serialize and deserialize the result to emulate what the Mesos
+ // executor does. This is useful to catch serialization errors early
+ // on in development (so when users move their local Spark programs
+ // to the cluster, they don't get surprised by serialization errors).
+ val serResult = objectSer.serialize(result)
+ deserializedTask.metrics.get.resultSize = serResult.limit()
+ val resultToReturn = objectSer.deserialize[Any](serResult)
+ val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
+ ser.serialize(Accumulators.values))
+ val serviceTime = System.currentTimeMillis() - taskStart
+ logInfo("Finished " + taskId)
+ deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
+ deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime
+ deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
+ val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
+ val serializedResult = ser.serialize(taskResult)
+ localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
+ } catch {
+ case t: Throwable => {
+ val serviceTime = System.currentTimeMillis() - taskStart
+ val metrics = attemptedTask.flatMap(t => t.metrics)
+ for (m <- metrics) {
+ m.executorRunTime = serviceTime.toInt
+ m.jvmGCTime = getTotalGCTime - startGCTime
+ }
+ val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
+ localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
+ }
+ }
+ }
+
+ /**
+ * Download any missing dependencies if we receive a new set of files and JARs from the
+ * SparkContext. Also adds any new JARs we fetched to the class loader.
+ */
+ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
+ synchronized {
+ // Fetch missing dependencies
+ for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentFiles(name) = timestamp
+ }
+
+ for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val localName = name.split("/").last
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
+ if (!classLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ classLoader.addURL(url)
+ }
+ }
+ }
+ }
+
+ def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
+ synchronized {
+ val taskSetId = taskIdToTaskSetId(taskId)
+ val taskSetManager = activeTaskSets(taskSetId)
+ taskSetTaskIds(taskSetId) -= taskId
+ taskSetManager.statusUpdate(taskId, state, serializedData)
+ }
+ }
+
+ override def stop() {
+ threadPool.shutdownNow()
+ }
+
+ override def defaultParallelism() = threads
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
new file mode 100644
index 0000000000..e52cb998bd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.local
+
+import java.nio.ByteBuffer
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+
+import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, Success, TaskState}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.scheduler.{Task, TaskResult, TaskSet}
+import org.apache.spark.scheduler.cluster.{Schedulable, TaskDescription, TaskInfo, TaskLocality, TaskSetManager}
+
+
+private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet)
+ extends TaskSetManager with Logging {
+
+ var parent: Schedulable = null
+ var weight: Int = 1
+ var minShare: Int = 0
+ var runningTasks: Int = 0
+ var priority: Int = taskSet.priority
+ var stageId: Int = taskSet.stageId
+ var name: String = "TaskSet_" + taskSet.stageId.toString
+
+ var failCount = new Array[Int](taskSet.tasks.size)
+ val taskInfos = new HashMap[Long, TaskInfo]
+ val numTasks = taskSet.tasks.size
+ var numFinished = 0
+ val env = SparkEnv.get
+ val ser = env.closureSerializer.newInstance()
+ val copiesRunning = new Array[Int](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val MAX_TASK_FAILURES = sched.maxFailures
+
+ override def increaseRunningTasks(taskNum: Int): Unit = {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ override def decreaseRunningTasks(taskNum: Int): Unit = {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ override def addSchedulable(schedulable: Schedulable): Unit = {
+ // nothing
+ }
+
+ override def removeSchedulable(schedulable: Schedulable): Unit = {
+ // nothing
+ }
+
+ override def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ override def executorLost(executorId: String, host: String): Unit = {
+ // nothing
+ }
+
+ override def checkSpeculatableTasks() = true
+
+ override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ override def hasPendingTasks() = true
+
+ def findTask(): Option[Int] = {
+ for (i <- 0 to numTasks-1) {
+ if (copiesRunning(i) == 0 && !finished(i)) {
+ return Some(i)
+ }
+ }
+ return None
+ }
+
+ override def resourceOffer(
+ execId: String,
+ host: String,
+ availableCpus: Int,
+ maxLocality: TaskLocality.TaskLocality)
+ : Option[TaskDescription] =
+ {
+ SparkEnv.set(sched.env)
+ logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format(
+ availableCpus.toInt, numFinished, numTasks))
+ if (availableCpus > 0 && numFinished < numTasks) {
+ findTask() match {
+ case Some(index) =>
+ val taskId = sched.attemptId.getAndIncrement()
+ val task = taskSet.tasks(index)
+ val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1",
+ TaskLocality.NODE_LOCAL)
+ taskInfos(taskId) = info
+ // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
+ // we assume the task can be serialized without exceptions.
+ val bytes = Task.serializeWithDependencies(
+ task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ copiesRunning(index) += 1
+ increaseRunningTasks(1)
+ taskStarted(task, info)
+ return Some(new TaskDescription(taskId, null, taskName, index, bytes))
+ case None => {}
+ }
+ }
+ return None
+ }
+
+ override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ SparkEnv.set(env)
+ state match {
+ case TaskState.FINISHED =>
+ taskEnded(tid, state, serializedData)
+ case TaskState.FAILED =>
+ taskFailed(tid, state, serializedData)
+ case _ => {}
+ }
+ }
+
+ def taskStarted(task: Task[_], info: TaskInfo) {
+ sched.listener.taskStarted(task, info)
+ }
+
+ def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ val index = info.index
+ val task = taskSet.tasks(index)
+ info.markSuccessful()
+ val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
+ numFinished += 1
+ decreaseRunningTasks(1)
+ finished(index) = true
+ if (numFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ }
+
+ def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ val index = info.index
+ val task = taskSet.tasks(index)
+ info.markFailed()
+ decreaseRunningTasks(1)
+ val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
+ serializedData, getClass.getClassLoader)
+ sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
+ if (!finished(index)) {
+ copiesRunning(index) -= 1
+ numFailures(index) += 1
+ val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s\n%s".format(
+ reason.className, reason.description, locs.mkString("\n")))
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
+ taskSet.id, index, 4, reason.description)
+ decreaseRunningTasks(runningTasks)
+ sched.listener.taskSetFailed(taskSet, errorMessage)
+ // need to delete failed Taskset from schedule queue
+ sched.taskSetFinished(this)
+ }
+ }
+ }
+
+ override def error(message: String) {
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
new file mode 100644
index 0000000000..f6a2feab28
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.mesos
+
+import com.google.protobuf.ByteString
+
+import org.apache.mesos.{Scheduler => MScheduler}
+import org.apache.mesos._
+import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
+
+import org.apache.spark.{SparkException, Utils, Logging, SparkContext}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.JavaConversions._
+import java.io.File
+import org.apache.spark.scheduler.cluster._
+import java.util.{ArrayList => JArrayList, List => JList}
+import java.util.Collections
+import org.apache.spark.TaskState
+
+/**
+ * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
+ * onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever
+ * a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the
+ * StandaloneBackend mechanism. This class is useful for lower and more predictable latency.
+ *
+ * Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to
+ * remove this.
+ */
+private[spark] class CoarseMesosSchedulerBackend(
+ scheduler: ClusterScheduler,
+ sc: SparkContext,
+ master: String,
+ appName: String)
+ extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
+ with MScheduler
+ with Logging {
+
+ val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures
+
+ // Lock used to wait for scheduler to be registered
+ var isRegistered = false
+ val registeredLock = new Object()
+
+ // Driver for talking to Mesos
+ var driver: SchedulerDriver = null
+
+ // Maximum number of cores to acquire (TODO: we'll need more flexible controls here)
+ val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt
+
+ // Cores we have acquired with each Mesos task ID
+ val coresByTaskId = new HashMap[Int, Int]
+ var totalCoresAcquired = 0
+
+ val slaveIdsWithExecutors = new HashSet[String]
+
+ val taskIdToSlaveId = new HashMap[Int, String]
+ val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed
+
+ val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
+ "Spark home is not set; set it through the spark.home system " +
+ "property, the SPARK_HOME environment variable or the SparkContext constructor"))
+
+ val extraCoresPerSlave = System.getProperty("spark.mesos.extra.cores", "0").toInt
+
+ var nextMesosTaskId = 0
+
+ def newMesosTaskId(): Int = {
+ val id = nextMesosTaskId
+ nextMesosTaskId += 1
+ id
+ }
+
+ override def start() {
+ super.start()
+
+ synchronized {
+ new Thread("CoarseMesosSchedulerBackend driver") {
+ setDaemon(true)
+ override def run() {
+ val scheduler = CoarseMesosSchedulerBackend.this
+ val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build()
+ driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
+ try { {
+ val ret = driver.run()
+ logInfo("driver.run() returned with code " + ret)
+ }
+ } catch {
+ case e: Exception => logError("driver.run() failed", e)
+ }
+ }
+ }.start()
+
+ waitForRegister()
+ }
+ }
+
+ def createCommand(offer: Offer, numCores: Int): CommandInfo = {
+ val environment = Environment.newBuilder()
+ sc.executorEnvs.foreach { case (key, value) =>
+ environment.addVariables(Environment.Variable.newBuilder()
+ .setName(key)
+ .setValue(value)
+ .build())
+ }
+ val command = CommandInfo.newBuilder()
+ .setEnvironment(environment)
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"),
+ System.getProperty("spark.driver.port"),
+ StandaloneSchedulerBackend.ACTOR_NAME)
+ val uri = System.getProperty("spark.executor.uri")
+ if (uri == null) {
+ val runScript = new File(sparkHome, "spark-class").getCanonicalPath
+ command.setValue(
+ "\"%s\" org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
+ runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
+ } else {
+ // Grab everything to the first '.'. We'll use that and '*' to
+ // glob the directory "correctly".
+ val basename = uri.split('/').last.split('.').head
+ command.setValue(
+ "cd %s*; ./spark-class org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
+ basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
+ command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
+ }
+ return command.build()
+ }
+
+ override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
+
+ override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
+ logInfo("Registered as framework ID " + frameworkId.getValue)
+ registeredLock.synchronized {
+ isRegistered = true
+ registeredLock.notifyAll()
+ }
+ }
+
+ def waitForRegister() {
+ registeredLock.synchronized {
+ while (!isRegistered) {
+ registeredLock.wait()
+ }
+ }
+ }
+
+ override def disconnected(d: SchedulerDriver) {}
+
+ override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
+
+ /**
+ * Method called by Mesos to offer resources on slaves. We respond by launching an executor,
+ * unless we've already launched more than we wanted to.
+ */
+ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
+ synchronized {
+ val filters = Filters.newBuilder().setRefuseSeconds(-1).build()
+
+ for (offer <- offers) {
+ val slaveId = offer.getSlaveId.toString
+ val mem = getResource(offer.getResourcesList, "mem")
+ val cpus = getResource(offer.getResourcesList, "cpus").toInt
+ if (totalCoresAcquired < maxCores && mem >= executorMemory && cpus >= 1 &&
+ failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES &&
+ !slaveIdsWithExecutors.contains(slaveId)) {
+ // Launch an executor on the slave
+ val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired)
+ val taskId = newMesosTaskId()
+ taskIdToSlaveId(taskId) = slaveId
+ slaveIdsWithExecutors += slaveId
+ coresByTaskId(taskId) = cpusToUse
+ val task = MesosTaskInfo.newBuilder()
+ .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build())
+ .setSlaveId(offer.getSlaveId)
+ .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave))
+ .setName("Task " + taskId)
+ .addResources(createResource("cpus", cpusToUse))
+ .addResources(createResource("mem", executorMemory))
+ .build()
+ d.launchTasks(offer.getId, Collections.singletonList(task), filters)
+ } else {
+ // Filter it out
+ d.launchTasks(offer.getId, Collections.emptyList[MesosTaskInfo](), filters)
+ }
+ }
+ }
+ }
+
+ /** Helper function to pull out a resource from a Mesos Resources protobuf */
+ private def getResource(res: JList[Resource], name: String): Double = {
+ for (r <- res if r.getName == name) {
+ return r.getScalar.getValue
+ }
+ // If we reached here, no resource with the required name was present
+ throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ }
+
+ /** Build a Mesos resource protobuf object */
+ private def createResource(resourceName: String, quantity: Double): Protos.Resource = {
+ Resource.newBuilder()
+ .setName(resourceName)
+ .setType(Value.Type.SCALAR)
+ .setScalar(Value.Scalar.newBuilder().setValue(quantity).build())
+ .build()
+ }
+
+ /** Check whether a Mesos task state represents a finished task */
+ private def isFinished(state: MesosTaskState) = {
+ state == MesosTaskState.TASK_FINISHED ||
+ state == MesosTaskState.TASK_FAILED ||
+ state == MesosTaskState.TASK_KILLED ||
+ state == MesosTaskState.TASK_LOST
+ }
+
+ override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
+ val taskId = status.getTaskId.getValue.toInt
+ val state = status.getState
+ logInfo("Mesos task " + taskId + " is now " + state)
+ synchronized {
+ if (isFinished(state)) {
+ val slaveId = taskIdToSlaveId(taskId)
+ slaveIdsWithExecutors -= slaveId
+ taskIdToSlaveId -= taskId
+ // Remove the cores we have remembered for this task, if it's in the hashmap
+ for (cores <- coresByTaskId.get(taskId)) {
+ totalCoresAcquired -= cores
+ coresByTaskId -= taskId
+ }
+ // If it was a failure, mark the slave as failed for blacklisting purposes
+ if (state == MesosTaskState.TASK_FAILED || state == MesosTaskState.TASK_LOST) {
+ failuresBySlaveId(slaveId) = failuresBySlaveId.getOrElse(slaveId, 0) + 1
+ if (failuresBySlaveId(slaveId) >= MAX_SLAVE_FAILURES) {
+ logInfo("Blacklisting Mesos slave " + slaveId + " due to too many failures; " +
+ "is Spark installed on it?")
+ }
+ }
+ driver.reviveOffers() // In case we'd rejected everything before but have now lost a node
+ }
+ }
+ }
+
+ override def error(d: SchedulerDriver, message: String) {
+ logError("Mesos error: " + message)
+ scheduler.error(message)
+ }
+
+ override def stop() {
+ super.stop()
+ if (driver != null) {
+ driver.stop()
+ }
+ }
+
+ override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
+
+ override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {
+ logInfo("Mesos slave lost: " + slaveId.getValue)
+ synchronized {
+ if (slaveIdsWithExecutors.contains(slaveId.getValue)) {
+ // Note that the slave ID corresponds to the executor ID on that slave
+ slaveIdsWithExecutors -= slaveId.getValue
+ removeExecutor(slaveId.getValue, "Mesos slave lost")
+ }
+ }
+ }
+
+ override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) {
+ logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue))
+ slaveLost(d, s)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala
new file mode 100644
index 0000000000..e002af1742
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackend.scala
@@ -0,0 +1,342 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.mesos
+
+import com.google.protobuf.ByteString
+
+import org.apache.mesos.{Scheduler => MScheduler}
+import org.apache.mesos._
+import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
+
+import org.apache.spark.{SparkException, Utils, Logging, SparkContext}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.JavaConversions._
+import java.io.File
+import org.apache.spark.scheduler.cluster._
+import java.util.{ArrayList => JArrayList, List => JList}
+import java.util.Collections
+import org.apache.spark.TaskState
+
+/**
+ * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a
+ * separate Mesos task, allowing multiple applications to share cluster nodes both in space (tasks
+ * from multiple apps can run on different cores) and in time (a core can switch ownership).
+ */
+private[spark] class MesosSchedulerBackend(
+ scheduler: ClusterScheduler,
+ sc: SparkContext,
+ master: String,
+ appName: String)
+ extends SchedulerBackend
+ with MScheduler
+ with Logging {
+
+ // Lock used to wait for scheduler to be registered
+ var isRegistered = false
+ val registeredLock = new Object()
+
+ // Driver for talking to Mesos
+ var driver: SchedulerDriver = null
+
+ // Which slave IDs we have executors on
+ val slaveIdsWithExecutors = new HashSet[String]
+ val taskIdToSlaveId = new HashMap[Long, String]
+
+ // An ExecutorInfo for our tasks
+ var execArgs: Array[Byte] = null
+
+ var classLoader: ClassLoader = null
+
+ override def start() {
+ synchronized {
+ classLoader = Thread.currentThread.getContextClassLoader
+
+ new Thread("MesosSchedulerBackend driver") {
+ setDaemon(true)
+ override def run() {
+ val scheduler = MesosSchedulerBackend.this
+ val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build()
+ driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
+ try {
+ val ret = driver.run()
+ logInfo("driver.run() returned with code " + ret)
+ } catch {
+ case e: Exception => logError("driver.run() failed", e)
+ }
+ }
+ }.start()
+
+ waitForRegister()
+ }
+ }
+
+ def createExecutorInfo(execId: String): ExecutorInfo = {
+ val sparkHome = sc.getSparkHome().getOrElse(throw new SparkException(
+ "Spark home is not set; set it through the spark.home system " +
+ "property, the SPARK_HOME environment variable or the SparkContext constructor"))
+ val environment = Environment.newBuilder()
+ sc.executorEnvs.foreach { case (key, value) =>
+ environment.addVariables(Environment.Variable.newBuilder()
+ .setName(key)
+ .setValue(value)
+ .build())
+ }
+ val command = CommandInfo.newBuilder()
+ .setEnvironment(environment)
+ val uri = System.getProperty("spark.executor.uri")
+ if (uri == null) {
+ command.setValue(new File(sparkHome, "spark-executor").getCanonicalPath)
+ } else {
+ // Grab everything to the first '.'. We'll use that and '*' to
+ // glob the directory "correctly".
+ val basename = uri.split('/').last.split('.').head
+ command.setValue("cd %s*; ./spark-executor".format(basename))
+ command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
+ }
+ val memory = Resource.newBuilder()
+ .setName("mem")
+ .setType(Value.Type.SCALAR)
+ .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build())
+ .build()
+ ExecutorInfo.newBuilder()
+ .setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
+ .setCommand(command)
+ .setData(ByteString.copyFrom(createExecArg()))
+ .addResources(memory)
+ .build()
+ }
+
+ /**
+ * Create and serialize the executor argument to pass to Mesos. Our executor arg is an array
+ * containing all the spark.* system properties in the form of (String, String) pairs.
+ */
+ private def createExecArg(): Array[Byte] = {
+ if (execArgs == null) {
+ val props = new HashMap[String, String]
+ val iterator = System.getProperties.entrySet.iterator
+ while (iterator.hasNext) {
+ val entry = iterator.next
+ val (key, value) = (entry.getKey.toString, entry.getValue.toString)
+ if (key.startsWith("spark.")) {
+ props(key) = value
+ }
+ }
+ // Serialize the map as an array of (String, String) pairs
+ execArgs = Utils.serialize(props.toArray)
+ }
+ return execArgs
+ }
+
+ private def setClassLoader(): ClassLoader = {
+ val oldClassLoader = Thread.currentThread.getContextClassLoader
+ Thread.currentThread.setContextClassLoader(classLoader)
+ return oldClassLoader
+ }
+
+ private def restoreClassLoader(oldClassLoader: ClassLoader) {
+ Thread.currentThread.setContextClassLoader(oldClassLoader)
+ }
+
+ override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
+
+ override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
+ val oldClassLoader = setClassLoader()
+ try {
+ logInfo("Registered as framework ID " + frameworkId.getValue)
+ registeredLock.synchronized {
+ isRegistered = true
+ registeredLock.notifyAll()
+ }
+ } finally {
+ restoreClassLoader(oldClassLoader)
+ }
+ }
+
+ def waitForRegister() {
+ registeredLock.synchronized {
+ while (!isRegistered) {
+ registeredLock.wait()
+ }
+ }
+ }
+
+ override def disconnected(d: SchedulerDriver) {}
+
+ override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
+
+ /**
+ * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets
+ * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that
+ * tasks are balanced across the cluster.
+ */
+ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
+ val oldClassLoader = setClassLoader()
+ try {
+ synchronized {
+ // Build a big list of the offerable workers, and remember their indices so that we can
+ // figure out which Offer to reply to for each worker
+ val offerableIndices = new ArrayBuffer[Int]
+ val offerableWorkers = new ArrayBuffer[WorkerOffer]
+
+ def enoughMemory(o: Offer) = {
+ val mem = getResource(o.getResourcesList, "mem")
+ val slaveId = o.getSlaveId.getValue
+ mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId)
+ }
+
+ for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) {
+ offerableIndices += index
+ offerableWorkers += new WorkerOffer(
+ offer.getSlaveId.getValue,
+ offer.getHostname,
+ getResource(offer.getResourcesList, "cpus").toInt)
+ }
+
+ // Call into the ClusterScheduler
+ val taskLists = scheduler.resourceOffers(offerableWorkers)
+
+ // Build a list of Mesos tasks for each slave
+ val mesosTasks = offers.map(o => Collections.emptyList[MesosTaskInfo]())
+ for ((taskList, index) <- taskLists.zipWithIndex) {
+ if (!taskList.isEmpty) {
+ val offerNum = offerableIndices(index)
+ val slaveId = offers(offerNum).getSlaveId.getValue
+ slaveIdsWithExecutors += slaveId
+ mesosTasks(offerNum) = new JArrayList[MesosTaskInfo](taskList.size)
+ for (taskDesc <- taskList) {
+ taskIdToSlaveId(taskDesc.taskId) = slaveId
+ mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId))
+ }
+ }
+ }
+
+ // Reply to the offers
+ val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
+ for (i <- 0 until offers.size) {
+ d.launchTasks(offers(i).getId, mesosTasks(i), filters)
+ }
+ }
+ } finally {
+ restoreClassLoader(oldClassLoader)
+ }
+ }
+
+ /** Helper function to pull out a resource from a Mesos Resources protobuf */
+ def getResource(res: JList[Resource], name: String): Double = {
+ for (r <- res if r.getName == name) {
+ return r.getScalar.getValue
+ }
+ // If we reached here, no resource with the required name was present
+ throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ }
+
+ /** Turn a Spark TaskDescription into a Mesos task */
+ def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = {
+ val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build()
+ val cpuResource = Resource.newBuilder()
+ .setName("cpus")
+ .setType(Value.Type.SCALAR)
+ .setScalar(Value.Scalar.newBuilder().setValue(1).build())
+ .build()
+ return MesosTaskInfo.newBuilder()
+ .setTaskId(taskId)
+ .setSlaveId(SlaveID.newBuilder().setValue(slaveId).build())
+ .setExecutor(createExecutorInfo(slaveId))
+ .setName(task.name)
+ .addResources(cpuResource)
+ .setData(ByteString.copyFrom(task.serializedTask))
+ .build()
+ }
+
+ /** Check whether a Mesos task state represents a finished task */
+ def isFinished(state: MesosTaskState) = {
+ state == MesosTaskState.TASK_FINISHED ||
+ state == MesosTaskState.TASK_FAILED ||
+ state == MesosTaskState.TASK_KILLED ||
+ state == MesosTaskState.TASK_LOST
+ }
+
+ override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
+ val oldClassLoader = setClassLoader()
+ try {
+ val tid = status.getTaskId.getValue.toLong
+ val state = TaskState.fromMesos(status.getState)
+ synchronized {
+ if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) {
+ // We lost the executor on this slave, so remember that it's gone
+ slaveIdsWithExecutors -= taskIdToSlaveId(tid)
+ }
+ if (isFinished(status.getState)) {
+ taskIdToSlaveId.remove(tid)
+ }
+ }
+ scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer)
+ } finally {
+ restoreClassLoader(oldClassLoader)
+ }
+ }
+
+ override def error(d: SchedulerDriver, message: String) {
+ val oldClassLoader = setClassLoader()
+ try {
+ logError("Mesos error: " + message)
+ scheduler.error(message)
+ } finally {
+ restoreClassLoader(oldClassLoader)
+ }
+ }
+
+ override def stop() {
+ if (driver != null) {
+ driver.stop()
+ }
+ }
+
+ override def reviveOffers() {
+ driver.reviveOffers()
+ }
+
+ override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
+
+ private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) {
+ val oldClassLoader = setClassLoader()
+ try {
+ logInfo("Mesos slave lost: " + slaveId.getValue)
+ synchronized {
+ slaveIdsWithExecutors -= slaveId.getValue
+ }
+ scheduler.executorLost(slaveId.getValue, reason)
+ } finally {
+ restoreClassLoader(oldClassLoader)
+ }
+ }
+
+ override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) {
+ recordSlaveLost(d, slaveId, SlaveLost())
+ }
+
+ override def executorLost(d: SchedulerDriver, executorId: ExecutorID,
+ slaveId: SlaveID, status: Int) {
+ logInfo("Executor lost: %s, marking slave %s as lost".format(executorId.getValue,
+ slaveId.getValue))
+ recordSlaveLost(d, slaveId, ExecutorExited(status))
+ }
+
+ // TODO: query Mesos for number of cores
+ override def defaultParallelism() = System.getProperty("spark.default.parallelism", "8").toInt
+}
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
new file mode 100644
index 0000000000..160cca4d6c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.io.{EOFException, InputStream, OutputStream}
+import java.nio.ByteBuffer
+
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
+import org.apache.spark.util.{NextIterator, ByteBufferInputStream}
+
+
+/**
+ * A serializer. Because some serialization libraries are not thread safe, this class is used to
+ * create [[org.apache.spark.serializer.SerializerInstance]] objects that do the actual serialization and are
+ * guaranteed to only be called from one thread at a time.
+ */
+trait Serializer {
+ def newInstance(): SerializerInstance
+}
+
+
+/**
+ * An instance of a serializer, for use by one thread at a time.
+ */
+trait SerializerInstance {
+ def serialize[T](t: T): ByteBuffer
+
+ def deserialize[T](bytes: ByteBuffer): T
+
+ def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T
+
+ def serializeStream(s: OutputStream): SerializationStream
+
+ def deserializeStream(s: InputStream): DeserializationStream
+
+ def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
+ // Default implementation uses serializeStream
+ val stream = new FastByteArrayOutputStream()
+ serializeStream(stream).writeAll(iterator)
+ val buffer = ByteBuffer.allocate(stream.position.toInt)
+ buffer.put(stream.array, 0, stream.position.toInt)
+ buffer.flip()
+ buffer
+ }
+
+ def deserializeMany(buffer: ByteBuffer): Iterator[Any] = {
+ // Default implementation uses deserializeStream
+ buffer.rewind()
+ deserializeStream(new ByteBufferInputStream(buffer)).asIterator
+ }
+}
+
+
+/**
+ * A stream for writing serialized objects.
+ */
+trait SerializationStream {
+ def writeObject[T](t: T): SerializationStream
+ def flush(): Unit
+ def close(): Unit
+
+ def writeAll[T](iter: Iterator[T]): SerializationStream = {
+ while (iter.hasNext) {
+ writeObject(iter.next())
+ }
+ this
+ }
+}
+
+
+/**
+ * A stream for reading serialized objects.
+ */
+trait DeserializationStream {
+ def readObject[T](): T
+ def close(): Unit
+
+ /**
+ * Read the elements of this stream through an iterator. This can only be called once, as
+ * reading each element will consume data from the input source.
+ */
+ def asIterator: Iterator[Any] = new NextIterator[Any] {
+ override protected def getNext() = {
+ try {
+ readObject[Any]()
+ } catch {
+ case eof: EOFException =>
+ finished = true
+ }
+ }
+
+ override protected def close() {
+ DeserializationStream.this.close()
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
new file mode 100644
index 0000000000..2955986fec
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.util.concurrent.ConcurrentHashMap
+
+
+/**
+ * A service that returns a serializer object given the serializer's class name. If a previous
+ * instance of the serializer object has been created, the get method returns that instead of
+ * creating a new one.
+ */
+private[spark] class SerializerManager {
+
+ private val serializers = new ConcurrentHashMap[String, Serializer]
+ private var _default: Serializer = _
+
+ def default = _default
+
+ def setDefault(clsName: String): Serializer = {
+ _default = get(clsName)
+ _default
+ }
+
+ def get(clsName: String): Serializer = {
+ if (clsName == null) {
+ default
+ } else {
+ var serializer = serializers.get(clsName)
+ if (serializer != null) {
+ // If the serializer has been created previously, reuse that.
+ serializer
+ } else this.synchronized {
+ // Otherwise, create a new one. But make sure no other thread has attempted
+ // to create another new one at the same time.
+ serializer = serializers.get(clsName)
+ if (serializer == null) {
+ val clsLoader = Thread.currentThread.getContextClassLoader
+ serializer =
+ Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer]
+ serializers.put(clsName, serializer)
+ }
+ serializer
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockException.scala b/core/src/main/scala/org/apache/spark/storage/BlockException.scala
new file mode 100644
index 0000000000..290dbce4f5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockException.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+private[spark]
+case class BlockException(blockId: String, message: String) extends Exception(message)
+
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala
new file mode 100644
index 0000000000..2e0b0e6eda
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetchTracker.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+private[spark] trait BlockFetchTracker {
+ def totalBlocks : Int
+ def numLocalBlocks: Int
+ def numRemoteBlocks: Int
+ def remoteFetchTime : Long
+ def fetchWaitTime: Long
+ def remoteBytesRead : Long
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
new file mode 100644
index 0000000000..c91f0fc1ad
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -0,0 +1,348 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.Queue
+
+import io.netty.buffer.ByteBuf
+
+import org.apache.spark.Logging
+import org.apache.spark.Utils
+import org.apache.spark.SparkException
+import org.apache.spark.network.BufferMessage
+import org.apache.spark.network.ConnectionManagerId
+import org.apache.spark.network.netty.ShuffleCopier
+import org.apache.spark.serializer.Serializer
+
+
+/**
+ * A block fetcher iterator interface. There are two implementations:
+ *
+ * BasicBlockFetcherIterator: uses a custom-built NIO communication layer.
+ * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer.
+ *
+ * Eventually we would like the two to converge and use a single NIO-based communication layer,
+ * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores),
+ * NIO would perform poorly and thus the need for the Netty OIO one.
+ */
+
+private[storage]
+trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])]
+ with Logging with BlockFetchTracker {
+ def initialize()
+}
+
+
+private[storage]
+object BlockFetcherIterator {
+
+ // A request to fetch one or more blocks, complete with their sizes
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ val size = blocks.map(_._2).sum
+ }
+
+ // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+ // the block (since we want all deserializaton to happen in the calling thread); can also
+ // represent a fetch failure if size == -1.
+ class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ def failed: Boolean = size == -1
+ }
+
+ class BasicBlockFetcherIterator(
+ private val blockManager: BlockManager,
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer)
+ extends BlockFetcherIterator {
+
+ import blockManager._
+
+ private var _remoteBytesRead = 0l
+ private var _remoteFetchTime = 0l
+ private var _fetchWaitTime = 0l
+
+ if (blocksByAddress == null) {
+ throw new IllegalArgumentException("BlocksByAddress is null")
+ }
+
+ // Total number blocks fetched (local + remote). Also number of FetchResults expected
+ protected var _numBlocksToFetch = 0
+
+ protected var startTime = System.currentTimeMillis
+
+ // This represents the number of local blocks, also counting zero-sized blocks
+ private var numLocal = 0
+ // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
+ protected val localBlocksToFetch = new ArrayBuffer[String]()
+
+ // This represents the number of remote blocks, also counting zero-sized blocks
+ private var numRemote = 0
+ // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
+ protected val remoteBlocksToFetch = new HashSet[String]()
+
+ // A queue to hold our results.
+ protected val results = new LinkedBlockingQueue[FetchResult]
+
+ // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ // the number of bytes in flight is limited to maxBytesInFlight
+ private val fetchRequests = new Queue[FetchRequest]
+
+ // Current bytes in flight from our requests
+ private var bytesInFlight = 0L
+
+ protected def sendRequest(req: FetchRequest) {
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
+ val cmId = new ConnectionManagerId(req.address.host, req.address.port)
+ val blockMessageArray = new BlockMessageArray(req.blocks.map {
+ case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
+ })
+ bytesInFlight += req.size
+ val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
+ val fetchStart = System.currentTimeMillis()
+ val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
+ future.onSuccess {
+ case Some(message) => {
+ val fetchDone = System.currentTimeMillis()
+ _remoteFetchTime += fetchDone - fetchStart
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+ for (blockMessage <- blockMessageArray) {
+ if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+ throw new SparkException(
+ "Unexpected message " + blockMessage.getType + " received from " + cmId)
+ }
+ val blockId = blockMessage.getId
+ val networkSize = blockMessage.getData.limit()
+ results.put(new FetchResult(blockId, sizeMap(blockId),
+ () => dataDeserialize(blockId, blockMessage.getData, serializer)))
+ _remoteBytesRead += networkSize
+ logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ }
+ }
+ case None => {
+ logError("Could not get block(s) from " + cmId)
+ for ((blockId, size) <- req.blocks) {
+ results.put(new FetchResult(blockId, -1, null))
+ }
+ }
+ }
+ }
+
+ protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
+ // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+ // at most maxBytesInFlight in order to limit the amount of data in flight.
+ val remoteRequests = new ArrayBuffer[FetchRequest]
+ for ((address, blockInfos) <- blocksByAddress) {
+ if (address == blockManagerId) {
+ numLocal = blockInfos.size
+ // Filter out zero-sized blocks
+ localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
+ _numBlocksToFetch += localBlocksToFetch.size
+ } else {
+ numRemote += blockInfos.size
+ // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+ // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+ // nodes, rather than blocking on reading output from one node.
+ val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+ logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+ val iterator = blockInfos.iterator
+ var curRequestSize = 0L
+ var curBlocks = new ArrayBuffer[(String, Long)]
+ while (iterator.hasNext) {
+ val (blockId, size) = iterator.next()
+ // Skip empty blocks
+ if (size > 0) {
+ curBlocks += ((blockId, size))
+ remoteBlocksToFetch += blockId
+ _numBlocksToFetch += 1
+ curRequestSize += size
+ } else if (size < 0) {
+ throw new BlockException(blockId, "Negative block size " + size)
+ }
+ if (curRequestSize >= minRequestSize) {
+ // Add this FetchRequest
+ remoteRequests += new FetchRequest(address, curBlocks)
+ curRequestSize = 0
+ curBlocks = new ArrayBuffer[(String, Long)]
+ }
+ }
+ // Add in the final request
+ if (!curBlocks.isEmpty) {
+ remoteRequests += new FetchRequest(address, curBlocks)
+ }
+ }
+ }
+ logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " +
+ totalBlocks + " blocks")
+ remoteRequests
+ }
+
+ protected def getLocalBlocks() {
+ // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+ // these all at once because they will just memory-map some files, so they won't consume
+ // any memory that might exceed our maxBytesInFlight
+ for (id <- localBlocksToFetch) {
+ getLocalFromDisk(id, serializer) match {
+ case Some(iter) => {
+ // Pass 0 as size since it's not in flight
+ results.put(new FetchResult(id, 0, () => iter))
+ logDebug("Got local block " + id)
+ }
+ case None => {
+ throw new BlockException(id, "Could not get block " + id + " from local machine")
+ }
+ }
+ }
+ }
+
+ override def initialize() {
+ // Split local and remote blocks.
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ fetchRequests ++= Utils.randomize(remoteRequests)
+
+ // Send out initial requests for blocks, up to our maxBytesInFlight
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+
+ val numGets = remoteRequests.size - fetchRequests.size
+ logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ startTime = System.currentTimeMillis
+ getLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ //an iterator that will read fetched blocks off the queue as they arrive.
+ @volatile protected var resultsGotten = 0
+
+ override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
+
+ override def next(): (String, Option[Iterator[Any]]) = {
+ resultsGotten += 1
+ val startFetchWait = System.currentTimeMillis()
+ val result = results.take()
+ val stopFetchWait = System.currentTimeMillis()
+ _fetchWaitTime += (stopFetchWait - startFetchWait)
+ if (! result.failed) bytesInFlight -= result.size
+ while (!fetchRequests.isEmpty &&
+ (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+ sendRequest(fetchRequests.dequeue())
+ }
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+
+ // Implementing BlockFetchTracker trait.
+ override def totalBlocks: Int = numLocal + numRemote
+ override def numLocalBlocks: Int = numLocal
+ override def numRemoteBlocks: Int = numRemote
+ override def remoteFetchTime: Long = _remoteFetchTime
+ override def fetchWaitTime: Long = _fetchWaitTime
+ override def remoteBytesRead: Long = _remoteBytesRead
+ }
+ // End of BasicBlockFetcherIterator
+
+ class NettyBlockFetcherIterator(
+ blockManager: BlockManager,
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ serializer: Serializer)
+ extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
+
+ import blockManager._
+
+ val fetchRequestsSync = new LinkedBlockingQueue[FetchRequest]
+
+ private def startCopiers(numCopiers: Int): List[_ <: Thread] = {
+ (for ( i <- Range(0,numCopiers) ) yield {
+ val copier = new Thread {
+ override def run(){
+ try {
+ while(!isInterrupted && !fetchRequestsSync.isEmpty) {
+ sendRequest(fetchRequestsSync.take())
+ }
+ } catch {
+ case x: InterruptedException => logInfo("Copier Interrupted")
+ //case _ => throw new SparkException("Exception Throw in Shuffle Copier")
+ }
+ }
+ }
+ copier.start
+ copier
+ }).toList
+ }
+
+ // keep this to interrupt the threads when necessary
+ private def stopCopiers() {
+ for (copier <- copiers) {
+ copier.interrupt()
+ }
+ }
+
+ override protected def sendRequest(req: FetchRequest) {
+
+ def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) {
+ val fetchResult = new FetchResult(blockId, blockSize,
+ () => dataDeserialize(blockId, blockData.nioBuffer, serializer))
+ results.put(fetchResult)
+ }
+
+ logDebug("Sending request for %d blocks (%s) from %s".format(
+ req.blocks.size, Utils.bytesToString(req.size), req.address.host))
+ val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort)
+ val cpier = new ShuffleCopier
+ cpier.getBlocks(cmId, req.blocks, putResult)
+ logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
+ }
+
+ private var copiers: List[_ <: Thread] = null
+
+ override def initialize() {
+ // Split Local Remote Blocks and set numBlocksToFetch
+ val remoteRequests = splitLocalRemoteBlocks()
+ // Add the remote requests into our queue in a random order
+ for (request <- Utils.randomize(remoteRequests)) {
+ fetchRequestsSync.put(request)
+ }
+
+ copiers = startCopiers(System.getProperty("spark.shuffle.copier.threads", "6").toInt)
+ logInfo("Started " + fetchRequestsSync.size + " remote gets in " +
+ Utils.getUsedTimeMs(startTime))
+
+ // Get Local Blocks
+ startTime = System.currentTimeMillis
+ getLocalBlocks()
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ }
+
+ override def next(): (String, Option[Iterator[Any]]) = {
+ resultsGotten += 1
+ val result = results.take()
+ // If all the results has been retrieved, copiers will exit automatically
+ (result.blockId, if (result.failed) None else Some(result.deserialize()))
+ }
+ }
+ // End of NettyBlockFetcherIterator
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
new file mode 100644
index 0000000000..3299ac98d5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -0,0 +1,1046 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.{InputStream, OutputStream}
+import java.nio.{ByteBuffer, MappedByteBuffer}
+
+import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
+
+import akka.actor.{ActorSystem, Cancellable, Props}
+import akka.dispatch.{Await, Future}
+import akka.util.Duration
+import akka.util.duration._
+
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
+import org.apache.spark.{Logging, SparkEnv, SparkException, Utils}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.network._
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.{ByteBufferInputStream, IdGenerator, MetadataCleaner, TimeStampedHashMap}
+
+import sun.nio.ch.DirectBuffer
+
+
+private[spark] class BlockManager(
+ executorId: String,
+ actorSystem: ActorSystem,
+ val master: BlockManagerMaster,
+ val defaultSerializer: Serializer,
+ maxMemory: Long)
+ extends Logging {
+
+ private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
+ @volatile var pending: Boolean = true
+ @volatile var size: Long = -1L
+ @volatile var initThread: Thread = null
+ @volatile var failed = false
+
+ setInitThread()
+
+ private def setInitThread() {
+ // Set current thread as init thread - waitForReady will not block this thread
+ // (in case there is non trivial initialization which ends up calling waitForReady as part of
+ // initialization itself)
+ this.initThread = Thread.currentThread()
+ }
+
+ /**
+ * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
+ * Return true if the block is available, false otherwise.
+ */
+ def waitForReady(): Boolean = {
+ if (initThread != Thread.currentThread() && pending) {
+ synchronized {
+ while (pending) this.wait()
+ }
+ }
+ !failed
+ }
+
+ /** Mark this BlockInfo as ready (i.e. block is finished writing) */
+ def markReady(sizeInBytes: Long) {
+ assert (pending)
+ size = sizeInBytes
+ initThread = null
+ failed = false
+ initThread = null
+ pending = false
+ synchronized {
+ this.notifyAll()
+ }
+ }
+
+ /** Mark this BlockInfo as ready but failed */
+ def markFailure() {
+ assert (pending)
+ size = 0
+ initThread = null
+ failed = true
+ initThread = null
+ pending = false
+ synchronized {
+ this.notifyAll()
+ }
+ }
+ }
+
+ val shuffleBlockManager = new ShuffleBlockManager(this)
+
+ private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
+
+ private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
+ private[storage] val diskStore: DiskStore =
+ new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+
+ // If we use Netty for shuffle, start a new Netty-based shuffle sender service.
+ private val nettyPort: Int = {
+ val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
+ val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
+ if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
+ }
+
+ val connectionManager = new ConnectionManager(0)
+ implicit val futureExecContext = connectionManager.futureExecContext
+
+ val blockManagerId = BlockManagerId(
+ executorId, connectionManager.id.host, connectionManager.id.port, nettyPort)
+
+ // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
+ // for receiving shuffle outputs)
+ val maxBytesInFlight =
+ System.getProperty("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024
+
+ // Whether to compress broadcast variables that are stored
+ val compressBroadcast = System.getProperty("spark.broadcast.compress", "true").toBoolean
+ // Whether to compress shuffle output that are stored
+ val compressShuffle = System.getProperty("spark.shuffle.compress", "true").toBoolean
+ // Whether to compress RDD partitions that are stored serialized
+ val compressRdds = System.getProperty("spark.rdd.compress", "false").toBoolean
+
+ val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties
+
+ val hostPort = Utils.localHostPort()
+
+ val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
+ name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
+
+ // Pending reregistration action being executed asynchronously or null if none
+ // is pending. Accesses should synchronize on asyncReregisterLock.
+ var asyncReregisterTask: Future[Unit] = null
+ val asyncReregisterLock = new Object
+
+ private def heartBeat() {
+ if (!master.sendHeartBeat(blockManagerId)) {
+ reregister()
+ }
+ }
+
+ var heartBeatTask: Cancellable = null
+
+ val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks)
+ initialize()
+
+ // The compression codec to use. Note that the "lazy" val is necessary because we want to delay
+ // the initialization of the compression codec until it is first used. The reason is that a Spark
+ // program could be using a user-defined codec in a third party jar, which is loaded in
+ // Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been
+ // loaded yet.
+ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec()
+
+ /**
+ * Construct a BlockManager with a memory limit set based on system properties.
+ */
+ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster,
+ serializer: Serializer) = {
+ this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties)
+ }
+
+ /**
+ * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
+ * BlockManagerWorker actor.
+ */
+ private def initialize() {
+ master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
+ BlockManagerWorker.startBlockManagerWorker(this)
+ if (!BlockManager.getDisableHeartBeatsForTesting) {
+ heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) {
+ heartBeat()
+ }
+ }
+ }
+
+ /**
+ * Report all blocks to the BlockManager again. This may be necessary if we are dropped
+ * by the BlockManager and come back or if we become capable of recovering blocks on disk after
+ * an executor crash.
+ *
+ * This function deliberately fails silently if the master returns false (indicating that
+ * the slave needs to reregister). The error condition will be detected again by the next
+ * heart beat attempt or new block registration and another try to reregister all blocks
+ * will be made then.
+ */
+ private def reportAllBlocks() {
+ logInfo("Reporting " + blockInfo.size + " blocks to the master.")
+ for ((blockId, info) <- blockInfo) {
+ if (!tryToReportBlockStatus(blockId, info)) {
+ logError("Failed to report " + blockId + " to master; giving up.")
+ return
+ }
+ }
+ }
+
+ /**
+ * Reregister with the master and report all blocks to it. This will be called by the heart beat
+ * thread if our heartbeat to the block amnager indicates that we were not registered.
+ *
+ * Note that this method must be called without any BlockInfo locks held.
+ */
+ def reregister() {
+ // TODO: We might need to rate limit reregistering.
+ logInfo("BlockManager reregistering with master")
+ master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
+ reportAllBlocks()
+ }
+
+ /**
+ * Reregister with the master sometime soon.
+ */
+ def asyncReregister() {
+ asyncReregisterLock.synchronized {
+ if (asyncReregisterTask == null) {
+ asyncReregisterTask = Future[Unit] {
+ reregister()
+ asyncReregisterLock.synchronized {
+ asyncReregisterTask = null
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing.
+ */
+ def waitForAsyncReregister() {
+ val task = asyncReregisterTask
+ if (task != null) {
+ Await.ready(task, Duration.Inf)
+ }
+ }
+
+ /**
+ * Get storage level of local block. If no info exists for the block, then returns null.
+ */
+ def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
+
+ /**
+ * Tell the master about the current storage status of a block. This will send a block update
+ * message reflecting the current status, *not* the desired storage level in its block info.
+ * For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk.
+ *
+ * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid).
+ * This ensures that update in master will compensate for the increase in memory on slave.
+ */
+ def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) {
+ val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
+ if (needReregister) {
+ logInfo("Got told to reregister updating block " + blockId)
+ // Reregistering will report our new block for free.
+ asyncReregister()
+ }
+ logDebug("Told master about block " + blockId)
+ }
+
+ /**
+ * Actually send a UpdateBlockInfo message. Returns the mater's response,
+ * which will be true if the block was successfully recorded and false if
+ * the slave needs to re-register.
+ */
+ private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
+ val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
+ info.level match {
+ case null =>
+ (StorageLevel.NONE, 0L, 0L, false)
+ case level =>
+ val inMem = level.useMemory && memoryStore.contains(blockId)
+ val onDisk = level.useDisk && diskStore.contains(blockId)
+ val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
+ val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize
+ val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
+ (storageLevel, memSize, diskSize, info.tellMaster)
+ }
+ }
+
+ if (tellMaster) {
+ master.updateBlockInfo(blockManagerId, blockId, curLevel, inMemSize, onDiskSize)
+ } else {
+ true
+ }
+ }
+
+ /**
+ * Get locations of an array of blocks.
+ */
+ def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = {
+ val startTimeMs = System.currentTimeMillis
+ val locations = master.getLocations(blockIds).toArray
+ logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
+ locations
+ }
+
+ /**
+ * A short-circuited method to get blocks directly from disk. This is used for getting
+ * shuffle blocks. It is safe to do so without a lock on block info since disk store
+ * never deletes (recent) items.
+ */
+ def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ diskStore.getValues(blockId, serializer).orElse(
+ sys.error("Block " + blockId + " not found on disk, though it should be"))
+ }
+
+ /**
+ * Get block from local block manager.
+ */
+ def getLocal(blockId: String): Option[Iterator[Any]] = {
+ logDebug("Getting local block " + blockId)
+ val info = blockInfo.get(blockId).orNull
+ if (info != null) {
+ info.synchronized {
+
+ // In the another thread is writing the block, wait for it to become ready.
+ if (!info.waitForReady()) {
+ // If we get here, the block write failed.
+ logWarning("Block " + blockId + " was marked as failure.")
+ return None
+ }
+
+ val level = info.level
+ logDebug("Level for block " + blockId + " is " + level)
+
+ // Look for the block in memory
+ if (level.useMemory) {
+ logDebug("Getting block " + blockId + " from memory")
+ memoryStore.getValues(blockId) match {
+ case Some(iterator) =>
+ return Some(iterator)
+ case None =>
+ logDebug("Block " + blockId + " not found in memory")
+ }
+ }
+
+ // Look for block on disk, potentially loading it back into memory if required
+ if (level.useDisk) {
+ logDebug("Getting block " + blockId + " from disk")
+ if (level.useMemory && level.deserialized) {
+ diskStore.getValues(blockId) match {
+ case Some(iterator) =>
+ // Put the block back in memory before returning it
+ // TODO: Consider creating a putValues that also takes in a iterator ?
+ val elements = new ArrayBuffer[Any]
+ elements ++= iterator
+ memoryStore.putValues(blockId, elements, level, true).data match {
+ case Left(iterator2) =>
+ return Some(iterator2)
+ case _ =>
+ throw new Exception("Memory store did not return back an iterator")
+ }
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ } else if (level.useMemory && !level.deserialized) {
+ // Read it as a byte buffer into memory first, then return it
+ diskStore.getBytes(blockId) match {
+ case Some(bytes) =>
+ // Put a copy of the block back in memory before returning it. Note that we can't
+ // put the ByteBuffer returned by the disk store as that's a memory-mapped file.
+ // The use of rewind assumes this.
+ assert (0 == bytes.position())
+ val copyForMemory = ByteBuffer.allocate(bytes.limit)
+ copyForMemory.put(bytes)
+ memoryStore.putBytes(blockId, copyForMemory, level)
+ bytes.rewind()
+ return Some(dataDeserialize(blockId, bytes))
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ } else {
+ diskStore.getValues(blockId) match {
+ case Some(iterator) =>
+ return Some(iterator)
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ }
+ }
+ }
+ } else {
+ logDebug("Block " + blockId + " not registered locally")
+ }
+ return None
+ }
+
+ /**
+ * Get block from the local block manager as serialized bytes.
+ */
+ def getLocalBytes(blockId: String): Option[ByteBuffer] = {
+ // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow
+ logDebug("Getting local block " + blockId + " as bytes")
+
+ // As an optimization for map output fetches, if the block is for a shuffle, return it
+ // without acquiring a lock; the disk store never deletes (recent) items so this should work
+ if (ShuffleBlockManager.isShuffle(blockId)) {
+ return diskStore.getBytes(blockId) match {
+ case Some(bytes) =>
+ Some(bytes)
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ }
+
+ val info = blockInfo.get(blockId).orNull
+ if (info != null) {
+ info.synchronized {
+
+ // In the another thread is writing the block, wait for it to become ready.
+ if (!info.waitForReady()) {
+ // If we get here, the block write failed.
+ logWarning("Block " + blockId + " was marked as failure.")
+ return None
+ }
+
+ val level = info.level
+ logDebug("Level for block " + blockId + " is " + level)
+
+ // Look for the block in memory
+ if (level.useMemory) {
+ logDebug("Getting block " + blockId + " from memory")
+ memoryStore.getBytes(blockId) match {
+ case Some(bytes) =>
+ return Some(bytes)
+ case None =>
+ logDebug("Block " + blockId + " not found in memory")
+ }
+ }
+
+ // Look for block on disk
+ if (level.useDisk) {
+ // Read it as a byte buffer into memory first, then return it
+ diskStore.getBytes(blockId) match {
+ case Some(bytes) =>
+ assert (0 == bytes.position())
+ if (level.useMemory) {
+ if (level.deserialized) {
+ memoryStore.putBytes(blockId, bytes, level)
+ } else {
+ // The memory store will hang onto the ByteBuffer, so give it a copy instead of
+ // the memory-mapped file buffer we got from the disk store
+ val copyForMemory = ByteBuffer.allocate(bytes.limit)
+ copyForMemory.put(bytes)
+ memoryStore.putBytes(blockId, copyForMemory, level)
+ }
+ }
+ bytes.rewind()
+ return Some(bytes)
+ case None =>
+ throw new Exception("Block " + blockId + " not found on disk, though it should be")
+ }
+ }
+ }
+ } else {
+ logDebug("Block " + blockId + " not registered locally")
+ }
+ return None
+ }
+
+ /**
+ * Get block from remote block managers.
+ */
+ def getRemote(blockId: String): Option[Iterator[Any]] = {
+ if (blockId == null) {
+ throw new IllegalArgumentException("Block Id is null")
+ }
+ logDebug("Getting remote block " + blockId)
+ // Get locations of block
+ val locations = master.getLocations(blockId)
+
+ // Get block from remote locations
+ for (loc <- locations) {
+ logDebug("Getting remote block " + blockId + " from " + loc)
+ val data = BlockManagerWorker.syncGetBlock(
+ GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
+ if (data != null) {
+ return Some(dataDeserialize(blockId, data))
+ }
+ logDebug("The value of block " + blockId + " is null")
+ }
+ logDebug("Block " + blockId + " not found")
+ return None
+ }
+
+ /**
+ * Get a block from the block manager (either local or remote).
+ */
+ def get(blockId: String): Option[Iterator[Any]] = {
+ getLocal(blockId).orElse(getRemote(blockId))
+ }
+
+ /**
+ * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns
+ * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined
+ * fashion as they're received. Expects a size in bytes to be provided for each block fetched,
+ * so that we can control the maxMegabytesInFlight for the fetch.
+ */
+ def getMultiple(
+ blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
+ : BlockFetcherIterator = {
+
+ val iter =
+ if (System.getProperty("spark.shuffle.use.netty", "false").toBoolean) {
+ new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
+ } else {
+ new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
+ }
+
+ iter.initialize()
+ iter
+ }
+
+ def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
+ : Long = {
+ val elements = new ArrayBuffer[Any]
+ elements ++= values
+ put(blockId, elements, level, tellMaster)
+ }
+
+ /**
+ * A short circuited method to get a block writer that can write data directly to disk.
+ * This is currently used for writing shuffle files out. Callers should handle error
+ * cases.
+ */
+ def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ : BlockObjectWriter = {
+ val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
+ writer.registerCloseEventHandler(() => {
+ val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
+ blockInfo.put(blockId, myInfo)
+ myInfo.markReady(writer.size())
+ })
+ writer
+ }
+
+ /**
+ * Put a new block of values to the block manager. Returns its (estimated) size in bytes.
+ */
+ def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
+ tellMaster: Boolean = true) : Long = {
+
+ if (blockId == null) {
+ throw new IllegalArgumentException("Block Id is null")
+ }
+ if (values == null) {
+ throw new IllegalArgumentException("Values is null")
+ }
+ if (level == null || !level.isValid) {
+ throw new IllegalArgumentException("Storage level is null or invalid")
+ }
+
+ // Remember the block's storage level so that we can correctly drop it to disk if it needs
+ // to be dropped right after it got put into memory. Note, however, that other threads will
+ // not be able to get() this block until we call markReady on its BlockInfo.
+ val myInfo = {
+ val tinfo = new BlockInfo(level, tellMaster)
+ // Do atomically !
+ val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
+
+ if (oldBlockOpt.isDefined) {
+ if (oldBlockOpt.get.waitForReady()) {
+ logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
+ return oldBlockOpt.get.size
+ }
+
+ // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ oldBlockOpt.get
+ } else {
+ tinfo
+ }
+ }
+
+ val startTimeMs = System.currentTimeMillis
+
+ // If we need to replicate the data, we'll want access to the values, but because our
+ // put will read the whole iterator, there will be no values left. For the case where
+ // the put serializes data, we'll remember the bytes, above; but for the case where it
+ // doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
+ var valuesAfterPut: Iterator[Any] = null
+
+ // Ditto for the bytes after the put
+ var bytesAfterPut: ByteBuffer = null
+
+ // Size of the block in bytes (to return to caller)
+ var size = 0L
+
+ myInfo.synchronized {
+ logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ + " to get into synchronized block")
+
+ var marked = false
+ try {
+ if (level.useMemory) {
+ // Save it just to memory first, even if it also has useDisk set to true; we will later
+ // drop it to disk if the memory store can't hold it.
+ val res = memoryStore.putValues(blockId, values, level, true)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case Left(newIterator) => valuesAfterPut = newIterator
+ }
+ } else {
+ // Save directly to disk.
+ // Don't get back the bytes unless we replicate them.
+ val askForBytes = level.replication > 1
+ val res = diskStore.putValues(blockId, values, level, askForBytes)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
+ }
+ }
+
+ // Now that the block is in either the memory or disk store, let other threads read it,
+ // and tell the master about it.
+ marked = true
+ myInfo.markReady(size)
+ if (tellMaster) {
+ reportBlockStatus(blockId, myInfo)
+ }
+ } finally {
+ // If we failed at putting the block to memory/disk, notify other possible readers
+ // that it has failed, and then remove it from the block info map.
+ if (! marked) {
+ // Note that the remove must happen before markFailure otherwise another thread
+ // could've inserted a new BlockInfo before we remove it.
+ blockInfo.remove(blockId)
+ myInfo.markFailure()
+ logWarning("Putting block " + blockId + " failed")
+ }
+ }
+ }
+ logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
+
+ // Replicate block if required
+ if (level.replication > 1) {
+ val remoteStartTime = System.currentTimeMillis
+ // Serialize the block if not already done
+ if (bytesAfterPut == null) {
+ if (valuesAfterPut == null) {
+ throw new SparkException(
+ "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
+ }
+ bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
+ }
+ replicate(blockId, bytesAfterPut, level)
+ logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime))
+ }
+ BlockManager.dispose(bytesAfterPut)
+
+ return size
+ }
+
+
+ /**
+ * Put a new block of serialized bytes to the block manager.
+ */
+ def putBytes(
+ blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
+
+ if (blockId == null) {
+ throw new IllegalArgumentException("Block Id is null")
+ }
+ if (bytes == null) {
+ throw new IllegalArgumentException("Bytes is null")
+ }
+ if (level == null || !level.isValid) {
+ throw new IllegalArgumentException("Storage level is null or invalid")
+ }
+
+ // Remember the block's storage level so that we can correctly drop it to disk if it needs
+ // to be dropped right after it got put into memory. Note, however, that other threads will
+ // not be able to get() this block until we call markReady on its BlockInfo.
+ val myInfo = {
+ val tinfo = new BlockInfo(level, tellMaster)
+ // Do atomically !
+ val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
+
+ if (oldBlockOpt.isDefined) {
+ if (oldBlockOpt.get.waitForReady()) {
+ logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
+ return
+ }
+
+ // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ oldBlockOpt.get
+ } else {
+ tinfo
+ }
+ }
+
+ val startTimeMs = System.currentTimeMillis
+
+ // Initiate the replication before storing it locally. This is faster as
+ // data is already serialized and ready for sending
+ val replicationFuture = if (level.replication > 1) {
+ val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper
+ Future {
+ replicate(blockId, bufferView, level)
+ }
+ } else {
+ null
+ }
+
+ myInfo.synchronized {
+ logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ + " to get into synchronized block")
+
+ var marked = false
+ try {
+ if (level.useMemory) {
+ // Store it only in memory at first, even if useDisk is also set to true
+ bytes.rewind()
+ memoryStore.putBytes(blockId, bytes, level)
+ } else {
+ bytes.rewind()
+ diskStore.putBytes(blockId, bytes, level)
+ }
+
+ // assert (0 == bytes.position(), "" + bytes)
+
+ // Now that the block is in either the memory or disk store, let other threads read it,
+ // and tell the master about it.
+ marked = true
+ myInfo.markReady(bytes.limit)
+ if (tellMaster) {
+ reportBlockStatus(blockId, myInfo)
+ }
+ } finally {
+ // If we failed at putting the block to memory/disk, notify other possible readers
+ // that it has failed, and then remove it from the block info map.
+ if (! marked) {
+ // Note that the remove must happen before markFailure otherwise another thread
+ // could've inserted a new BlockInfo before we remove it.
+ blockInfo.remove(blockId)
+ myInfo.markFailure()
+ logWarning("Putting block " + blockId + " failed")
+ }
+ }
+ }
+
+ // If replication had started, then wait for it to finish
+ if (level.replication > 1) {
+ Await.ready(replicationFuture, Duration.Inf)
+ }
+
+ if (level.replication > 1) {
+ logDebug("PutBytes for block " + blockId + " with replication took " +
+ Utils.getUsedTimeMs(startTimeMs))
+ } else {
+ logDebug("PutBytes for block " + blockId + " without replication took " +
+ Utils.getUsedTimeMs(startTimeMs))
+ }
+ }
+
+ /**
+ * Replicate block to another node.
+ */
+ var cachedPeers: Seq[BlockManagerId] = null
+ private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
+ val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
+ if (cachedPeers == null) {
+ cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
+ }
+ for (peer: BlockManagerId <- cachedPeers) {
+ val start = System.nanoTime
+ data.rewind()
+ logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
+ + data.limit() + " Bytes. To node: " + peer)
+ if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel),
+ new ConnectionManagerId(peer.host, peer.port))) {
+ logError("Failed to call syncPutBlock to " + peer)
+ }
+ logDebug("Replicated BlockId " + blockId + " once used " +
+ (System.nanoTime - start) / 1e6 + " s; The size of the data is " +
+ data.limit() + " bytes.")
+ }
+ }
+
+ /**
+ * Read a block consisting of a single object.
+ */
+ def getSingle(blockId: String): Option[Any] = {
+ get(blockId).map(_.next())
+ }
+
+ /**
+ * Write a block consisting of a single object.
+ */
+ def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
+ put(blockId, Iterator(value), level, tellMaster)
+ }
+
+ /**
+ * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
+ * store reaches its limit and needs to free up space.
+ */
+ def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) {
+ logInfo("Dropping block " + blockId + " from memory")
+ val info = blockInfo.get(blockId).orNull
+ if (info != null) {
+ info.synchronized {
+ // required ? As of now, this will be invoked only for blocks which are ready
+ // But in case this changes in future, adding for consistency sake.
+ if (! info.waitForReady() ) {
+ // If we get here, the block write failed.
+ logWarning("Block " + blockId + " was marked as failure. Nothing to drop")
+ return
+ }
+
+ val level = info.level
+ if (level.useDisk && !diskStore.contains(blockId)) {
+ logInfo("Writing block " + blockId + " to disk")
+ data match {
+ case Left(elements) =>
+ diskStore.putValues(blockId, elements, level, false)
+ case Right(bytes) =>
+ diskStore.putBytes(blockId, bytes, level)
+ }
+ }
+ val droppedMemorySize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
+ val blockWasRemoved = memoryStore.remove(blockId)
+ if (!blockWasRemoved) {
+ logWarning("Block " + blockId + " could not be dropped from memory as it does not exist")
+ }
+ if (info.tellMaster) {
+ reportBlockStatus(blockId, info, droppedMemorySize)
+ }
+ if (!level.useDisk) {
+ // The block is completely gone from this node; forget it so we can put() it again later.
+ blockInfo.remove(blockId)
+ }
+ }
+ } else {
+ // The block has already been dropped
+ }
+ }
+
+ /**
+ * Remove all blocks belonging to the given RDD.
+ * @return The number of blocks removed.
+ */
+ def removeRdd(rddId: Int): Int = {
+ // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
+ // from RDD.id to blocks.
+ logInfo("Removing RDD " + rddId)
+ val rddPrefix = "rdd_" + rddId + "_"
+ val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1)
+ blocksToRemove.foreach(blockId => removeBlock(blockId, false))
+ blocksToRemove.size
+ }
+
+ /**
+ * Remove a block from both memory and disk.
+ */
+ def removeBlock(blockId: String, tellMaster: Boolean = true) {
+ logInfo("Removing block " + blockId)
+ val info = blockInfo.get(blockId).orNull
+ if (info != null) info.synchronized {
+ // Removals are idempotent in disk store and memory store. At worst, we get a warning.
+ val removedFromMemory = memoryStore.remove(blockId)
+ val removedFromDisk = diskStore.remove(blockId)
+ if (!removedFromMemory && !removedFromDisk) {
+ logWarning("Block " + blockId + " could not be removed as it was not found in either " +
+ "the disk or memory store")
+ }
+ blockInfo.remove(blockId)
+ if (tellMaster && info.tellMaster) {
+ reportBlockStatus(blockId, info)
+ }
+ } else {
+ // The block has already been removed; do nothing.
+ logWarning("Asked to remove block " + blockId + ", which does not exist")
+ }
+ }
+
+ def dropOldBlocks(cleanupTime: Long) {
+ logInfo("Dropping blocks older than " + cleanupTime)
+ val iterator = blockInfo.internalMap.entrySet().iterator()
+ while (iterator.hasNext) {
+ val entry = iterator.next()
+ val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
+ if (time < cleanupTime) {
+ info.synchronized {
+ val level = info.level
+ if (level.useMemory) {
+ memoryStore.remove(id)
+ }
+ if (level.useDisk) {
+ diskStore.remove(id)
+ }
+ iterator.remove()
+ logInfo("Dropped block " + id)
+ }
+ reportBlockStatus(id, info)
+ }
+ }
+ }
+
+ def shouldCompress(blockId: String): Boolean = {
+ if (ShuffleBlockManager.isShuffle(blockId)) {
+ compressShuffle
+ } else if (blockId.startsWith("broadcast_")) {
+ compressBroadcast
+ } else if (blockId.startsWith("rdd_")) {
+ compressRdds
+ } else {
+ false // Won't happen in a real cluster, but it can in tests
+ }
+ }
+
+ /**
+ * Wrap an output stream for compression if block compression is enabled for its block type
+ */
+ def wrapForCompression(blockId: String, s: OutputStream): OutputStream = {
+ if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
+ }
+
+ /**
+ * Wrap an input stream for compression if block compression is enabled for its block type
+ */
+ def wrapForCompression(blockId: String, s: InputStream): InputStream = {
+ if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
+ }
+
+ def dataSerialize(
+ blockId: String,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer): ByteBuffer = {
+ val byteStream = new FastByteArrayOutputStream(4096)
+ val ser = serializer.newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ byteStream.trim()
+ ByteBuffer.wrap(byteStream.array)
+ }
+
+ /**
+ * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of
+ * the iterator is reached.
+ */
+ def dataDeserialize(
+ blockId: String,
+ bytes: ByteBuffer,
+ serializer: Serializer = defaultSerializer): Iterator[Any] = {
+ bytes.rewind()
+ val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true))
+ serializer.newInstance().deserializeStream(stream).asIterator
+ }
+
+ def stop() {
+ if (heartBeatTask != null) {
+ heartBeatTask.cancel()
+ }
+ connectionManager.stop()
+ actorSystem.stop(slaveActor)
+ blockInfo.clear()
+ memoryStore.clear()
+ diskStore.clear()
+ metadataCleaner.cancel()
+ logInfo("BlockManager stopped")
+ }
+}
+
+
+private[spark] object BlockManager extends Logging {
+
+ val ID_GENERATOR = new IdGenerator
+
+ def getMaxMemoryFromSystemProperties: Long = {
+ val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble
+ (Runtime.getRuntime.maxMemory * memoryFraction).toLong
+ }
+
+ def getHeartBeatFrequencyFromSystemProperties: Long =
+ System.getProperty("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4
+
+ def getDisableHeartBeatsForTesting: Boolean =
+ System.getProperty("spark.test.disableBlockManagerHeartBeat", "false").toBoolean
+
+ /**
+ * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that
+ * might cause errors if one attempts to read from the unmapped buffer, but it's better than
+ * waiting for the GC to find it because that could lead to huge numbers of open files. There's
+ * unfortunately no standard API to do this.
+ */
+ def dispose(buffer: ByteBuffer) {
+ if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) {
+ logTrace("Unmapping " + buffer)
+ if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) {
+ buffer.asInstanceOf[DirectBuffer].cleaner().clean()
+ }
+ }
+ }
+
+ def blockIdsToBlockManagers(
+ blockIds: Array[String],
+ env: SparkEnv,
+ blockManagerMaster: BlockManagerMaster = null)
+ : Map[String, Seq[BlockManagerId]] =
+ {
+ // env == null and blockManagerMaster != null is used in tests
+ assert (env != null || blockManagerMaster != null)
+ val blockLocations: Seq[Seq[BlockManagerId]] = if (env != null) {
+ env.blockManager.getLocationBlockIds(blockIds)
+ } else {
+ blockManagerMaster.getLocations(blockIds)
+ }
+
+ val blockManagers = new HashMap[String, Seq[BlockManagerId]]
+ for (i <- 0 until blockIds.length) {
+ blockManagers(blockIds(i)) = blockLocations(i)
+ }
+ blockManagers.toMap
+ }
+
+ def blockIdsToExecutorIds(
+ blockIds: Array[String],
+ env: SparkEnv,
+ blockManagerMaster: BlockManagerMaster = null)
+ : Map[String, Seq[String]] =
+ {
+ blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId))
+ }
+
+ def blockIdsToHosts(
+ blockIds: Array[String],
+ env: SparkEnv,
+ blockManagerMaster: BlockManagerMaster = null)
+ : Map[String, Seq[String]] =
+ {
+ blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
new file mode 100644
index 0000000000..a22a80decc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
+import java.util.concurrent.ConcurrentHashMap
+import org.apache.spark.Utils
+
+/**
+ * This class represent an unique identifier for a BlockManager.
+ * The first 2 constructors of this class is made private to ensure that
+ * BlockManagerId objects can be created only using the apply method in
+ * the companion object. This allows de-duplication of ID objects.
+ * Also, constructor parameters are private to ensure that parameters cannot
+ * be modified from outside this class.
+ */
+private[spark] class BlockManagerId private (
+ private var executorId_ : String,
+ private var host_ : String,
+ private var port_ : Int,
+ private var nettyPort_ : Int
+ ) extends Externalizable {
+
+ private def this() = this(null, null, 0, 0) // For deserialization only
+
+ def executorId: String = executorId_
+
+ if (null != host_){
+ Utils.checkHost(host_, "Expected hostname")
+ assert (port_ > 0)
+ }
+
+ def hostPort: String = {
+ // DEBUG code
+ Utils.checkHost(host)
+ assert (port > 0)
+
+ host + ":" + port
+ }
+
+ def host: String = host_
+
+ def port: Int = port_
+
+ def nettyPort: Int = nettyPort_
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeUTF(executorId_)
+ out.writeUTF(host_)
+ out.writeInt(port_)
+ out.writeInt(nettyPort_)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ executorId_ = in.readUTF()
+ host_ = in.readUTF()
+ port_ = in.readInt()
+ nettyPort_ = in.readInt()
+ }
+
+ @throws(classOf[IOException])
+ private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
+
+ override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort)
+
+ override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort
+
+ override def equals(that: Any) = that match {
+ case id: BlockManagerId =>
+ executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort
+ case _ =>
+ false
+ }
+}
+
+
+private[spark] object BlockManagerId {
+
+ /**
+ * Returns a [[org.apache.spark.storage.BlockManagerId]] for the given configuraiton.
+ *
+ * @param execId ID of the executor.
+ * @param host Host name of the block manager.
+ * @param port Port of the block manager.
+ * @param nettyPort Optional port for the Netty-based shuffle sender.
+ * @return A new [[org.apache.spark.storage.BlockManagerId]].
+ */
+ def apply(execId: String, host: String, port: Int, nettyPort: Int) =
+ getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort))
+
+ def apply(in: ObjectInput) = {
+ val obj = new BlockManagerId()
+ obj.readExternal(in)
+ getCachedBlockManagerId(obj)
+ }
+
+ val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
+
+ def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
+ blockManagerIdCache.putIfAbsent(id, id)
+ blockManagerIdCache.get(id)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
new file mode 100644
index 0000000000..cf463d6ffc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import akka.actor.ActorRef
+import akka.dispatch.{Await, Future}
+import akka.pattern.ask
+import akka.util.Duration
+
+import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.storage.BlockManagerMessages._
+
+
+private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {
+
+ val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
+ val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
+
+ val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
+
+ val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+
+ /** Remove a dead executor from the driver actor. This is only called on the driver side. */
+ def removeExecutor(execId: String) {
+ tell(RemoveExecutor(execId))
+ logInfo("Removed " + execId + " successfully in removeExecutor")
+ }
+
+ /**
+ * Send the driver actor a heart beat from the slave. Returns true if everything works out,
+ * false if the driver does not know about the given block manager, which means the block
+ * manager should re-register.
+ */
+ def sendHeartBeat(blockManagerId: BlockManagerId): Boolean = {
+ askDriverWithReply[Boolean](HeartBeat(blockManagerId))
+ }
+
+ /** Register the BlockManager's id with the driver. */
+ def registerBlockManager(
+ blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ logInfo("Trying to register BlockManager")
+ tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor))
+ logInfo("Registered BlockManager")
+ }
+
+ def updateBlockInfo(
+ blockManagerId: BlockManagerId,
+ blockId: String,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long): Boolean = {
+ val res = askDriverWithReply[Boolean](
+ UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize))
+ logInfo("Updated info of block " + blockId)
+ res
+ }
+
+ /** Get locations of the blockId from the driver */
+ def getLocations(blockId: String): Seq[BlockManagerId] = {
+ askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
+ }
+
+ /** Get locations of multiple blockIds from the driver */
+ def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
+ }
+
+ /** Get ids of other nodes in the cluster from the driver */
+ def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
+ val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
+ if (result.length != numPeers) {
+ throw new SparkException(
+ "Error getting peers, only got " + result.size + " instead of " + numPeers)
+ }
+ result
+ }
+
+ /**
+ * Remove a block from the slaves that have it. This can only be used to remove
+ * blocks that the driver knows about.
+ */
+ def removeBlock(blockId: String) {
+ askDriverWithReply(RemoveBlock(blockId))
+ }
+
+ /**
+ * Remove all blocks belonging to the given RDD.
+ */
+ def removeRdd(rddId: Int, blocking: Boolean) {
+ val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
+ future onFailure {
+ case e: Throwable => logError("Failed to remove RDD " + rddId, e)
+ }
+ if (blocking) {
+ Await.result(future, timeout)
+ }
+ }
+
+ /**
+ * Return the memory status for each block manager, in the form of a map from
+ * the block manager's id to two long values. The first value is the maximum
+ * amount of memory allocated for the block manager, while the second is the
+ * amount of remaining memory.
+ */
+ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
+ askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
+ }
+
+ def getStorageStatus: Array[StorageStatus] = {
+ askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
+ }
+
+ /** Stop the driver actor, called only on the Spark driver node */
+ def stop() {
+ if (driverActor != null) {
+ tell(StopBlockManagerMaster)
+ driverActor = null
+ logInfo("BlockManagerMaster stopped")
+ }
+ }
+
+ /** Send a one-way message to the master actor, to which we expect it to reply with true. */
+ private def tell(message: Any) {
+ if (!askDriverWithReply[Boolean](message)) {
+ throw new SparkException("BlockManagerMasterActor returned false, expected true.")
+ }
+ }
+
+ /**
+ * Send a message to the driver actor and get its result within a default timeout, or
+ * throw a SparkException if this fails.
+ */
+ private def askDriverWithReply[T](message: Any): T = {
+ // TODO: Consider removing multiple attempts
+ if (driverActor == null) {
+ throw new SparkException("Error sending message to BlockManager as driverActor is null " +
+ "[message = " + message + "]")
+ }
+ var attempts = 0
+ var lastException: Exception = null
+ while (attempts < AKKA_RETRY_ATTEMPTS) {
+ attempts += 1
+ try {
+ val future = driverActor.ask(message)(timeout)
+ val result = Await.result(future, timeout)
+ if (result == null) {
+ throw new SparkException("BlockManagerMaster returned null")
+ }
+ return result.asInstanceOf[T]
+ } catch {
+ case ie: InterruptedException => throw ie
+ case e: Exception =>
+ lastException = e
+ logWarning("Error sending message to BlockManagerMaster in " + attempts + " attempts", e)
+ }
+ Thread.sleep(AKKA_RETRY_INTERVAL_MS)
+ }
+
+ throw new SparkException(
+ "Error sending message to BlockManagerMaster [message = " + message + "]", lastException)
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
new file mode 100644
index 0000000000..baa4a1da50
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.{HashMap => JHashMap}
+
+import scala.collection.mutable
+import scala.collection.JavaConversions._
+
+import akka.actor.{Actor, ActorRef, Cancellable}
+import akka.dispatch.Future
+import akka.pattern.ask
+import akka.util.Duration
+import akka.util.duration._
+
+import org.apache.spark.{Logging, Utils, SparkException}
+import org.apache.spark.storage.BlockManagerMessages._
+
+
+/**
+ * BlockManagerMasterActor is an actor on the master node to track statuses of
+ * all slaves' block managers.
+ */
+private[spark]
+class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
+
+ // Mapping from block manager id to the block manager's information.
+ private val blockManagerInfo =
+ new mutable.HashMap[BlockManagerId, BlockManagerMasterActor.BlockManagerInfo]
+
+ // Mapping from executor ID to block manager ID.
+ private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
+
+ // Mapping from block id to the set of block managers that have the block.
+ private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]]
+
+ val akkaTimeout = Duration.create(
+ System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+
+ initLogging()
+
+ val slaveTimeout = System.getProperty("spark.storage.blockManagerSlaveTimeoutMs",
+ "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong
+
+ val checkTimeoutInterval = System.getProperty("spark.storage.blockManagerTimeoutIntervalMs",
+ "60000").toLong
+
+ var timeoutCheckingTask: Cancellable = null
+
+ override def preStart() {
+ if (!BlockManager.getDisableHeartBeatsForTesting) {
+ timeoutCheckingTask = context.system.scheduler.schedule(
+ 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
+ }
+ super.preStart()
+ }
+
+ def receive = {
+ case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) =>
+ register(blockManagerId, maxMemSize, slaveActor)
+ sender ! true
+
+ case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) =>
+ // TODO: Ideally we want to handle all the message replies in receive instead of in the
+ // individual private methods.
+ updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size)
+
+ case GetLocations(blockId) =>
+ sender ! getLocations(blockId)
+
+ case GetLocationsMultipleBlockIds(blockIds) =>
+ sender ! getLocationsMultipleBlockIds(blockIds)
+
+ case GetPeers(blockManagerId, size) =>
+ sender ! getPeers(blockManagerId, size)
+
+ case GetMemoryStatus =>
+ sender ! memoryStatus
+
+ case GetStorageStatus =>
+ sender ! storageStatus
+
+ case RemoveRdd(rddId) =>
+ sender ! removeRdd(rddId)
+
+ case RemoveBlock(blockId) =>
+ removeBlockFromWorkers(blockId)
+ sender ! true
+
+ case RemoveExecutor(execId) =>
+ removeExecutor(execId)
+ sender ! true
+
+ case StopBlockManagerMaster =>
+ logInfo("Stopping BlockManagerMaster")
+ sender ! true
+ if (timeoutCheckingTask != null) {
+ timeoutCheckingTask.cancel()
+ }
+ context.stop(self)
+
+ case ExpireDeadHosts =>
+ expireDeadHosts()
+
+ case HeartBeat(blockManagerId) =>
+ sender ! heartBeat(blockManagerId)
+
+ case other =>
+ logWarning("Got unknown message: " + other)
+ }
+
+ private def removeRdd(rddId: Int): Future[Seq[Int]] = {
+ // First remove the metadata for the given RDD, and then asynchronously remove the blocks
+ // from the slaves.
+
+ val prefix = "rdd_" + rddId + "_"
+ // Find all blocks for the given RDD, remove the block from both blockLocations and
+ // the blockManagerInfo that is tracking the blocks.
+ val blocks = blockLocations.keySet().filter(_.startsWith(prefix))
+ blocks.foreach { blockId =>
+ val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId)
+ bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId)))
+ blockLocations.remove(blockId)
+ }
+
+ // Ask the slaves to remove the RDD, and put the result in a sequence of Futures.
+ // The dispatcher is used as an implicit argument into the Future sequence construction.
+ import context.dispatcher
+ val removeMsg = RemoveRdd(rddId)
+ Future.sequence(blockManagerInfo.values.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+ }.toSeq)
+ }
+
+ private def removeBlockManager(blockManagerId: BlockManagerId) {
+ val info = blockManagerInfo(blockManagerId)
+
+ // Remove the block manager from blockManagerIdByExecutor.
+ blockManagerIdByExecutor -= blockManagerId.executorId
+
+ // Remove it from blockManagerInfo and remove all the blocks.
+ blockManagerInfo.remove(blockManagerId)
+ val iterator = info.blocks.keySet.iterator
+ while (iterator.hasNext) {
+ val blockId = iterator.next
+ val locations = blockLocations.get(blockId)
+ locations -= blockManagerId
+ if (locations.size == 0) {
+ blockLocations.remove(locations)
+ }
+ }
+ }
+
+ private def expireDeadHosts() {
+ logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.")
+ val now = System.currentTimeMillis()
+ val minSeenTime = now - slaveTimeout
+ val toRemove = new mutable.HashSet[BlockManagerId]
+ for (info <- blockManagerInfo.values) {
+ if (info.lastSeenMs < minSeenTime) {
+ logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: " +
+ (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms")
+ toRemove += info.blockManagerId
+ }
+ }
+ toRemove.foreach(removeBlockManager)
+ }
+
+ private def removeExecutor(execId: String) {
+ logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
+ blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
+ }
+
+ private def heartBeat(blockManagerId: BlockManagerId): Boolean = {
+ if (!blockManagerInfo.contains(blockManagerId)) {
+ blockManagerId.executorId == "<driver>" && !isLocal
+ } else {
+ blockManagerInfo(blockManagerId).updateLastSeenMs()
+ true
+ }
+ }
+
+ // Remove a block from the slaves that have it. This can only be used to remove
+ // blocks that the master knows about.
+ private def removeBlockFromWorkers(blockId: String) {
+ val locations = blockLocations.get(blockId)
+ if (locations != null) {
+ locations.foreach { blockManagerId: BlockManagerId =>
+ val blockManager = blockManagerInfo.get(blockManagerId)
+ if (blockManager.isDefined) {
+ // Remove the block from the slave's BlockManager.
+ // Doesn't actually wait for a confirmation and the message might get lost.
+ // If message loss becomes frequent, we should add retry logic here.
+ blockManager.get.slaveActor ! RemoveBlock(blockId)
+ }
+ }
+ }
+ }
+
+ // Return a map from the block manager id to max memory and remaining memory.
+ private def memoryStatus: Map[BlockManagerId, (Long, Long)] = {
+ blockManagerInfo.map { case(blockManagerId, info) =>
+ (blockManagerId, (info.maxMem, info.remainingMem))
+ }.toMap
+ }
+
+ private def storageStatus: Array[StorageStatus] = {
+ blockManagerInfo.map { case(blockManagerId, info) =>
+ import collection.JavaConverters._
+ StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala.toMap)
+ }.toArray
+ }
+
+ private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+ if (id.executorId == "<driver>" && !isLocal) {
+ // Got a register message from the master node; don't register it
+ } else if (!blockManagerInfo.contains(id)) {
+ blockManagerIdByExecutor.get(id.executorId) match {
+ case Some(manager) =>
+ // A block manager of the same executor already exists.
+ // This should never happen. Let's just quit.
+ logError("Got two different block manager registrations on " + id.executorId)
+ System.exit(1)
+ case None =>
+ blockManagerIdByExecutor(id.executorId) = id
+ }
+ blockManagerInfo(id) = new BlockManagerMasterActor.BlockManagerInfo(
+ id, System.currentTimeMillis(), maxMemSize, slaveActor)
+ }
+ }
+
+ private def updateBlockInfo(
+ blockManagerId: BlockManagerId,
+ blockId: String,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long) {
+
+ if (!blockManagerInfo.contains(blockManagerId)) {
+ if (blockManagerId.executorId == "<driver>" && !isLocal) {
+ // We intentionally do not register the master (except in local mode),
+ // so we should not indicate failure.
+ sender ! true
+ } else {
+ sender ! false
+ }
+ return
+ }
+
+ if (blockId == null) {
+ blockManagerInfo(blockManagerId).updateLastSeenMs()
+ sender ! true
+ return
+ }
+
+ blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize)
+
+ var locations: mutable.HashSet[BlockManagerId] = null
+ if (blockLocations.containsKey(blockId)) {
+ locations = blockLocations.get(blockId)
+ } else {
+ locations = new mutable.HashSet[BlockManagerId]
+ blockLocations.put(blockId, locations)
+ }
+
+ if (storageLevel.isValid) {
+ locations.add(blockManagerId)
+ } else {
+ locations.remove(blockManagerId)
+ }
+
+ // Remove the block from master tracking if it has been removed on all slaves.
+ if (locations.size == 0) {
+ blockLocations.remove(blockId)
+ }
+ sender ! true
+ }
+
+ private def getLocations(blockId: String): Seq[BlockManagerId] = {
+ if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
+ }
+
+ private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ blockIds.map(blockId => getLocations(blockId))
+ }
+
+ private def getPeers(blockManagerId: BlockManagerId, size: Int): Seq[BlockManagerId] = {
+ val peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray
+
+ val selfIndex = peers.indexOf(blockManagerId)
+ if (selfIndex == -1) {
+ throw new SparkException("Self index for " + blockManagerId + " not found")
+ }
+
+ // Note that this logic will select the same node multiple times if there aren't enough peers
+ Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq
+ }
+}
+
+
+private[spark]
+object BlockManagerMasterActor {
+
+ case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long)
+
+ class BlockManagerInfo(
+ val blockManagerId: BlockManagerId,
+ timeMs: Long,
+ val maxMem: Long,
+ val slaveActor: ActorRef)
+ extends Logging {
+
+ private var _lastSeenMs: Long = timeMs
+ private var _remainingMem: Long = maxMem
+
+ // Mapping from block id to its status.
+ private val _blocks = new JHashMap[String, BlockStatus]
+
+ logInfo("Registering block manager %s with %s RAM".format(
+ blockManagerId.hostPort, Utils.bytesToString(maxMem)))
+
+ def updateLastSeenMs() {
+ _lastSeenMs = System.currentTimeMillis()
+ }
+
+ def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long,
+ diskSize: Long) {
+
+ updateLastSeenMs()
+
+ if (_blocks.containsKey(blockId)) {
+ // The block exists on the slave already.
+ val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel
+
+ if (originalLevel.useMemory) {
+ _remainingMem += memSize
+ }
+ }
+
+ if (storageLevel.isValid) {
+ // isValid means it is either stored in-memory or on-disk.
+ _blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize))
+ if (storageLevel.useMemory) {
+ _remainingMem -= memSize
+ logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(memSize),
+ Utils.bytesToString(_remainingMem)))
+ }
+ if (storageLevel.useDisk) {
+ logInfo("Added %s on disk on %s (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize)))
+ }
+ } else if (_blocks.containsKey(blockId)) {
+ // If isValid is not true, drop the block.
+ val blockStatus: BlockStatus = _blocks.get(blockId)
+ _blocks.remove(blockId)
+ if (blockStatus.storageLevel.useMemory) {
+ _remainingMem += blockStatus.memSize
+ logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize),
+ Utils.bytesToString(_remainingMem)))
+ }
+ if (blockStatus.storageLevel.useDisk) {
+ logInfo("Removed %s on %s on disk (size: %s)".format(
+ blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize)))
+ }
+ }
+ }
+
+ def removeBlock(blockId: String) {
+ if (_blocks.containsKey(blockId)) {
+ _remainingMem += _blocks.get(blockId).memSize
+ _blocks.remove(blockId)
+ }
+ }
+
+ def remainingMem: Long = _remainingMem
+
+ def lastSeenMs: Long = _lastSeenMs
+
+ def blocks: JHashMap[String, BlockStatus] = _blocks
+
+ override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
+
+ def clear() {
+ _blocks.clear()
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
new file mode 100644
index 0000000000..24333a179c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
+
+import akka.actor.ActorRef
+
+
+private[storage] object BlockManagerMessages {
+ //////////////////////////////////////////////////////////////////////////////////
+ // Messages from the master to slaves.
+ //////////////////////////////////////////////////////////////////////////////////
+ sealed trait ToBlockManagerSlave
+
+ // Remove a block from the slaves that have it. This can only be used to remove
+ // blocks that the master knows about.
+ case class RemoveBlock(blockId: String) extends ToBlockManagerSlave
+
+ // Remove all blocks belonging to a specific RDD.
+ case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
+
+
+ //////////////////////////////////////////////////////////////////////////////////
+ // Messages from slaves to the master.
+ //////////////////////////////////////////////////////////////////////////////////
+ sealed trait ToBlockManagerMaster
+
+ case class RegisterBlockManager(
+ blockManagerId: BlockManagerId,
+ maxMemSize: Long,
+ sender: ActorRef)
+ extends ToBlockManagerMaster
+
+ case class HeartBeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
+
+ class UpdateBlockInfo(
+ var blockManagerId: BlockManagerId,
+ var blockId: String,
+ var storageLevel: StorageLevel,
+ var memSize: Long,
+ var diskSize: Long)
+ extends ToBlockManagerMaster
+ with Externalizable {
+
+ def this() = this(null, null, null, 0, 0) // For deserialization only
+
+ override def writeExternal(out: ObjectOutput) {
+ blockManagerId.writeExternal(out)
+ out.writeUTF(blockId)
+ storageLevel.writeExternal(out)
+ out.writeLong(memSize)
+ out.writeLong(diskSize)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ blockManagerId = BlockManagerId(in)
+ blockId = in.readUTF()
+ storageLevel = StorageLevel(in)
+ memSize = in.readLong()
+ diskSize = in.readLong()
+ }
+ }
+
+ object UpdateBlockInfo {
+ def apply(blockManagerId: BlockManagerId,
+ blockId: String,
+ storageLevel: StorageLevel,
+ memSize: Long,
+ diskSize: Long): UpdateBlockInfo = {
+ new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)
+ }
+
+ // For pattern-matching
+ def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
+ Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
+ }
+ }
+
+ case class GetLocations(blockId: String) extends ToBlockManagerMaster
+
+ case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster
+
+ case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
+
+ case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
+
+ case object StopBlockManagerMaster extends ToBlockManagerMaster
+
+ case object GetMemoryStatus extends ToBlockManagerMaster
+
+ case object ExpireDeadHosts extends ToBlockManagerMaster
+
+ case object GetStorageStatus extends ToBlockManagerMaster
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
new file mode 100644
index 0000000000..951503019f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import akka.actor.Actor
+
+import org.apache.spark.storage.BlockManagerMessages._
+
+
+/**
+ * An actor to take commands from the master to execute options. For example,
+ * this is used to remove blocks from the slave's BlockManager.
+ */
+class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
+ override def receive = {
+
+ case RemoveBlock(blockId) =>
+ blockManager.removeBlock(blockId)
+
+ case RemoveRdd(rddId) =>
+ val numBlocksRemoved = blockManager.removeRdd(rddId)
+ sender ! numBlocksRemoved
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
new file mode 100644
index 0000000000..24190cdd67
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala
@@ -0,0 +1,48 @@
+package org.apache.spark.storage
+
+import com.codahale.metrics.{Gauge,MetricRegistry}
+
+import org.apache.spark.metrics.source.Source
+
+
+private[spark] class BlockManagerSource(val blockManager: BlockManager) extends Source {
+ val metricRegistry = new MetricRegistry()
+ val sourceName = "BlockManager"
+
+ metricRegistry.register(MetricRegistry.name("memory", "maxMem", "MBytes"), new Gauge[Long] {
+ override def getValue: Long = {
+ val storageStatusList = blockManager.master.getStorageStatus
+ val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _)
+ maxMem / 1024 / 1024
+ }
+ })
+
+ metricRegistry.register(MetricRegistry.name("memory", "remainingMem", "MBytes"), new Gauge[Long] {
+ override def getValue: Long = {
+ val storageStatusList = blockManager.master.getStorageStatus
+ val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _)
+ remainingMem / 1024 / 1024
+ }
+ })
+
+ metricRegistry.register(MetricRegistry.name("memory", "memUsed", "MBytes"), new Gauge[Long] {
+ override def getValue: Long = {
+ val storageStatusList = blockManager.master.getStorageStatus
+ val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _)
+ val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _)
+ (maxMem - remainingMem) / 1024 / 1024
+ }
+ })
+
+ metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed", "MBytes"), new Gauge[Long] {
+ override def getValue: Long = {
+ val storageStatusList = blockManager.master.getStorageStatus
+ val diskSpaceUsed = storageStatusList
+ .flatMap(_.blocks.values.map(_.diskSize))
+ .reduceOption(_ + _)
+ .getOrElse(0L)
+
+ diskSpaceUsed / 1024 / 1024
+ }
+ })
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
new file mode 100644
index 0000000000..f4856020e5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.{Logging, Utils}
+import org.apache.spark.network._
+
+/**
+ * A network interface for BlockManager. Each slave should have one
+ * BlockManagerWorker.
+ *
+ * TODO: Use event model.
+ */
+private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging {
+ initLogging()
+
+ blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive)
+
+ def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = {
+ logDebug("Handling message " + msg)
+ msg match {
+ case bufferMessage: BufferMessage => {
+ try {
+ logDebug("Handling as a buffer message " + bufferMessage)
+ val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage)
+ logDebug("Parsed as a block message array")
+ val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
+ return Some(new BlockMessageArray(responseMessages).toBufferMessage)
+ } catch {
+ case e: Exception => logError("Exception handling buffer message", e)
+ return None
+ }
+ }
+ case otherMessage: Any => {
+ logError("Unknown type message received: " + otherMessage)
+ return None
+ }
+ }
+ }
+
+ def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = {
+ blockMessage.getType match {
+ case BlockMessage.TYPE_PUT_BLOCK => {
+ val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
+ logDebug("Received [" + pB + "]")
+ putBlock(pB.id, pB.data, pB.level)
+ return None
+ }
+ case BlockMessage.TYPE_GET_BLOCK => {
+ val gB = new GetBlock(blockMessage.getId)
+ logDebug("Received [" + gB + "]")
+ val buffer = getBlock(gB.id)
+ if (buffer == null) {
+ return None
+ }
+ return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer)))
+ }
+ case _ => return None
+ }
+ }
+
+ private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) {
+ val startTimeMs = System.currentTimeMillis()
+ logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
+ blockManager.putBytes(id, bytes, level)
+ logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+ + " with data size: " + bytes.limit)
+ }
+
+ private def getBlock(id: String): ByteBuffer = {
+ val startTimeMs = System.currentTimeMillis()
+ logDebug("GetBlock " + id + " started from " + startTimeMs)
+ val buffer = blockManager.getLocalBytes(id) match {
+ case Some(bytes) => bytes
+ case None => null
+ }
+ logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs)
+ + " and got buffer " + buffer)
+ return buffer
+ }
+}
+
+private[spark] object BlockManagerWorker extends Logging {
+ private var blockManagerWorker: BlockManagerWorker = null
+
+ initLogging()
+
+ def startBlockManagerWorker(manager: BlockManager) {
+ blockManagerWorker = new BlockManagerWorker(manager)
+ }
+
+ def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = {
+ val blockManager = blockManagerWorker.blockManager
+ val connectionManager = blockManager.connectionManager
+ val blockMessage = BlockMessage.fromPutBlock(msg)
+ val blockMessageArray = new BlockMessageArray(blockMessage)
+ val resultMessage = connectionManager.sendMessageReliablySync(
+ toConnManagerId, blockMessageArray.toBufferMessage)
+ return (resultMessage != None)
+ }
+
+ def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = {
+ val blockManager = blockManagerWorker.blockManager
+ val connectionManager = blockManager.connectionManager
+ val blockMessage = BlockMessage.fromGetBlock(msg)
+ val blockMessageArray = new BlockMessageArray(blockMessage)
+ val responseMessage = connectionManager.sendMessageReliablySync(
+ toConnManagerId, blockMessageArray.toBufferMessage)
+ responseMessage match {
+ case Some(message) => {
+ val bufferMessage = message.asInstanceOf[BufferMessage]
+ logDebug("Response message received " + bufferMessage)
+ BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => {
+ logDebug("Found " + blockMessage)
+ return blockMessage.getData
+ })
+ }
+ case None => logDebug("No response message received"); return null
+ }
+ return null
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
new file mode 100644
index 0000000000..d8fa6a91d1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.StringBuilder
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.network._
+
+private[spark] case class GetBlock(id: String)
+private[spark] case class GotBlock(id: String, data: ByteBuffer)
+private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel)
+
+private[spark] class BlockMessage() {
+ // Un-initialized: typ = 0
+ // GetBlock: typ = 1
+ // GotBlock: typ = 2
+ // PutBlock: typ = 3
+ private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
+ private var id: String = null
+ private var data: ByteBuffer = null
+ private var level: StorageLevel = null
+
+ def set(getBlock: GetBlock) {
+ typ = BlockMessage.TYPE_GET_BLOCK
+ id = getBlock.id
+ }
+
+ def set(gotBlock: GotBlock) {
+ typ = BlockMessage.TYPE_GOT_BLOCK
+ id = gotBlock.id
+ data = gotBlock.data
+ }
+
+ def set(putBlock: PutBlock) {
+ typ = BlockMessage.TYPE_PUT_BLOCK
+ id = putBlock.id
+ data = putBlock.data
+ level = putBlock.level
+ }
+
+ def set(buffer: ByteBuffer) {
+ val startTime = System.currentTimeMillis
+ /*
+ println()
+ println("BlockMessage: ")
+ while(buffer.remaining > 0) {
+ print(buffer.get())
+ }
+ buffer.rewind()
+ println()
+ println()
+ */
+ typ = buffer.getInt()
+ val idLength = buffer.getInt()
+ val idBuilder = new StringBuilder(idLength)
+ for (i <- 1 to idLength) {
+ idBuilder += buffer.getChar()
+ }
+ id = idBuilder.toString()
+
+ if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+
+ val booleanInt = buffer.getInt()
+ val replication = buffer.getInt()
+ level = StorageLevel(booleanInt, replication)
+
+ val dataLength = buffer.getInt()
+ data = ByteBuffer.allocate(dataLength)
+ if (dataLength != buffer.remaining) {
+ throw new Exception("Error parsing buffer")
+ }
+ data.put(buffer)
+ data.flip()
+ } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+
+ val dataLength = buffer.getInt()
+ data = ByteBuffer.allocate(dataLength)
+ if (dataLength != buffer.remaining) {
+ throw new Exception("Error parsing buffer")
+ }
+ data.put(buffer)
+ data.flip()
+ }
+
+ val finishTime = System.currentTimeMillis
+ }
+
+ def set(bufferMsg: BufferMessage) {
+ val buffer = bufferMsg.buffers.apply(0)
+ buffer.clear()
+ set(buffer)
+ }
+
+ def getType: Int = {
+ return typ
+ }
+
+ def getId: String = {
+ return id
+ }
+
+ def getData: ByteBuffer = {
+ return data
+ }
+
+ def getLevel: StorageLevel = {
+ return level
+ }
+
+ def toBufferMessage: BufferMessage = {
+ val startTime = System.currentTimeMillis
+ val buffers = new ArrayBuffer[ByteBuffer]()
+ var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2)
+ buffer.putInt(typ).putInt(id.length())
+ id.foreach((x: Char) => buffer.putChar(x))
+ buffer.flip()
+ buffers += buffer
+
+ if (typ == BlockMessage.TYPE_PUT_BLOCK) {
+ buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication)
+ buffer.flip()
+ buffers += buffer
+
+ buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+ buffer.flip()
+ buffers += buffer
+
+ buffers += data
+ } else if (typ == BlockMessage.TYPE_GOT_BLOCK) {
+ buffer = ByteBuffer.allocate(4).putInt(data.remaining)
+ buffer.flip()
+ buffers += buffer
+
+ buffers += data
+ }
+
+ /*
+ println()
+ println("BlockMessage: ")
+ buffers.foreach(b => {
+ while(b.remaining > 0) {
+ print(b.get())
+ }
+ b.rewind()
+ })
+ println()
+ println()
+ */
+ val finishTime = System.currentTimeMillis
+ return Message.createBufferMessage(buffers)
+ }
+
+ override def toString: String = {
+ "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level +
+ ", data = " + (if (data != null) data.remaining.toString else "null") + "]"
+ }
+}
+
+private[spark] object BlockMessage {
+ val TYPE_NON_INITIALIZED: Int = 0
+ val TYPE_GET_BLOCK: Int = 1
+ val TYPE_GOT_BLOCK: Int = 2
+ val TYPE_PUT_BLOCK: Int = 3
+
+ def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(bufferMessage)
+ newBlockMessage
+ }
+
+ def fromByteBuffer(buffer: ByteBuffer): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(buffer)
+ newBlockMessage
+ }
+
+ def fromGetBlock(getBlock: GetBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(getBlock)
+ newBlockMessage
+ }
+
+ def fromGotBlock(gotBlock: GotBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(gotBlock)
+ newBlockMessage
+ }
+
+ def fromPutBlock(putBlock: PutBlock): BlockMessage = {
+ val newBlockMessage = new BlockMessage()
+ newBlockMessage.set(putBlock)
+ newBlockMessage
+ }
+
+ def main(args: Array[String]) {
+ val B = new BlockMessage()
+ B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
+ val bMsg = B.toBufferMessage
+ val C = new BlockMessage()
+ C.set(bMsg)
+
+ println(B.getId + " " + B.getLevel)
+ println(C.getId + " " + C.getLevel)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
new file mode 100644
index 0000000000..0aaf846b5b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark._
+import org.apache.spark.network._
+
+private[spark]
+class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging {
+
+ def this(bm: BlockMessage) = this(Array(bm))
+
+ def this() = this(null.asInstanceOf[Seq[BlockMessage]])
+
+ def apply(i: Int) = blockMessages(i)
+
+ def iterator = blockMessages.iterator
+
+ def length = blockMessages.length
+
+ initLogging()
+
+ def set(bufferMessage: BufferMessage) {
+ val startTime = System.currentTimeMillis
+ val newBlockMessages = new ArrayBuffer[BlockMessage]()
+ val buffer = bufferMessage.buffers(0)
+ buffer.clear()
+ /*
+ println()
+ println("BlockMessageArray: ")
+ while(buffer.remaining > 0) {
+ print(buffer.get())
+ }
+ buffer.rewind()
+ println()
+ println()
+ */
+ while (buffer.remaining() > 0) {
+ val size = buffer.getInt()
+ logDebug("Creating block message of size " + size + " bytes")
+ val newBuffer = buffer.slice()
+ newBuffer.clear()
+ newBuffer.limit(size)
+ logDebug("Trying to convert buffer " + newBuffer + " to block message")
+ val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer)
+ logDebug("Created " + newBlockMessage)
+ newBlockMessages += newBlockMessage
+ buffer.position(buffer.position() + size)
+ }
+ val finishTime = System.currentTimeMillis
+ logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s")
+ this.blockMessages = newBlockMessages
+ }
+
+ def toBufferMessage: BufferMessage = {
+ val buffers = new ArrayBuffer[ByteBuffer]()
+
+ blockMessages.foreach(blockMessage => {
+ val bufferMessage = blockMessage.toBufferMessage
+ logDebug("Adding " + blockMessage)
+ val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size)
+ sizeBuffer.flip
+ buffers += sizeBuffer
+ buffers ++= bufferMessage.buffers
+ logDebug("Added " + bufferMessage)
+ })
+
+ logDebug("Buffer list:")
+ buffers.foreach((x: ByteBuffer) => logDebug("" + x))
+ /*
+ println()
+ println("BlockMessageArray: ")
+ buffers.foreach(b => {
+ while(b.remaining > 0) {
+ print(b.get())
+ }
+ b.rewind()
+ })
+ println()
+ println()
+ */
+ return Message.createBufferMessage(buffers)
+ }
+}
+
+private[spark] object BlockMessageArray {
+
+ def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = {
+ val newBlockMessageArray = new BlockMessageArray()
+ newBlockMessageArray.set(bufferMessage)
+ newBlockMessageArray
+ }
+
+ def main(args: Array[String]) {
+ val blockMessages =
+ (0 until 10).map { i =>
+ if (i % 2 == 0) {
+ val buffer = ByteBuffer.allocate(100)
+ buffer.clear
+ BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER))
+ } else {
+ BlockMessage.fromGetBlock(GetBlock(i.toString))
+ }
+ }
+ val blockMessageArray = new BlockMessageArray(blockMessages)
+ println("Block message array created")
+
+ val bufferMessage = blockMessageArray.toBufferMessage
+ println("Converted to buffer message")
+
+ val totalSize = bufferMessage.size
+ val newBuffer = ByteBuffer.allocate(totalSize)
+ newBuffer.clear()
+ bufferMessage.buffers.foreach(buffer => {
+ assert (0 == buffer.position())
+ newBuffer.put(buffer)
+ buffer.rewind()
+ })
+ newBuffer.flip
+ val newBufferMessage = Message.createBufferMessage(newBuffer)
+ println("Copied to new buffer message, size = " + newBufferMessage.size)
+
+ val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage)
+ println("Converted back to block message array")
+ newBlockMessageArray.foreach(blockMessage => {
+ blockMessage.getType match {
+ case BlockMessage.TYPE_PUT_BLOCK => {
+ val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
+ println(pB)
+ }
+ case BlockMessage.TYPE_GET_BLOCK => {
+ val gB = new GetBlock(blockMessage.getId)
+ println(gB)
+ }
+ }
+ })
+ }
+}
+
+
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
new file mode 100644
index 0000000000..39f103297f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+
+/**
+ * An interface for writing JVM objects to some underlying storage. This interface allows
+ * appending data to an existing block, and can guarantee atomicity in the case of faults
+ * as it allows the caller to revert partial writes.
+ *
+ * This interface does not support concurrent writes.
+ */
+abstract class BlockObjectWriter(val blockId: String) {
+
+ var closeEventHandler: () => Unit = _
+
+ def open(): BlockObjectWriter
+
+ def close() {
+ closeEventHandler()
+ }
+
+ def isOpen: Boolean
+
+ def registerCloseEventHandler(handler: () => Unit) {
+ closeEventHandler = handler
+ }
+
+ /**
+ * Flush the partial writes and commit them as a single atomic block. Return the
+ * number of bytes written for this commit.
+ */
+ def commit(): Long
+
+ /**
+ * Reverts writes that haven't been flushed yet. Callers should invoke this function
+ * when there are runtime exceptions.
+ */
+ def revertPartialWrites()
+
+ /**
+ * Writes an object.
+ */
+ def write(value: Any)
+
+ /**
+ * Size of the valid writes, in bytes.
+ */
+ def size(): Long
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
new file mode 100644
index 0000000000..fa834371f4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.Logging
+
+/**
+ * Abstract class to store blocks
+ */
+private[spark]
+abstract class BlockStore(val blockManager: BlockManager) extends Logging {
+ def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel)
+
+ /**
+ * Put in a block and, possibly, also return its content as either bytes or another Iterator.
+ * This is used to efficiently write the values to multiple locations (e.g. for replication).
+ *
+ * @return a PutResult that contains the size of the data, as well as the values put if
+ * returnValues is true (if not, the result's data field can be null)
+ */
+ def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
+ returnValues: Boolean) : PutResult
+
+ /**
+ * Return the size of a block in bytes.
+ */
+ def getSize(blockId: String): Long
+
+ def getBytes(blockId: String): Option[ByteBuffer]
+
+ def getValues(blockId: String): Option[Iterator[Any]]
+
+ /**
+ * Remove a block, if it exists.
+ * @param blockId the block to remove.
+ * @return True if the block was found and removed, False otherwise.
+ */
+ def remove(blockId: String): Boolean
+
+ def contains(blockId: String): Boolean
+
+ def clear() { }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
new file mode 100644
index 0000000000..fd945e065c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -0,0 +1,329 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
+import java.nio.ByteBuffer
+import java.nio.channels.FileChannel
+import java.nio.channels.FileChannel.MapMode
+import java.util.{Random, Date}
+import java.text.SimpleDateFormat
+
+import scala.collection.mutable.ArrayBuffer
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import org.apache.spark.Utils
+import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.serializer.{Serializer, SerializationStream}
+import org.apache.spark.Logging
+import org.apache.spark.network.netty.ShuffleSender
+import org.apache.spark.network.netty.PathResolver
+
+
+/**
+ * Stores BlockManager blocks on disk.
+ */
+private class DiskStore(blockManager: BlockManager, rootDirs: String)
+ extends BlockStore(blockManager) with Logging {
+
+ class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ extends BlockObjectWriter(blockId) {
+
+ private val f: File = createFile(blockId /*, allowAppendExisting */)
+
+ // The file channel, used for repositioning / truncating the file.
+ private var channel: FileChannel = null
+ private var bs: OutputStream = null
+ private var objOut: SerializationStream = null
+ private var lastValidPosition = 0L
+ private var initialized = false
+
+ override def open(): DiskBlockObjectWriter = {
+ val fos = new FileOutputStream(f, true)
+ channel = fos.getChannel()
+ bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
+ objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
+ this
+ }
+
+ override def close() {
+ if (initialized) {
+ objOut.close()
+ channel = null
+ bs = null
+ objOut = null
+ }
+ // Invoke the close callback handler.
+ super.close()
+ }
+
+ override def isOpen: Boolean = objOut != null
+
+ // Flush the partial writes, and set valid length to be the length of the entire file.
+ // Return the number of bytes written for this commit.
+ override def commit(): Long = {
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
+ }
+
+ override def revertPartialWrites() {
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
+ }
+
+ override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
+ objOut.writeObject(value)
+ }
+
+ override def size(): Long = lastValidPosition
+ }
+
+ private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+
+ private var shuffleSender : ShuffleSender = null
+ // Create one local directory for each path mentioned in spark.local.dir; then, inside this
+ // directory, create multiple subdirectories that we will hash files into, in order to avoid
+ // having really large inodes at the top level.
+ private val localDirs: Array[File] = createLocalDirs()
+ private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+
+ addShutdownHook()
+
+ def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ : BlockObjectWriter = {
+ new DiskBlockObjectWriter(blockId, serializer, bufferSize)
+ }
+
+ override def getSize(blockId: String): Long = {
+ getFile(blockId).length()
+ }
+
+ override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ // So that we do not modify the input offsets !
+ // duplicate does not copy buffer, so inexpensive
+ val bytes = _bytes.duplicate()
+ logDebug("Attempting to put block " + blockId)
+ val startTime = System.currentTimeMillis
+ val file = createFile(blockId)
+ val channel = new RandomAccessFile(file, "rw").getChannel()
+ while (bytes.remaining > 0) {
+ channel.write(bytes)
+ }
+ channel.close()
+ val finishTime = System.currentTimeMillis
+ logDebug("Block %s stored as %s file on disk in %d ms".format(
+ blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
+ }
+
+ private def getFileBytes(file: File): ByteBuffer = {
+ val length = file.length()
+ val channel = new RandomAccessFile(file, "r").getChannel()
+ val buffer = try {
+ channel.map(MapMode.READ_ONLY, 0, length)
+ } finally {
+ channel.close()
+ }
+
+ buffer
+ }
+
+ override def putValues(
+ blockId: String,
+ values: ArrayBuffer[Any],
+ level: StorageLevel,
+ returnValues: Boolean)
+ : PutResult = {
+
+ logDebug("Attempting to write values for block " + blockId)
+ val startTime = System.currentTimeMillis
+ val file = createFile(blockId)
+ val fileOut = blockManager.wrapForCompression(blockId,
+ new FastBufferedOutputStream(new FileOutputStream(file)))
+ val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
+ objOut.writeAll(values.iterator)
+ objOut.close()
+ val length = file.length()
+
+ val timeTaken = System.currentTimeMillis - startTime
+ logDebug("Block %s stored as %s file on disk in %d ms".format(
+ blockId, Utils.bytesToString(length), timeTaken))
+
+ if (returnValues) {
+ // Return a byte buffer for the contents of the file
+ val buffer = getFileBytes(file)
+ PutResult(length, Right(buffer))
+ } else {
+ PutResult(length, null)
+ }
+ }
+
+ override def getBytes(blockId: String): Option[ByteBuffer] = {
+ val file = getFile(blockId)
+ val bytes = getFileBytes(file)
+ Some(bytes)
+ }
+
+ override def getValues(blockId: String): Option[Iterator[Any]] = {
+ getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
+ }
+
+ /**
+ * A version of getValues that allows a custom serializer. This is used as part of the
+ * shuffle short-circuit code.
+ */
+ def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
+ }
+
+ override def remove(blockId: String): Boolean = {
+ val file = getFile(blockId)
+ if (file.exists()) {
+ file.delete()
+ } else {
+ false
+ }
+ }
+
+ override def contains(blockId: String): Boolean = {
+ getFile(blockId).exists()
+ }
+
+ private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
+ val file = getFile(blockId)
+ if (!allowAppendExisting && file.exists()) {
+ // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
+ // was rescheduled on the same machine as the old task.
+ logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
+ file.delete()
+ }
+ file
+ }
+
+ private def getFile(blockId: String): File = {
+ logDebug("Getting file for block " + blockId)
+
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = math.abs(blockId.hashCode)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+
+ // Create the subdirectory if it doesn't already exist
+ var subDir = subDirs(dirId)(subDirId)
+ if (subDir == null) {
+ subDir = subDirs(dirId).synchronized {
+ val old = subDirs(dirId)(subDirId)
+ if (old != null) {
+ old
+ } else {
+ val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ newDir.mkdir()
+ subDirs(dirId)(subDirId) = newDir
+ newDir
+ }
+ }
+ }
+
+ new File(subDir, blockId)
+ }
+
+ private def createLocalDirs(): Array[File] = {
+ logDebug("Creating local directories at root dirs '" + rootDirs + "'")
+ val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirId: String = null
+ var tries = 0
+ val rand = new Random()
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
+ localDir = new File(rootDir, "spark-local-" + localDirId)
+ if (!localDir.exists) {
+ foundLocalDir = localDir.mkdirs()
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
+ " attempts to create local dir in " + rootDir)
+ System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
+ }
+ logInfo("Created local directory at " + localDir)
+ localDir
+ }
+ }
+
+ private def addShutdownHook() {
+ localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ localDirs.foreach { localDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
+ }
+ }
+ if (shuffleSender != null) {
+ shuffleSender.stop
+ }
+ }
+ })
+ }
+
+ private[storage] def startShuffleBlockSender(port: Int): Int = {
+ val pResolver = new PathResolver {
+ override def getAbsolutePath(blockId: String): String = {
+ if (!blockId.startsWith("shuffle_")) {
+ return null
+ }
+ DiskStore.this.getFile(blockId).getAbsolutePath()
+ }
+ }
+ shuffleSender = new ShuffleSender(port, pResolver)
+ logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
+ shuffleSender.port
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
new file mode 100644
index 0000000000..828dc0f22d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.LinkedHashMap
+import java.util.concurrent.ArrayBlockingQueue
+import org.apache.spark.{SizeEstimator, Utils}
+import java.nio.ByteBuffer
+import collection.mutable.ArrayBuffer
+
+/**
+ * Stores blocks in memory, either as ArrayBuffers of deserialized Java objects or as
+ * serialized ByteBuffers.
+ */
+private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
+ extends BlockStore(blockManager) {
+
+ case class Entry(value: Any, size: Long, deserialized: Boolean, var dropPending: Boolean = false)
+
+ private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true)
+ private var currentMemory = 0L
+ // Object used to ensure that only one thread is putting blocks and if necessary, dropping
+ // blocks from the memory store.
+ private val putLock = new Object()
+
+ logInfo("MemoryStore started with capacity %s.".format(Utils.bytesToString(maxMemory)))
+
+ def freeMemory: Long = maxMemory - currentMemory
+
+ override def getSize(blockId: String): Long = {
+ entries.synchronized {
+ entries.get(blockId).size
+ }
+ }
+
+ override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ // Work on a duplicate - since the original input might be used elsewhere.
+ val bytes = _bytes.duplicate()
+ bytes.rewind()
+ if (level.deserialized) {
+ val values = blockManager.dataDeserialize(blockId, bytes)
+ val elements = new ArrayBuffer[Any]
+ elements ++= values
+ val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef])
+ tryToPut(blockId, elements, sizeEstimate, true)
+ } else {
+ tryToPut(blockId, bytes, bytes.limit, false)
+ }
+ }
+
+ override def putValues(
+ blockId: String,
+ values: ArrayBuffer[Any],
+ level: StorageLevel,
+ returnValues: Boolean)
+ : PutResult = {
+
+ if (level.deserialized) {
+ val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef])
+ tryToPut(blockId, values, sizeEstimate, true)
+ PutResult(sizeEstimate, Left(values.iterator))
+ } else {
+ val bytes = blockManager.dataSerialize(blockId, values.iterator)
+ tryToPut(blockId, bytes, bytes.limit, false)
+ PutResult(bytes.limit(), Right(bytes.duplicate()))
+ }
+ }
+
+ override def getBytes(blockId: String): Option[ByteBuffer] = {
+ val entry = entries.synchronized {
+ entries.get(blockId)
+ }
+ if (entry == null) {
+ None
+ } else if (entry.deserialized) {
+ Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[ArrayBuffer[Any]].iterator))
+ } else {
+ Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data
+ }
+ }
+
+ override def getValues(blockId: String): Option[Iterator[Any]] = {
+ val entry = entries.synchronized {
+ entries.get(blockId)
+ }
+ if (entry == null) {
+ None
+ } else if (entry.deserialized) {
+ Some(entry.value.asInstanceOf[ArrayBuffer[Any]].iterator)
+ } else {
+ val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data
+ Some(blockManager.dataDeserialize(blockId, buffer))
+ }
+ }
+
+ override def remove(blockId: String): Boolean = {
+ entries.synchronized {
+ val entry = entries.get(blockId)
+ if (entry != null) {
+ entries.remove(blockId)
+ currentMemory -= entry.size
+ logInfo("Block %s of size %d dropped from memory (free %d)".format(
+ blockId, entry.size, freeMemory))
+ true
+ } else {
+ false
+ }
+ }
+ }
+
+ override def clear() {
+ entries.synchronized {
+ entries.clear()
+ }
+ logInfo("MemoryStore cleared")
+ }
+
+ /**
+ * Return the RDD ID that a given block ID is from, or null if it is not an RDD block.
+ */
+ private def getRddId(blockId: String): String = {
+ if (blockId.startsWith("rdd_")) {
+ blockId.split('_')(1)
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Try to put in a set of values, if we can free up enough space. The value should either be
+ * an ArrayBuffer if deserialized is true or a ByteBuffer otherwise. Its (possibly estimated)
+ * size must also be passed by the caller.
+ *
+ * Locks on the object putLock to ensure that all the put requests and its associated block
+ * dropping is done by only on thread at a time. Otherwise while one thread is dropping
+ * blocks to free memory for one block, another thread may use up the freed space for
+ * another block.
+ */
+ private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = {
+ // TODO: Its possible to optimize the locking by locking entries only when selecting blocks
+ // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been
+ // released, it must be ensured that those to-be-dropped blocks are not double counted for
+ // freeing up more space for another block that needs to be put. Only then the actually dropping
+ // of blocks (and writing to disk if necessary) can proceed in parallel.
+ putLock.synchronized {
+ if (ensureFreeSpace(blockId, size)) {
+ val entry = new Entry(value, size, deserialized)
+ entries.synchronized { entries.put(blockId, entry) }
+ currentMemory += size
+ if (deserialized) {
+ logInfo("Block %s stored as values to memory (estimated size %s, free %s)".format(
+ blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory)))
+ } else {
+ logInfo("Block %s stored as bytes to memory (size %s, free %s)".format(
+ blockId, Utils.bytesToString(size), Utils.bytesToString(freeMemory)))
+ }
+ true
+ } else {
+ // Tell the block manager that we couldn't put it in memory so that it can drop it to
+ // disk if the block allows disk storage.
+ val data = if (deserialized) {
+ Left(value.asInstanceOf[ArrayBuffer[Any]])
+ } else {
+ Right(value.asInstanceOf[ByteBuffer].duplicate())
+ }
+ blockManager.dropFromMemory(blockId, data)
+ false
+ }
+ }
+ }
+
+ /**
+ * Tries to free up a given amount of space to store a particular block, but can fail and return
+ * false if either the block is bigger than our memory or it would require replacing another
+ * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
+ * don't fit into memory that we want to avoid).
+ *
+ * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks.
+ * Otherwise, the freed space may fill up before the caller puts in their new value.
+ */
+ private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = {
+
+ logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
+ space, currentMemory, maxMemory))
+
+ if (space > maxMemory) {
+ logInfo("Will not store " + blockIdToAdd + " as it is larger than our memory limit")
+ return false
+ }
+
+ if (maxMemory - currentMemory < space) {
+ val rddToAdd = getRddId(blockIdToAdd)
+ val selectedBlocks = new ArrayBuffer[String]()
+ var selectedMemory = 0L
+
+ // This is synchronized to ensure that the set of entries is not changed
+ // (because of getValue or getBytes) while traversing the iterator, as that
+ // can lead to exceptions.
+ entries.synchronized {
+ val iterator = entries.entrySet().iterator()
+ while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) {
+ val pair = iterator.next()
+ val blockId = pair.getKey
+ if (rddToAdd != null && rddToAdd == getRddId(blockId)) {
+ logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " +
+ "block from the same RDD")
+ return false
+ }
+ selectedBlocks += blockId
+ selectedMemory += pair.getValue.size
+ }
+ }
+
+ if (maxMemory - (currentMemory - selectedMemory) >= space) {
+ logInfo(selectedBlocks.size + " blocks selected for dropping")
+ for (blockId <- selectedBlocks) {
+ val entry = entries.synchronized { entries.get(blockId) }
+ // This should never be null as only one thread should be dropping
+ // blocks and removing entries. However the check is still here for
+ // future safety.
+ if (entry != null) {
+ val data = if (entry.deserialized) {
+ Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
+ } else {
+ Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
+ }
+ blockManager.dropFromMemory(blockId, data)
+ }
+ }
+ return true
+ } else {
+ return false
+ }
+ }
+ return true
+ }
+
+ override def contains(blockId: String): Boolean = {
+ entries.synchronized { entries.containsKey(blockId) }
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/storage/PutResult.scala b/core/src/main/scala/org/apache/spark/storage/PutResult.scala
new file mode 100644
index 0000000000..2eba2f06b5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/PutResult.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.nio.ByteBuffer
+
+/**
+ * Result of adding a block into a BlockStore. Contains its estimated size, and possibly the
+ * values put if the caller asked for them to be returned (e.g. for chaining replication)
+ */
+private[spark] case class PutResult(size: Long, data: Either[Iterator[_], ByteBuffer])
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
new file mode 100644
index 0000000000..9da11efb57
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.serializer.Serializer
+
+
+private[spark]
+class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
+
+
+private[spark]
+trait ShuffleBlocks {
+ def acquireWriters(mapId: Int): ShuffleWriterGroup
+ def releaseWriters(group: ShuffleWriterGroup)
+}
+
+
+private[spark]
+class ShuffleBlockManager(blockManager: BlockManager) {
+
+ def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
+ new ShuffleBlocks {
+ // Get a group of writers for a map task.
+ override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
+ val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+ val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
+ blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
+ }
+ new ShuffleWriterGroup(mapId, writers)
+ }
+
+ override def releaseWriters(group: ShuffleWriterGroup) = {
+ // Nothing really to release here.
+ }
+ }
+ }
+}
+
+
+private[spark]
+object ShuffleBlockManager {
+
+ // Returns the block id for a given shuffle block.
+ def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = {
+ "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId
+ }
+
+ // Returns true if the block is a shuffle block.
+ def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_")
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
new file mode 100644
index 0000000000..755f1a760e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
+
+/**
+ * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory,
+ * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory
+ * in a serialized format, and whether to replicate the RDD partitions on multiple nodes.
+ * The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants for
+ * commonly useful storage levels. To create your own storage level object, use the factor method
+ * of the singleton object (`StorageLevel(...)`).
+ */
+class StorageLevel private(
+ private var useDisk_ : Boolean,
+ private var useMemory_ : Boolean,
+ private var deserialized_ : Boolean,
+ private var replication_ : Int = 1)
+ extends Externalizable {
+
+ // TODO: Also add fields for caching priority, dataset ID, and flushing.
+ private def this(flags: Int, replication: Int) {
+ this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication)
+ }
+
+ def this() = this(false, true, false) // For deserialization
+
+ def useDisk = useDisk_
+ def useMemory = useMemory_
+ def deserialized = deserialized_
+ def replication = replication_
+
+ assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes")
+
+ override def clone(): StorageLevel = new StorageLevel(
+ this.useDisk, this.useMemory, this.deserialized, this.replication)
+
+ override def equals(other: Any): Boolean = other match {
+ case s: StorageLevel =>
+ s.useDisk == useDisk &&
+ s.useMemory == useMemory &&
+ s.deserialized == deserialized &&
+ s.replication == replication
+ case _ =>
+ false
+ }
+
+ def isValid = ((useMemory || useDisk) && (replication > 0))
+
+ def toInt: Int = {
+ var ret = 0
+ if (useDisk_) {
+ ret |= 4
+ }
+ if (useMemory_) {
+ ret |= 2
+ }
+ if (deserialized_) {
+ ret |= 1
+ }
+ return ret
+ }
+
+ override def writeExternal(out: ObjectOutput) {
+ out.writeByte(toInt)
+ out.writeByte(replication_)
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val flags = in.readByte()
+ useDisk_ = (flags & 4) != 0
+ useMemory_ = (flags & 2) != 0
+ deserialized_ = (flags & 1) != 0
+ replication_ = in.readByte()
+ }
+
+ @throws(classOf[IOException])
+ private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this)
+
+ override def toString: String =
+ "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication)
+
+ override def hashCode(): Int = toInt * 41 + replication
+ def description : String = {
+ var result = ""
+ result += (if (useDisk) "Disk " else "")
+ result += (if (useMemory) "Memory " else "")
+ result += (if (deserialized) "Deserialized " else "Serialized")
+ result += "%sx Replicated".format(replication)
+ result
+ }
+}
+
+
+object StorageLevel {
+ val NONE = new StorageLevel(false, false, false)
+ val DISK_ONLY = new StorageLevel(true, false, false)
+ val DISK_ONLY_2 = new StorageLevel(true, false, false, 2)
+ val MEMORY_ONLY = new StorageLevel(false, true, true)
+ val MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2)
+ val MEMORY_ONLY_SER = new StorageLevel(false, true, false)
+ val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2)
+ val MEMORY_AND_DISK = new StorageLevel(true, true, true)
+ val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2)
+ val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false)
+ val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2)
+
+ /** Create a new StorageLevel object */
+ def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, replication: Int = 1) =
+ getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication))
+
+ /** Create a new StorageLevel object from its integer representation */
+ def apply(flags: Int, replication: Int) =
+ getCachedStorageLevel(new StorageLevel(flags, replication))
+
+ /** Read StorageLevel object from ObjectInput stream */
+ def apply(in: ObjectInput) = {
+ val obj = new StorageLevel()
+ obj.readExternal(in)
+ getCachedStorageLevel(obj)
+ }
+
+ private[spark]
+ val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
+
+ private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
+ storageLevelCache.putIfAbsent(level, level)
+ storageLevelCache.get(level)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
new file mode 100644
index 0000000000..0bba1dac54
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.apache.spark.{Utils, SparkContext}
+import BlockManagerMasterActor.BlockStatus
+
+private[spark]
+case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
+ blocks: Map[String, BlockStatus]) {
+
+ def memUsed(blockPrefix: String = "") = {
+ blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
+ reduceOption(_+_).getOrElse(0l)
+ }
+
+ def diskUsed(blockPrefix: String = "") = {
+ blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize).
+ reduceOption(_+_).getOrElse(0l)
+ }
+
+ def memRemaining : Long = maxMem - memUsed()
+
+}
+
+case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
+ numCachedPartitions: Int, numPartitions: Int, memSize: Long, diskSize: Long)
+ extends Ordered[RDDInfo] {
+ override def toString = {
+ import Utils.bytesToString
+ "RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; DiskSize: %s".format(name, id,
+ storageLevel.toString, numCachedPartitions, numPartitions, bytesToString(memSize), bytesToString(diskSize))
+ }
+
+ override def compare(that: RDDInfo) = {
+ this.id - that.id
+ }
+}
+
+/* Helper methods for storage-related objects */
+private[spark]
+object StorageUtils {
+
+ /* Returns RDD-level information, compiled from a list of StorageStatus objects */
+ def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus],
+ sc: SparkContext) : Array[RDDInfo] = {
+ rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
+ }
+
+ /* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */
+ def blockLocationsFromStorageStatus(storageStatusList: Seq[StorageStatus]) = {
+ val blockLocationPairs = storageStatusList
+ .flatMap(s => s.blocks.map(b => (b._1, s.blockManagerId.hostPort)))
+ blockLocationPairs.groupBy(_._1).map{case (k, v) => (k, v.unzip._2)}.toMap
+ }
+
+ /* Given a list of BlockStatus objets, returns information for each RDD */
+ def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
+ sc: SparkContext) : Array[RDDInfo] = {
+
+ // Group by rddId, ignore the partition name
+ val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) =>
+ k.substring(0,k.lastIndexOf('_'))
+ }.mapValues(_.values.toArray)
+
+ // For each RDD, generate an RDDInfo object
+ val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) =>
+ // Add up memory and disk sizes
+ val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
+ val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
+
+ // Find the id of the RDD, e.g. rdd_1 => 1
+ val rddId = rddKey.split("_").last.toInt
+
+ // Get the friendly name and storage level for the RDD, if available
+ sc.persistentRdds.get(rddId).map { r =>
+ val rddName = Option(r.name).getOrElse(rddKey)
+ val rddStorageLevel = r.getStorageLevel
+ RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize)
+ }
+ }.flatten.toArray
+
+ scala.util.Sorting.quickSort(rddInfos)
+
+ rddInfos
+ }
+
+ /* Removes all BlockStatus object that are not part of a block prefix */
+ def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
+ prefix: String) : Array[StorageStatus] = {
+
+ storageStatusList.map { status =>
+ val newBlocks = status.blocks.filterKeys(_.startsWith(prefix))
+ //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _)
+ StorageStatus(status.blockManagerId, status.maxMem, newBlocks)
+ }
+
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
new file mode 100644
index 0000000000..1d5afe9b08
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import akka.actor._
+
+import org.apache.spark.KryoSerializer
+import java.util.concurrent.ArrayBlockingQueue
+import util.Random
+
+/**
+ * This class tests the BlockManager and MemoryStore for thread safety and
+ * deadlocks. It spawns a number of producer and consumer threads. Producer
+ * threads continuously pushes blocks into the BlockManager and consumer
+ * threads continuously retrieves the blocks form the BlockManager and tests
+ * whether the block is correct or not.
+ */
+private[spark] object ThreadingTest {
+
+ val numProducers = 5
+ val numBlocksPerProducer = 20000
+
+ private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread {
+ val queue = new ArrayBlockingQueue[(String, Seq[Int])](100)
+
+ override def run() {
+ for (i <- 1 to numBlocksPerProducer) {
+ val blockId = "b-" + id + "-" + i
+ val blockSize = Random.nextInt(1000)
+ val block = (1 to blockSize).map(_ => Random.nextInt())
+ val level = randomLevel()
+ val startTime = System.currentTimeMillis()
+ manager.put(blockId, block.iterator, level, true)
+ println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms")
+ queue.add((blockId, block))
+ }
+ println("Producer thread " + id + " terminated")
+ }
+
+ def randomLevel(): StorageLevel = {
+ math.abs(Random.nextInt()) % 4 match {
+ case 0 => StorageLevel.MEMORY_ONLY
+ case 1 => StorageLevel.MEMORY_ONLY_SER
+ case 2 => StorageLevel.MEMORY_AND_DISK
+ case 3 => StorageLevel.MEMORY_AND_DISK_SER
+ }
+ }
+ }
+
+ private[spark] class ConsumerThread(
+ manager: BlockManager,
+ queue: ArrayBlockingQueue[(String, Seq[Int])]
+ ) extends Thread {
+ var numBlockConsumed = 0
+
+ override def run() {
+ println("Consumer thread started")
+ while(numBlockConsumed < numBlocksPerProducer) {
+ val (blockId, block) = queue.take()
+ val startTime = System.currentTimeMillis()
+ manager.get(blockId) match {
+ case Some(retrievedBlock) =>
+ assert(retrievedBlock.toList.asInstanceOf[List[Int]] == block.toList,
+ "Block " + blockId + " did not match")
+ println("Got block " + blockId + " in " +
+ (System.currentTimeMillis - startTime) + " ms")
+ case None =>
+ assert(false, "Block " + blockId + " could not be retrieved")
+ }
+ numBlockConsumed += 1
+ }
+ println("Consumer thread terminated")
+ }
+ }
+
+ def main(args: Array[String]) {
+ System.setProperty("spark.kryoserializer.buffer.mb", "1")
+ val actorSystem = ActorSystem("test")
+ val serializer = new KryoSerializer
+ val blockManagerMaster = new BlockManagerMaster(
+ actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
+ val blockManager = new BlockManager(
+ "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
+ val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
+ val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
+ producers.foreach(_.start)
+ consumers.foreach(_.start)
+ producers.foreach(_.join)
+ consumers.foreach(_.join)
+ blockManager.stop()
+ blockManagerMaster.stop()
+ actorSystem.shutdown()
+ actorSystem.awaitTermination()
+ println("Everything stopped.")
+ println(
+ "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.")
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
new file mode 100644
index 0000000000..cfa18f6ea4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import javax.servlet.http.{HttpServletResponse, HttpServletRequest}
+
+import scala.annotation.tailrec
+import scala.util.{Try, Success, Failure}
+import scala.xml.Node
+
+import net.liftweb.json.{JValue, pretty, render}
+
+import org.eclipse.jetty.server.{Server, Request, Handler}
+import org.eclipse.jetty.server.handler.{ResourceHandler, HandlerList, ContextHandler, AbstractHandler}
+import org.eclipse.jetty.util.thread.QueuedThreadPool
+
+import org.apache.spark.Logging
+
+
+/** Utilities for launching a web server using Jetty's HTTP Server class */
+private[spark] object JettyUtils extends Logging {
+ // Base type for a function that returns something based on an HTTP request. Allows for
+ // implicit conversion from many types of functions to jetty Handlers.
+ type Responder[T] = HttpServletRequest => T
+
+ // Conversions from various types of Responder's to jetty Handlers
+ implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler =
+ createHandler(responder, "text/json", (in: JValue) => pretty(render(in)))
+
+ implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler =
+ createHandler(responder, "text/html", (in: Seq[Node]) => "<!DOCTYPE html>" + in.toString)
+
+ implicit def textResponderToHandler(responder: Responder[String]): Handler =
+ createHandler(responder, "text/plain")
+
+ def createHandler[T <% AnyRef](responder: Responder[T], contentType: String,
+ extractFn: T => String = (in: Any) => in.toString): Handler = {
+ new AbstractHandler {
+ def handle(target: String,
+ baseRequest: Request,
+ request: HttpServletRequest,
+ response: HttpServletResponse) {
+ response.setContentType("%s;charset=utf-8".format(contentType))
+ response.setStatus(HttpServletResponse.SC_OK)
+ baseRequest.setHandled(true)
+ val result = responder(request)
+ response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ response.getWriter().println(extractFn(result))
+ }
+ }
+ }
+
+ /** Creates a handler that always redirects the user to a given path */
+ def createRedirectHandler(newPath: String): Handler = {
+ new AbstractHandler {
+ def handle(target: String,
+ baseRequest: Request,
+ request: HttpServletRequest,
+ response: HttpServletResponse) {
+ response.setStatus(302)
+ response.setHeader("Location", baseRequest.getRootURL + newPath)
+ baseRequest.setHandled(true)
+ }
+ }
+ }
+
+ /** Creates a handler for serving files from a static directory */
+ def createStaticHandler(resourceBase: String): ResourceHandler = {
+ val staticHandler = new ResourceHandler
+ Option(getClass.getClassLoader.getResource(resourceBase)) match {
+ case Some(res) =>
+ staticHandler.setResourceBase(res.toString)
+ case None =>
+ throw new Exception("Could not find resource path for Web UI: " + resourceBase)
+ }
+ staticHandler
+ }
+
+ /**
+ * Attempts to start a Jetty server at the supplied ip:port which uses the supplied handlers.
+ *
+ * If the desired port number is contented, continues incrementing ports until a free port is
+ * found. Returns the chosen port and the jetty Server object.
+ */
+ def startJettyServer(ip: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int) = {
+ val handlersToRegister = handlers.map { case(path, handler) =>
+ val contextHandler = new ContextHandler(path)
+ contextHandler.setHandler(handler)
+ contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler]
+ }
+
+ val handlerList = new HandlerList
+ handlerList.setHandlers(handlersToRegister.toArray)
+
+ @tailrec
+ def connect(currentPort: Int): (Server, Int) = {
+ val server = new Server(currentPort)
+ val pool = new QueuedThreadPool
+ pool.setDaemon(true)
+ server.setThreadPool(pool)
+ server.setHandler(handlerList)
+
+ Try { server.start() } match {
+ case s: Success[_] =>
+ sys.addShutdownHook(server.stop()) // Be kind, un-bind
+ (server, server.getConnectors.head.getLocalPort)
+ case f: Failure[_] =>
+ server.stop()
+ logInfo("Failed to create UI at port, %s. Trying again.".format(currentPort))
+ logInfo("Error was: " + f.toString)
+ connect((currentPort + 1) % 65536)
+ }
+ }
+
+ connect(port)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/Page.scala b/core/src/main/scala/org/apache/spark/ui/Page.scala
new file mode 100644
index 0000000000..b2a069a375
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/Page.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+private[spark] object Page extends Enumeration {
+ val Stages, Storage, Environment, Executors = Value
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
new file mode 100644
index 0000000000..4688effe0a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import javax.servlet.http.HttpServletRequest
+
+import org.eclipse.jetty.server.{Handler, Server}
+
+import org.apache.spark.{Logging, SparkContext, SparkEnv, Utils}
+import org.apache.spark.ui.env.EnvironmentUI
+import org.apache.spark.ui.exec.ExecutorsUI
+import org.apache.spark.ui.storage.BlockManagerUI
+import org.apache.spark.ui.jobs.JobProgressUI
+import org.apache.spark.ui.JettyUtils._
+
+/** Top level user interface for Spark */
+private[spark] class SparkUI(sc: SparkContext) extends Logging {
+ val host = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(Utils.localHostName())
+ val port = Option(System.getProperty("spark.ui.port")).getOrElse(SparkUI.DEFAULT_PORT).toInt
+ var boundPort: Option[Int] = None
+ var server: Option[Server] = None
+
+ val handlers = Seq[(String, Handler)](
+ ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)),
+ ("/", createRedirectHandler("/stages"))
+ )
+ val storage = new BlockManagerUI(sc)
+ val jobs = new JobProgressUI(sc)
+ val env = new EnvironmentUI(sc)
+ val exec = new ExecutorsUI(sc)
+
+ // Add MetricsServlet handlers by default
+ val metricsServletHandlers = SparkEnv.get.metricsSystem.getServletHandlers
+
+ val allHandlers = storage.getHandlers ++ jobs.getHandlers ++ env.getHandlers ++
+ exec.getHandlers ++ metricsServletHandlers ++ handlers
+
+ /** Bind the HTTP server which backs this web interface */
+ def bind() {
+ try {
+ val (srv, usedPort) = JettyUtils.startJettyServer("0.0.0.0", port, allHandlers)
+ logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort))
+ server = Some(srv)
+ boundPort = Some(usedPort)
+ } catch {
+ case e: Exception =>
+ logError("Failed to create Spark JettyUtils", e)
+ System.exit(1)
+ }
+ }
+
+ /** Initialize all components of the server */
+ def start() {
+ // NOTE: This is decoupled from bind() because of the following dependency cycle:
+ // DAGScheduler() requires that the port of this server is known
+ // This server must register all handlers, including JobProgressUI, before binding
+ // JobProgressUI registers a listener with SparkContext, which requires sc to initialize
+ jobs.start()
+ exec.start()
+ }
+
+ def stop() {
+ server.foreach(_.stop())
+ }
+
+ private[spark] def appUIAddress = "http://" + host + ":" + boundPort.getOrElse("-1")
+}
+
+private[spark] object SparkUI {
+ val DEFAULT_PORT = "3030"
+ val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static"
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
new file mode 100644
index 0000000000..ce1acf564c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import scala.xml.Node
+
+import org.apache.spark.SparkContext
+
+/** Utility functions for generating XML pages with spark content. */
+private[spark] object UIUtils {
+ import Page._
+
+ /** Returns a spark page with correctly formatted headers */
+ def headerSparkPage(content: => Seq[Node], sc: SparkContext, title: String, page: Page.Value)
+ : Seq[Node] = {
+ val jobs = page match {
+ case Stages => <li class="active"><a href="/stages">Stages</a></li>
+ case _ => <li><a href="/stages">Stages</a></li>
+ }
+ val storage = page match {
+ case Storage => <li class="active"><a href="/storage">Storage</a></li>
+ case _ => <li><a href="/storage">Storage</a></li>
+ }
+ val environment = page match {
+ case Environment => <li class="active"><a href="/environment">Environment</a></li>
+ case _ => <li><a href="/environment">Environment</a></li>
+ }
+ val executors = page match {
+ case Executors => <li class="active"><a href="/executors">Executors</a></li>
+ case _ => <li><a href="/executors">Executors</a></li>
+ }
+
+ <html>
+ <head>
+ <meta http-equiv="Content-type" content="text/html; charset=utf-8" />
+ <link rel="stylesheet" href="/static/bootstrap.min.css" type="text/css" />
+ <link rel="stylesheet" href="/static/webui.css" type="text/css" />
+ <script src="/static/sorttable.js"></script>
+ <title>{sc.appName} - {title}</title>
+ </head>
+ <body>
+ <div class="navbar navbar-static-top">
+ <div class="navbar-inner">
+ <a href="/" class="brand"><img src="/static/spark-logo-77x50px-hd.png" /></a>
+ <ul class="nav">
+ {jobs}
+ {storage}
+ {environment}
+ {executors}
+ </ul>
+ <p class="navbar-text pull-right"><strong>{sc.appName}</strong> application UI</p>
+ </div>
+ </div>
+
+ <div class="container-fluid">
+ <div class="row-fluid">
+ <div class="span12">
+ <h3 style="vertical-align: bottom; display: inline-block;">
+ {title}
+ </h3>
+ </div>
+ </div>
+ {content}
+ </div>
+ </body>
+ </html>
+ }
+
+ /** Returns a page with the spark css/js and a simple format. Used for scheduler UI. */
+ def basicSparkPage(content: => Seq[Node], title: String): Seq[Node] = {
+ <html>
+ <head>
+ <meta http-equiv="Content-type" content="text/html; charset=utf-8" />
+ <link rel="stylesheet" href="/static/bootstrap.min.css" type="text/css" />
+ <link rel="stylesheet" href="/static/webui.css" type="text/css" />
+ <script src="/static/sorttable.js"></script>
+ <title>{title}</title>
+ </head>
+ <body>
+ <div class="container-fluid">
+ <div class="row-fluid">
+ <div class="span12">
+ <h3 style="vertical-align: middle; display: inline-block;">
+ <img src="/static/spark-logo-77x50px-hd.png" style="margin-right: 15px;" />
+ {title}
+ </h3>
+ </div>
+ </div>
+ {content}
+ </div>
+ </body>
+ </html>
+ }
+
+ /** Returns an HTML table constructed by generating a row for each object in a sequence. */
+ def listingTable[T](
+ headers: Seq[String],
+ makeRow: T => Seq[Node],
+ rows: Seq[T],
+ fixedWidth: Boolean = false): Seq[Node] = {
+
+ val colWidth = 100.toDouble / headers.size
+ val colWidthAttr = if (fixedWidth) colWidth + "%" else ""
+ var tableClass = "table table-bordered table-striped table-condensed sortable"
+ if (fixedWidth) {
+ tableClass += " table-fixed"
+ }
+
+ <table class={tableClass}>
+ <thead>{headers.map(h => <th width={colWidthAttr}>{h}</th>)}</thead>
+ <tbody>
+ {rows.map(r => makeRow(r))}
+ </tbody>
+ </table>
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
new file mode 100644
index 0000000000..0ecb22d2f9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import scala.util.Random
+
+import org.apache.spark.SparkContext
+import org.apache.spark.SparkContext._
+import org.apache.spark.scheduler.cluster.SchedulingMode
+
+
+/**
+ * Continuously generates jobs that expose various features of the WebUI (internal testing tool).
+ *
+ * Usage: ./run spark.ui.UIWorkloadGenerator [master]
+ */
+private[spark] object UIWorkloadGenerator {
+ val NUM_PARTITIONS = 100
+ val INTER_JOB_WAIT_MS = 5000
+
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ println("usage: ./spark-class spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
+ System.exit(1)
+ }
+ val master = args(0)
+ val schedulingMode = SchedulingMode.withName(args(1))
+ val appName = "Spark UI Tester"
+
+ if (schedulingMode == SchedulingMode.FAIR) {
+ System.setProperty("spark.cluster.schedulingmode", "FAIR")
+ }
+ val sc = new SparkContext(master, appName)
+
+ def setProperties(s: String) = {
+ if(schedulingMode == SchedulingMode.FAIR) {
+ sc.setLocalProperty("spark.scheduler.cluster.fair.pool", s)
+ }
+ sc.setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, s)
+ }
+
+ val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS)
+ def nextFloat() = (new Random()).nextFloat()
+
+ val jobs = Seq[(String, () => Long)](
+ ("Count", baseData.count),
+ ("Cache and Count", baseData.map(x => x).cache.count),
+ ("Single Shuffle", baseData.map(x => (x % 10, x)).reduceByKey(_ + _).count),
+ ("Entirely failed phase", baseData.map(x => throw new Exception).count),
+ ("Partially failed phase", {
+ baseData.map{x =>
+ val probFailure = (4.0 / NUM_PARTITIONS)
+ if (nextFloat() < probFailure) {
+ throw new Exception("This is a task failure")
+ }
+ 1
+ }.count
+ }),
+ ("Partially failed phase (longer tasks)", {
+ baseData.map{x =>
+ val probFailure = (4.0 / NUM_PARTITIONS)
+ if (nextFloat() < probFailure) {
+ Thread.sleep(100)
+ throw new Exception("This is a task failure")
+ }
+ 1
+ }.count
+ }),
+ ("Job with delays", baseData.map(x => Thread.sleep(100)).count)
+ )
+
+ while (true) {
+ for ((desc, job) <- jobs) {
+ new Thread {
+ override def run() {
+ try {
+ setProperties(desc)
+ job()
+ println("Job funished: " + desc)
+ } catch {
+ case e: Exception =>
+ println("Job Failed: " + desc)
+ }
+ }
+ }.start
+ Thread.sleep(INTER_JOB_WAIT_MS)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
new file mode 100644
index 0000000000..c5bf2acc9e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.env
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.collection.JavaConversions._
+import scala.util.Properties
+import scala.xml.Node
+
+import org.eclipse.jetty.server.Handler
+
+import org.apache.spark.ui.JettyUtils._
+import org.apache.spark.ui.UIUtils
+import org.apache.spark.ui.Page.Environment
+import org.apache.spark.SparkContext
+
+
+private[spark] class EnvironmentUI(sc: SparkContext) {
+
+ def getHandlers = Seq[(String, Handler)](
+ ("/environment", (request: HttpServletRequest) => envDetails(request))
+ )
+
+ def envDetails(request: HttpServletRequest): Seq[Node] = {
+ val jvmInformation = Seq(
+ ("Java Version", "%s (%s)".format(Properties.javaVersion, Properties.javaVendor)),
+ ("Java Home", Properties.javaHome),
+ ("Scala Version", Properties.versionString),
+ ("Scala Home", Properties.scalaHome)
+ ).sorted
+ def jvmRow(kv: (String, String)) = <tr><td>{kv._1}</td><td>{kv._2}</td></tr>
+ def jvmTable =
+ UIUtils.listingTable(Seq("Name", "Value"), jvmRow, jvmInformation, fixedWidth = true)
+
+ val properties = System.getProperties.iterator.toSeq
+ val classPathProperty = properties.find { case (k, v) =>
+ k.contains("java.class.path")
+ }.getOrElse(("", ""))
+ val sparkProperties = properties.filter(_._1.startsWith("spark")).sorted
+ val otherProperties = properties.diff(sparkProperties :+ classPathProperty).sorted
+
+ val propertyHeaders = Seq("Name", "Value")
+ def propertyRow(kv: (String, String)) = <tr><td>{kv._1}</td><td>{kv._2}</td></tr>
+ val sparkPropertyTable =
+ UIUtils.listingTable(propertyHeaders, propertyRow, sparkProperties, fixedWidth = true)
+ val otherPropertyTable =
+ UIUtils.listingTable(propertyHeaders, propertyRow, otherProperties, fixedWidth = true)
+
+ val classPathEntries = classPathProperty._2
+ .split(System.getProperty("path.separator", ":"))
+ .filterNot(e => e.isEmpty)
+ .map(e => (e, "System Classpath"))
+ val addedJars = sc.addedJars.iterator.toSeq.map{case (path, time) => (path, "Added By User")}
+ val addedFiles = sc.addedFiles.iterator.toSeq.map{case (path, time) => (path, "Added By User")}
+ val classPath = (addedJars ++ addedFiles ++ classPathEntries).sorted
+
+ val classPathHeaders = Seq("Resource", "Source")
+ def classPathRow(data: (String, String)) = <tr><td>{data._1}</td><td>{data._2}</td></tr>
+ val classPathTable =
+ UIUtils.listingTable(classPathHeaders, classPathRow, classPath, fixedWidth = true)
+
+ val content =
+ <span>
+ <h4>Runtime Information</h4> {jvmTable}
+ <h4>Spark Properties</h4>
+ {sparkPropertyTable}
+ <h4>System Properties</h4>
+ {otherPropertyTable}
+ <h4>Classpath Entries</h4>
+ {classPathTable}
+ </span>
+
+ UIUtils.headerSparkPage(content, sc, "Environment", Environment)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
new file mode 100644
index 0000000000..efe6b474e0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
@@ -0,0 +1,136 @@
+package org.apache.spark.ui.exec
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.collection.mutable.{HashMap, HashSet}
+import scala.xml.Node
+
+import org.eclipse.jetty.server.Handler
+
+import org.apache.spark.{ExceptionFailure, Logging, Utils, SparkContext}
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler.cluster.TaskInfo
+import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListenerTaskEnd, SparkListener}
+import org.apache.spark.ui.JettyUtils._
+import org.apache.spark.ui.Page.Executors
+import org.apache.spark.ui.UIUtils
+
+
+private[spark] class ExecutorsUI(val sc: SparkContext) {
+
+ private var _listener: Option[ExecutorsListener] = None
+ def listener = _listener.get
+
+ def start() {
+ _listener = Some(new ExecutorsListener)
+ sc.addSparkListener(listener)
+ }
+
+ def getHandlers = Seq[(String, Handler)](
+ ("/executors", (request: HttpServletRequest) => render(request))
+ )
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val storageStatusList = sc.getExecutorStorageStatus
+
+ val maxMem = storageStatusList.map(_.maxMem).fold(0L)(_+_)
+ val memUsed = storageStatusList.map(_.memUsed()).fold(0L)(_+_)
+ val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_)
+
+ val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used",
+ "Active tasks", "Failed tasks", "Complete tasks", "Total tasks")
+
+ def execRow(kv: Seq[String]) = {
+ <tr>
+ <td>{kv(0)}</td>
+ <td>{kv(1)}</td>
+ <td>{kv(2)}</td>
+ <td sorttable_customkey={kv(3)}>
+ {Utils.bytesToString(kv(3).toLong)} / {Utils.bytesToString(kv(4).toLong)}
+ </td>
+ <td sorttable_customkey={kv(5)}>
+ {Utils.bytesToString(kv(5).toLong)}
+ </td>
+ <td>{kv(6)}</td>
+ <td>{kv(7)}</td>
+ <td>{kv(8)}</td>
+ <td>{kv(9)}</td>
+ </tr>
+ }
+
+ val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b)
+ val execTable = UIUtils.listingTable(execHead, execRow, execInfo)
+
+ val content =
+ <div class="row-fluid">
+ <div class="span12">
+ <ul class="unstyled">
+ <li><strong>Memory:</strong>
+ {Utils.bytesToString(memUsed)} Used
+ ({Utils.bytesToString(maxMem)} Total) </li>
+ <li><strong>Disk:</strong> {Utils.bytesToString(diskSpaceUsed)} Used </li>
+ </ul>
+ </div>
+ </div>
+ <div class = "row">
+ <div class="span12">
+ {execTable}
+ </div>
+ </div>;
+
+ UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors)
+ }
+
+ def getExecInfo(a: Int): Seq[String] = {
+ val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId
+ val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort
+ val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString
+ val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString
+ val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString
+ val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString
+ val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0)
+ val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0)
+ val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0)
+ val totalTasks = activeTasks + failedTasks + completedTasks
+
+ Seq(
+ execId,
+ hostPort,
+ rddBlocks,
+ memUsed,
+ maxMem,
+ diskUsed,
+ activeTasks.toString,
+ failedTasks.toString,
+ completedTasks.toString,
+ totalTasks.toString
+ )
+ }
+
+ private[spark] class ExecutorsListener extends SparkListener with Logging {
+ val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]()
+ val executorToTasksComplete = HashMap[String, Int]()
+ val executorToTasksFailed = HashMap[String, Int]()
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ val eid = taskStart.taskInfo.executorId
+ val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]())
+ activeTasks += taskStart.taskInfo
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val eid = taskEnd.taskInfo.executorId
+ val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]())
+ activeTasks -= taskEnd.taskInfo
+ val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
+ taskEnd.reason match {
+ case e: ExceptionFailure =>
+ executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1
+ (Some(e), e.metrics)
+ case _ =>
+ executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1
+ (None, Option(taskEnd.taskMetrics))
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
new file mode 100644
index 0000000000..3b428effaf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.{NodeSeq, Node}
+
+import org.apache.spark.scheduler.cluster.SchedulingMode
+import org.apache.spark.ui.Page._
+import org.apache.spark.ui.UIUtils._
+
+
+/** Page showing list of all ongoing and recently finished stages and pools*/
+private[spark] class IndexPage(parent: JobProgressUI) {
+ def listener = parent.listener
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ listener.synchronized {
+ val activeStages = listener.activeStages.toSeq
+ val completedStages = listener.completedStages.reverse.toSeq
+ val failedStages = listener.failedStages.reverse.toSeq
+ val now = System.currentTimeMillis()
+
+ var activeTime = 0L
+ for (tasks <- listener.stageToTasksActive.values; t <- tasks) {
+ activeTime += t.timeRunning(now)
+ }
+
+ val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent)
+ val completedStagesTable = new StageTable(completedStages.sortBy(_.submissionTime).reverse, parent)
+ val failedStagesTable = new StageTable(failedStages.sortBy(_.submissionTime).reverse, parent)
+
+ val pools = listener.sc.getAllPools
+ val poolTable = new PoolTable(pools, listener)
+ val summary: NodeSeq =
+ <div>
+ <ul class="unstyled">
+ <li>
+ <strong>Total Duration: </strong>
+ {parent.formatDuration(now - listener.sc.startTime)}
+ </li>
+ <li><strong>Scheduling Mode:</strong> {parent.sc.getSchedulingMode}</li>
+ <li>
+ <a href="#active"><strong>Active Stages:</strong></a>
+ {activeStages.size}
+ </li>
+ <li>
+ <a href="#completed"><strong>Completed Stages:</strong></a>
+ {completedStages.size}
+ </li>
+ <li>
+ <a href="#failed"><strong>Failed Stages:</strong></a>
+ {failedStages.size}
+ </li>
+ </ul>
+ </div>
+
+ val content = summary ++
+ {if (listener.sc.getSchedulingMode == SchedulingMode.FAIR) {
+ <h4>{pools.size} Fair Scheduler Pools</h4> ++ poolTable.toNodeSeq
+ } else {
+ Seq()
+ }} ++
+ <h4 id="active">Active Stages ({activeStages.size})</h4> ++
+ activeStagesTable.toNodeSeq++
+ <h4 id="completed">Completed Stages ({completedStages.size})</h4> ++
+ completedStagesTable.toNodeSeq++
+ <h4 id ="failed">Failed Stages ({failedStages.size})</h4> ++
+ failedStagesTable.toNodeSeq
+
+ headerSparkPage(content, parent.sc, "Spark Stages", Stages)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
new file mode 100644
index 0000000000..ae02226300
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -0,0 +1,156 @@
+package org.apache.spark.ui.jobs
+
+import scala.Seq
+import scala.collection.mutable.{ListBuffer, HashMap, HashSet}
+
+import org.apache.spark.{ExceptionFailure, SparkContext, Success, Utils}
+import org.apache.spark.scheduler._
+import org.apache.spark.scheduler.cluster.TaskInfo
+import org.apache.spark.executor.TaskMetrics
+import collection.mutable
+
+/**
+ * Tracks task-level information to be displayed in the UI.
+ *
+ * All access to the data structures in this class must be synchronized on the
+ * class, since the UI thread and the DAGScheduler event loop may otherwise
+ * be reading/updating the internal data structures concurrently.
+ */
+private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener {
+ // How many stages to remember
+ val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt
+ val DEFAULT_POOL_NAME = "default"
+
+ val stageToPool = new HashMap[Stage, String]()
+ val stageToDescription = new HashMap[Stage, String]()
+ val poolToActiveStages = new HashMap[String, HashSet[Stage]]()
+
+ val activeStages = HashSet[Stage]()
+ val completedStages = ListBuffer[Stage]()
+ val failedStages = ListBuffer[Stage]()
+
+ // Total metrics reflect metrics only for completed tasks
+ var totalTime = 0L
+ var totalShuffleRead = 0L
+ var totalShuffleWrite = 0L
+
+ val stageToTime = HashMap[Int, Long]()
+ val stageToShuffleRead = HashMap[Int, Long]()
+ val stageToShuffleWrite = HashMap[Int, Long]()
+ val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
+ val stageToTasksComplete = HashMap[Int, Int]()
+ val stageToTasksFailed = HashMap[Int, Int]()
+ val stageToTaskInfos =
+ HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
+
+ override def onJobStart(jobStart: SparkListenerJobStart) {}
+
+ override def onStageCompleted(stageCompleted: StageCompleted) = synchronized {
+ val stage = stageCompleted.stageInfo.stage
+ poolToActiveStages(stageToPool(stage)) -= stage
+ activeStages -= stage
+ completedStages += stage
+ trimIfNecessary(completedStages)
+ }
+
+ /** If stages is too large, remove and garbage collect old stages */
+ def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized {
+ if (stages.size > RETAINED_STAGES) {
+ val toRemove = RETAINED_STAGES / 10
+ stages.takeRight(toRemove).foreach( s => {
+ stageToTaskInfos.remove(s.id)
+ stageToTime.remove(s.id)
+ stageToShuffleRead.remove(s.id)
+ stageToShuffleWrite.remove(s.id)
+ stageToTasksActive.remove(s.id)
+ stageToTasksComplete.remove(s.id)
+ stageToTasksFailed.remove(s.id)
+ stageToPool.remove(s)
+ if (stageToDescription.contains(s)) {stageToDescription.remove(s)}
+ })
+ stages.trimEnd(toRemove)
+ }
+ }
+
+ /** For FIFO, all stages are contained by "default" pool but "default" pool here is meaningless */
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized {
+ val stage = stageSubmitted.stage
+ activeStages += stage
+
+ val poolName = Option(stageSubmitted.properties).map {
+ p => p.getProperty("spark.scheduler.cluster.fair.pool", DEFAULT_POOL_NAME)
+ }.getOrElse(DEFAULT_POOL_NAME)
+ stageToPool(stage) = poolName
+
+ val description = Option(stageSubmitted.properties).flatMap {
+ p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
+ }
+ description.map(d => stageToDescription(stage) = d)
+
+ val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]())
+ stages += stage
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
+ val sid = taskStart.task.stageId
+ val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+ tasksActive += taskStart.taskInfo
+ val taskList = stageToTaskInfos.getOrElse(
+ sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
+ taskList += ((taskStart.taskInfo, None, None))
+ stageToTaskInfos(sid) = taskList
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
+ val sid = taskEnd.task.stageId
+ val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+ tasksActive -= taskEnd.taskInfo
+ val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
+ taskEnd.reason match {
+ case e: ExceptionFailure =>
+ stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1
+ (Some(e), e.metrics)
+ case _ =>
+ stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1
+ (None, Option(taskEnd.taskMetrics))
+ }
+
+ stageToTime.getOrElseUpdate(sid, 0L)
+ val time = metrics.map(m => m.executorRunTime).getOrElse(0)
+ stageToTime(sid) += time
+ totalTime += time
+
+ stageToShuffleRead.getOrElseUpdate(sid, 0L)
+ val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s =>
+ s.remoteBytesRead).getOrElse(0L)
+ stageToShuffleRead(sid) += shuffleRead
+ totalShuffleRead += shuffleRead
+
+ stageToShuffleWrite.getOrElseUpdate(sid, 0L)
+ val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s =>
+ s.shuffleBytesWritten).getOrElse(0L)
+ stageToShuffleWrite(sid) += shuffleWrite
+ totalShuffleWrite += shuffleWrite
+
+ val taskList = stageToTaskInfos.getOrElse(
+ sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
+ taskList -= ((taskEnd.taskInfo, None, None))
+ taskList += ((taskEnd.taskInfo, metrics, failureInfo))
+ stageToTaskInfos(sid) = taskList
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
+ jobEnd match {
+ case end: SparkListenerJobEnd =>
+ end.jobResult match {
+ case JobFailed(ex, Some(stage)) =>
+ activeStages -= stage
+ poolToActiveStages(stageToPool(stage)) -= stage
+ failedStages += stage
+ trimIfNecessary(failedStages)
+ case _ =>
+ }
+ case _ =>
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
new file mode 100644
index 0000000000..1bb7638bd9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import akka.util.Duration
+
+import java.text.SimpleDateFormat
+
+import javax.servlet.http.HttpServletRequest
+
+import org.eclipse.jetty.server.Handler
+
+import scala.Seq
+import scala.collection.mutable.{HashSet, ListBuffer, HashMap, ArrayBuffer}
+
+import org.apache.spark.ui.JettyUtils._
+import org.apache.spark.{ExceptionFailure, SparkContext, Success, Utils}
+import org.apache.spark.scheduler._
+import collection.mutable
+import org.apache.spark.scheduler.cluster.SchedulingMode
+import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
+
+/** Web UI showing progress status of all jobs in the given SparkContext. */
+private[spark] class JobProgressUI(val sc: SparkContext) {
+ private var _listener: Option[JobProgressListener] = None
+ def listener = _listener.get
+ val dateFmt = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+
+ private val indexPage = new IndexPage(this)
+ private val stagePage = new StagePage(this)
+ private val poolPage = new PoolPage(this)
+
+ def start() {
+ _listener = Some(new JobProgressListener(sc))
+ sc.addSparkListener(listener)
+ }
+
+ def formatDuration(ms: Long) = Utils.msDurationToString(ms)
+
+ def getHandlers = Seq[(String, Handler)](
+ ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)),
+ ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)),
+ ("/stages", (request: HttpServletRequest) => indexPage.render(request))
+ )
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
new file mode 100644
index 0000000000..ce92b6932b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
@@ -0,0 +1,32 @@
+package org.apache.spark.ui.jobs
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.{NodeSeq, Node}
+import scala.collection.mutable.HashSet
+
+import org.apache.spark.scheduler.Stage
+import org.apache.spark.ui.UIUtils._
+import org.apache.spark.ui.Page._
+
+/** Page showing specific pool details */
+private[spark] class PoolPage(parent: JobProgressUI) {
+ def listener = parent.listener
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ listener.synchronized {
+ val poolName = request.getParameter("poolname")
+ val poolToActiveStages = listener.poolToActiveStages
+ val activeStages = poolToActiveStages.get(poolName).toSeq.flatten
+ val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent)
+
+ val pool = listener.sc.getPoolForName(poolName).get
+ val poolTable = new PoolTable(Seq(pool), listener)
+
+ val content = <h4>Summary </h4> ++ poolTable.toNodeSeq() ++
+ <h4>{activeStages.size} Active Stages</h4> ++ activeStagesTable.toNodeSeq()
+
+ headerSparkPage(content, parent.sc, "Fair Scheduler Pool: " + poolName, Stages)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
new file mode 100644
index 0000000000..f31465e59d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
@@ -0,0 +1,55 @@
+package org.apache.spark.ui.jobs
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.xml.Node
+
+import org.apache.spark.scheduler.Stage
+import org.apache.spark.scheduler.cluster.Schedulable
+
+/** Table showing list of pools */
+private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) {
+
+ var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages
+
+ def toNodeSeq(): Seq[Node] = {
+ listener.synchronized {
+ poolTable(poolRow, pools)
+ }
+ }
+
+ private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node],
+ rows: Seq[Schedulable]
+ ): Seq[Node] = {
+ <table class="table table-bordered table-striped table-condensed sortable table-fixed">
+ <thead>
+ <th>Pool Name</th>
+ <th>Minimum Share</th>
+ <th>Pool Weight</th>
+ <th>Active Stages</th>
+ <th>Running Tasks</th>
+ <th>SchedulingMode</th>
+ </thead>
+ <tbody>
+ {rows.map(r => makeRow(r, poolToActiveStages))}
+ </tbody>
+ </table>
+ }
+
+ private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]])
+ : Seq[Node] = {
+ val activeStages = poolToActiveStages.get(p.name) match {
+ case Some(stages) => stages.size
+ case None => 0
+ }
+ <tr>
+ <td><a href={"/stages/pool?poolname=%s".format(p.name)}>{p.name}</a></td>
+ <td>{p.minShare}</td>
+ <td>{p.weight}</td>
+ <td>{activeStages}</td>
+ <td>{p.runningTasks}</td>
+ <td>{p.schedulingMode}</td>
+ </tr>
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
new file mode 100644
index 0000000000..2fe85bc0cf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import java.util.Date
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.apache.spark.ui.UIUtils._
+import org.apache.spark.ui.Page._
+import org.apache.spark.util.Distribution
+import org.apache.spark.{ExceptionFailure, Utils}
+import org.apache.spark.scheduler.cluster.TaskInfo
+import org.apache.spark.executor.TaskMetrics
+
+/** Page showing statistics and task list for a given stage */
+private[spark] class StagePage(parent: JobProgressUI) {
+ def listener = parent.listener
+ val dateFmt = parent.dateFmt
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ listener.synchronized {
+ val stageId = request.getParameter("id").toInt
+ val now = System.currentTimeMillis()
+
+ if (!listener.stageToTaskInfos.contains(stageId)) {
+ val content =
+ <div>
+ <h4>Summary Metrics</h4> No tasks have started yet
+ <h4>Tasks</h4> No tasks have started yet
+ </div>
+ return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages)
+ }
+
+ val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
+
+ val numCompleted = tasks.count(_._1.finished)
+ val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L)
+ val hasShuffleRead = shuffleReadBytes > 0
+ val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L)
+ val hasShuffleWrite = shuffleWriteBytes > 0
+
+ var activeTime = 0L
+ listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
+
+ val summary =
+ <div>
+ <ul class="unstyled">
+ <li>
+ <strong>CPU time: </strong>
+ {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)}
+ </li>
+ {if (hasShuffleRead)
+ <li>
+ <strong>Shuffle read: </strong>
+ {Utils.bytesToString(shuffleReadBytes)}
+ </li>
+ }
+ {if (hasShuffleWrite)
+ <li>
+ <strong>Shuffle write: </strong>
+ {Utils.bytesToString(shuffleWriteBytes)}
+ </li>
+ }
+ </ul>
+ </div>
+
+ val taskHeaders: Seq[String] =
+ Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++
+ Seq("GC Time") ++
+ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
+ {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++
+ Seq("Errors")
+
+ val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks)
+
+ // Excludes tasks which failed and have incomplete metrics
+ val validTasks = tasks.filter(t => t._1.status == "SUCCESS" && (t._2.isDefined))
+
+ val summaryTable: Option[Seq[Node]] =
+ if (validTasks.size == 0) {
+ None
+ }
+ else {
+ val serviceTimes = validTasks.map{case (info, metrics, exception) =>
+ metrics.get.executorRunTime.toDouble}
+ val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map(
+ ms => parent.formatDuration(ms.toLong))
+
+ def getQuantileCols(data: Seq[Double]) =
+ Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong))
+
+ val shuffleReadSizes = validTasks.map {
+ case(info, metrics, exception) =>
+ metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble
+ }
+ val shuffleReadQuantiles = "Shuffle Read (Remote)" +: getQuantileCols(shuffleReadSizes)
+
+ val shuffleWriteSizes = validTasks.map {
+ case(info, metrics, exception) =>
+ metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble
+ }
+ val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes)
+
+ val listings: Seq[Seq[String]] = Seq(serviceQuantiles,
+ if (hasShuffleRead) shuffleReadQuantiles else Nil,
+ if (hasShuffleWrite) shuffleWriteQuantiles else Nil)
+
+ val quantileHeaders = Seq("Metric", "Min", "25th percentile",
+ "Median", "75th percentile", "Max")
+ def quantileRow(data: Seq[String]): Seq[Node] = <tr> {data.map(d => <td>{d}</td>)} </tr>
+ Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
+ }
+
+ val content =
+ summary ++
+ <h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++
+ <div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++
+ <h4>Tasks</h4> ++ taskTable;
+
+ headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages)
+ }
+ }
+
+
+ def taskRow(shuffleRead: Boolean, shuffleWrite: Boolean)
+ (taskData: (TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])): Seq[Node] = {
+ def fmtStackTrace(trace: Seq[StackTraceElement]): Seq[Node] =
+ trace.map(e => <span style="display:block;">{e.toString}</span>)
+ val (info, metrics, exception) = taskData
+
+ val duration = if (info.status == "RUNNING") info.timeRunning(System.currentTimeMillis())
+ else metrics.map(m => m.executorRunTime).getOrElse(1)
+ val formatDuration = if (info.status == "RUNNING") parent.formatDuration(duration)
+ else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("")
+ val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
+
+ <tr>
+ <td>{info.taskId}</td>
+ <td>{info.status}</td>
+ <td>{info.taskLocality}</td>
+ <td>{info.host}</td>
+ <td>{dateFmt.format(new Date(info.launchTime))}</td>
+ <td sorttable_customkey={duration.toString}>
+ {formatDuration}
+ </td>
+ <td sorttable_customkey={gcTime.toString}>
+ {if (gcTime > 0) parent.formatDuration(gcTime) else ""}
+ </td>
+ {if (shuffleRead) {
+ <td>{metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
+ Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td>
+ }}
+ {if (shuffleWrite) {
+ <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+ Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td>
+ }}
+ <td>{exception.map(e =>
+ <span>
+ {e.className} ({e.description})<br/>
+ {fmtStackTrace(e.stackTrace)}
+ </span>).getOrElse("")}
+ </td>
+ </tr>
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
new file mode 100644
index 0000000000..beb0574548
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -0,0 +1,107 @@
+package org.apache.spark.ui.jobs
+
+import java.util.Date
+
+import scala.xml.Node
+import scala.collection.mutable.HashSet
+
+import org.apache.spark.Utils
+import org.apache.spark.scheduler.cluster.{SchedulingMode, TaskInfo}
+import org.apache.spark.scheduler.Stage
+
+
+/** Page showing list of all ongoing and recently finished stages */
+private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) {
+
+ val listener = parent.listener
+ val dateFmt = parent.dateFmt
+ val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR
+
+ def toNodeSeq(): Seq[Node] = {
+ listener.synchronized {
+ stageTable(stageRow, stages)
+ }
+ }
+
+ /** Special table which merges two header cells. */
+ private def stageTable[T](makeRow: T => Seq[Node], rows: Seq[T]): Seq[Node] = {
+ <table class="table table-bordered table-striped table-condensed sortable">
+ <thead>
+ <th>Stage Id</th>
+ {if (isFairScheduler) {<th>Pool Name</th>} else {}}
+ <th>Description</th>
+ <th>Submitted</th>
+ <th>Duration</th>
+ <th>Tasks: Succeeded/Total</th>
+ <th>Shuffle Read</th>
+ <th>Shuffle Write</th>
+ </thead>
+ <tbody>
+ {rows.map(r => makeRow(r))}
+ </tbody>
+ </table>
+ }
+
+ private def makeProgressBar(started: Int, completed: Int, failed: String, total: Int): Seq[Node] = {
+ val completeWidth = "width: %s%%".format((completed.toDouble/total)*100)
+ val startWidth = "width: %s%%".format((started.toDouble/total)*100)
+
+ <div class="progress">
+ <span style="text-align:center; position:absolute; width:100%;">
+ {completed}/{total} {failed}
+ </span>
+ <div class="bar bar-completed" style={completeWidth}></div>
+ <div class="bar bar-running" style={startWidth}></div>
+ </div>
+ }
+
+
+ private def stageRow(s: Stage): Seq[Node] = {
+ val submissionTime = s.submissionTime match {
+ case Some(t) => dateFmt.format(new Date(t))
+ case None => "Unknown"
+ }
+
+ val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match {
+ case 0 => ""
+ case b => Utils.bytesToString(b)
+ }
+ val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match {
+ case 0 => ""
+ case b => Utils.bytesToString(b)
+ }
+
+ val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size
+ val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0)
+ val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match {
+ case f if f > 0 => "(%s failed)".format(f)
+ case _ => ""
+ }
+ val totalTasks = s.numPartitions
+
+ val poolName = listener.stageToPool.get(s)
+
+ val nameLink = <a href={"/stages/stage?id=%s".format(s.id)}>{s.name}</a>
+ val description = listener.stageToDescription.get(s)
+ .map(d => <div><em>{d}</em></div><div>{nameLink}</div>).getOrElse(nameLink)
+ val finishTime = s.completionTime.getOrElse(System.currentTimeMillis())
+ val duration = s.submissionTime.map(t => finishTime - t)
+
+ <tr>
+ <td>{s.id}</td>
+ {if (isFairScheduler) {
+ <td><a href={"/stages/pool?poolname=%s".format(poolName.get)}>{poolName.get}</a></td>}
+ }
+ <td>{description}</td>
+ <td valign="middle">{submissionTime}</td>
+ <td sorttable_customkey={duration.getOrElse(-1).toString}>
+ {duration.map(d => parent.formatDuration(d)).getOrElse("Unknown")}
+ </td>
+ <td class="progress-cell">
+ {makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)}
+ </td>
+ <td>{shuffleRead}</td>
+ <td>{shuffleWrite}</td>
+ </tr>
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
new file mode 100644
index 0000000000..1d633d374a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.storage
+
+import akka.util.Duration
+
+import javax.servlet.http.HttpServletRequest
+
+import org.eclipse.jetty.server.Handler
+
+import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.ui.JettyUtils._
+
+/** Web UI showing storage status of all RDD's in the given SparkContext. */
+private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging {
+ implicit val timeout = Duration.create(
+ System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+
+ val indexPage = new IndexPage(this)
+ val rddPage = new RDDPage(this)
+
+ def getHandlers = Seq[(String, Handler)](
+ ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)),
+ ("/storage", (request: HttpServletRequest) => indexPage.render(request))
+ )
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala
new file mode 100644
index 0000000000..1eb4a7a85e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.storage
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.apache.spark.storage.{RDDInfo, StorageUtils}
+import org.apache.spark.Utils
+import org.apache.spark.ui.UIUtils._
+import org.apache.spark.ui.Page._
+
+/** Page showing list of RDD's currently stored in the cluster */
+private[spark] class IndexPage(parent: BlockManagerUI) {
+ val sc = parent.sc
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val storageStatusList = sc.getExecutorStorageStatus
+ // Calculate macro-level statistics
+
+ val rddHeaders = Seq(
+ "RDD Name",
+ "Storage Level",
+ "Cached Partitions",
+ "Fraction Cached",
+ "Size in Memory",
+ "Size on Disk")
+ val rdds = StorageUtils.rddInfoFromStorageStatus(storageStatusList, sc)
+ val content = listingTable(rddHeaders, rddRow, rdds)
+
+ headerSparkPage(content, parent.sc, "Storage ", Storage)
+ }
+
+ def rddRow(rdd: RDDInfo): Seq[Node] = {
+ <tr>
+ <td>
+ <a href={"/storage/rdd?id=%s".format(rdd.id)}>
+ {rdd.name}
+ </a>
+ </td>
+ <td>{rdd.storageLevel.description}
+ </td>
+ <td>{rdd.numCachedPartitions}</td>
+ <td>{"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)}</td>
+ <td>{Utils.bytesToString(rdd.memSize)}</td>
+ <td>{Utils.bytesToString(rdd.diskSize)}</td>
+ </tr>
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
new file mode 100644
index 0000000000..37baf17f7a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.storage
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.Node
+
+import org.apache.spark.Utils
+import org.apache.spark.storage.{StorageStatus, StorageUtils}
+import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus
+import org.apache.spark.ui.UIUtils._
+import org.apache.spark.ui.Page._
+
+
+/** Page showing storage details for a given RDD */
+private[spark] class RDDPage(parent: BlockManagerUI) {
+ val sc = parent.sc
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val id = request.getParameter("id")
+ val prefix = "rdd_" + id.toString
+ val storageStatusList = sc.getExecutorStorageStatus
+ val filteredStorageStatusList = StorageUtils.
+ filterStorageStatusByPrefix(storageStatusList, prefix)
+ val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
+
+ val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage")
+ val workers = filteredStorageStatusList.map((prefix, _))
+ val workerTable = listingTable(workerHeaders, workerRow, workers)
+
+ val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk",
+ "Executors")
+
+ val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1)
+ val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList)
+ val blocks = blockStatuses.map {
+ case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN")))
+ }
+ val blockTable = listingTable(blockHeaders, blockRow, blocks)
+
+ val content =
+ <div class="row-fluid">
+ <div class="span12">
+ <ul class="unstyled">
+ <li>
+ <strong>Storage Level:</strong>
+ {rddInfo.storageLevel.description}
+ </li>
+ <li>
+ <strong>Cached Partitions:</strong>
+ {rddInfo.numCachedPartitions}
+ </li>
+ <li>
+ <strong>Total Partitions:</strong>
+ {rddInfo.numPartitions}
+ </li>
+ <li>
+ <strong>Memory Size:</strong>
+ {Utils.bytesToString(rddInfo.memSize)}
+ </li>
+ <li>
+ <strong>Disk Size:</strong>
+ {Utils.bytesToString(rddInfo.diskSize)}
+ </li>
+ </ul>
+ </div>
+ </div>
+
+ <div class="row-fluid">
+ <div class="span12">
+ <h4> Data Distribution on {workers.size} Executors </h4>
+ {workerTable}
+ </div>
+ </div>
+
+ <div class="row-fluid">
+ <div class="span12">
+ <h4> {blocks.size} Partitions </h4>
+ {blockTable}
+ </div>
+ </div>;
+
+ headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage)
+ }
+
+ def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = {
+ val (id, block, locations) = row
+ <tr>
+ <td>{id}</td>
+ <td>
+ {block.storageLevel.description}
+ </td>
+ <td sorttable_customkey={block.memSize.toString}>
+ {Utils.bytesToString(block.memSize)}
+ </td>
+ <td sorttable_customkey={block.diskSize.toString}>
+ {Utils.bytesToString(block.diskSize)}
+ </td>
+ <td>
+ {locations.map(l => <span>{l}<br/></span>)}
+ </td>
+ </tr>
+ }
+
+ def workerRow(worker: (String, StorageStatus)): Seq[Node] = {
+ val (prefix, status) = worker
+ <tr>
+ <td>{status.blockManagerId.host + ":" + status.blockManagerId.port}</td>
+ <td>
+ {Utils.bytesToString(status.memUsed(prefix))}
+ ({Utils.bytesToString(status.memRemaining)} Remaining)
+ </td>
+ <td>{Utils.bytesToString(status.diskUsed(prefix))}</td>
+ </tr>
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
new file mode 100644
index 0000000000..d4c5065c3f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import akka.actor.{ActorSystem, ExtendedActorSystem}
+import com.typesafe.config.ConfigFactory
+import akka.util.duration._
+import akka.remote.RemoteActorRefProvider
+
+
+/**
+ * Various utility classes for working with Akka.
+ */
+private[spark] object AkkaUtils {
+
+ /**
+ * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the
+ * ActorSystem itself and its port (which is hard to get from Akka).
+ *
+ * Note: the `name` parameter is important, as even if a client sends a message to right
+ * host + port, if the system name is incorrect, Akka will drop the message.
+ */
+ def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
+ val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
+ val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
+ val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt
+ val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt
+ val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off"
+ // 10 seconds is the default akka timeout, but in a cluster, we need higher by default.
+ val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt
+
+ val akkaConf = ConfigFactory.parseString("""
+ akka.daemonic = on
+ akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
+ akka.stdout-loglevel = "ERROR"
+ akka.actor.provider = "akka.remote.RemoteActorRefProvider"
+ akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
+ akka.remote.netty.hostname = "%s"
+ akka.remote.netty.port = %d
+ akka.remote.netty.connection-timeout = %ds
+ akka.remote.netty.message-frame-size = %d MiB
+ akka.remote.netty.execution-pool-size = %d
+ akka.actor.default-dispatcher.throughput = %d
+ akka.remote.log-remote-lifecycle-events = %s
+ akka.remote.netty.write-timeout = %ds
+ """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize,
+ lifecycleEvents, akkaWriteTimeout))
+
+ val actorSystem = ActorSystem(name, akkaConf)
+
+ // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
+ // hack because Akka doesn't let you figure out the port through the public API yet.
+ val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider
+ val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get
+ return (actorSystem, boundPort)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
new file mode 100644
index 0000000000..0b51c23f7b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.Serializable
+import java.util.{PriorityQueue => JPriorityQueue}
+import scala.collection.generic.Growable
+import scala.collection.JavaConverters._
+
+/**
+ * Bounded priority queue. This class wraps the original PriorityQueue
+ * class and modifies it such that only the top K elements are retained.
+ * The top K elements are defined by an implicit Ordering[A].
+ */
+class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
+ extends Iterable[A] with Growable[A] with Serializable {
+
+ private val underlying = new JPriorityQueue[A](maxSize, ord)
+
+ override def iterator: Iterator[A] = underlying.iterator.asScala
+
+ override def ++=(xs: TraversableOnce[A]): this.type = {
+ xs.foreach { this += _ }
+ this
+ }
+
+ override def +=(elem: A): this.type = {
+ if (size < maxSize) underlying.offer(elem)
+ else maybeReplaceLowest(elem)
+ this
+ }
+
+ override def +=(elem1: A, elem2: A, elems: A*): this.type = {
+ this += elem1 += elem2 ++= elems
+ }
+
+ override def clear() { underlying.clear() }
+
+ private def maybeReplaceLowest(a: A): Boolean = {
+ val head = underlying.peek()
+ if (head != null && ord.gt(a, head)) {
+ underlying.poll()
+ underlying.offer(a)
+ } else false
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
new file mode 100644
index 0000000000..e214d2a519
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.InputStream
+import java.nio.ByteBuffer
+import org.apache.spark.storage.BlockManager
+
+/**
+ * Reads data from a ByteBuffer, and optionally cleans it up using BlockManager.dispose()
+ * at the end of the stream (e.g. to close a memory-mapped file).
+ */
+private[spark]
+class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false)
+ extends InputStream {
+
+ override def read(): Int = {
+ if (buffer == null || buffer.remaining() == 0) {
+ cleanUp()
+ -1
+ } else {
+ buffer.get() & 0xFF
+ }
+ }
+
+ override def read(dest: Array[Byte]): Int = {
+ read(dest, 0, dest.length)
+ }
+
+ override def read(dest: Array[Byte], offset: Int, length: Int): Int = {
+ if (buffer == null || buffer.remaining() == 0) {
+ cleanUp()
+ -1
+ } else {
+ val amountToGet = math.min(buffer.remaining(), length)
+ buffer.get(dest, offset, amountToGet)
+ amountToGet
+ }
+ }
+
+ override def skip(bytes: Long): Long = {
+ if (buffer != null) {
+ val amountToSkip = math.min(bytes, buffer.remaining).toInt
+ buffer.position(buffer.position + amountToSkip)
+ if (buffer.remaining() == 0) {
+ cleanUp()
+ }
+ amountToSkip
+ } else {
+ 0L
+ }
+ }
+
+ /**
+ * Clean up the buffer, and potentially dispose of it using BlockManager.dispose().
+ */
+ private def cleanUp() {
+ if (buffer != null) {
+ if (dispose) {
+ BlockManager.dispose(buffer)
+ }
+ buffer = null
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Clock.scala b/core/src/main/scala/org/apache/spark/util/Clock.scala
new file mode 100644
index 0000000000..97c2b45aab
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/Clock.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * An interface to represent clocks, so that they can be mocked out in unit tests.
+ */
+private[spark] trait Clock {
+ def getTime(): Long
+}
+
+private[spark] object SystemClock extends Clock {
+ def getTime(): Long = System.currentTimeMillis()
+}
diff --git a/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
new file mode 100644
index 0000000000..dc15a38b29
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/CompletionIterator.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * Wrapper around an iterator which calls a completion method after it successfully iterates through all the elements
+ */
+abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{
+ def next = sub.next
+ def hasNext = {
+ val r = sub.hasNext
+ if (!r) {
+ completion
+ }
+ r
+ }
+
+ def completion()
+}
+
+object CompletionIterator {
+ def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = {
+ new CompletionIterator[A,I](sub) {
+ def completion() = completionFunction
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala
new file mode 100644
index 0000000000..33bf3562fe
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.PrintStream
+
+/**
+ * Util for getting some stats from a small sample of numeric values, with some handy summary functions.
+ *
+ * Entirely in memory, not intended as a good way to compute stats over large data sets.
+ *
+ * Assumes you are giving it a non-empty set of data
+ */
+class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) {
+ require(startIdx < endIdx)
+ def this(data: Traversable[Double]) = this(data.toArray, 0, data.size)
+ java.util.Arrays.sort(data, startIdx, endIdx)
+ val length = endIdx - startIdx
+
+ val defaultProbabilities = Array(0,0.25,0.5,0.75,1.0)
+
+ /**
+ * Get the value of the distribution at the given probabilities. Probabilities should be
+ * given from 0 to 1
+ * @param probabilities
+ */
+ def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) = {
+ probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))}
+ }
+
+ private def closestIndex(p: Double) = {
+ math.min((p * length).toInt + startIdx, endIdx - 1)
+ }
+
+ def showQuantiles(out: PrintStream = System.out) = {
+ out.println("min\t25%\t50%\t75%\tmax")
+ getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")}
+ out.println
+ }
+
+ def statCounter = StatCounter(data.slice(startIdx, endIdx))
+
+ /**
+ * print a summary of this distribution to the given PrintStream.
+ * @param out
+ */
+ def summary(out: PrintStream = System.out) {
+ out.println(statCounter)
+ showQuantiles(out)
+ }
+}
+
+object Distribution {
+
+ def apply(data: Traversable[Double]): Option[Distribution] = {
+ if (data.size > 0)
+ Some(new Distribution(data))
+ else
+ None
+ }
+
+ def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) {
+ out.println("min\t25%\t50%\t75%\tmax")
+ quantiles.foreach{q => out.print(q + "\t")}
+ out.println
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/IdGenerator.scala b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala
new file mode 100644
index 0000000000..17e55f7996
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/IdGenerator.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.concurrent.atomic.AtomicInteger
+
+/**
+ * A util used to get a unique generation ID. This is a wrapper around Java's
+ * AtomicInteger. An example usage is in BlockManager, where each BlockManager
+ * instance would start an Akka actor and we use this utility to assign the Akka
+ * actors unique names.
+ */
+private[spark] class IdGenerator {
+ private var id = new AtomicInteger
+ def next: Int = id.incrementAndGet
+}
diff --git a/core/src/main/scala/org/apache/spark/util/IntParam.scala b/core/src/main/scala/org/apache/spark/util/IntParam.scala
new file mode 100644
index 0000000000..626bb49eea
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/IntParam.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * An extractor object for parsing strings into integers.
+ */
+private[spark] object IntParam {
+ def unapply(str: String): Option[Int] = {
+ try {
+ Some(str.toInt)
+ } catch {
+ case e: NumberFormatException => None
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/MemoryParam.scala b/core/src/main/scala/org/apache/spark/util/MemoryParam.scala
new file mode 100644
index 0000000000..0ee6707826
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/MemoryParam.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.apache.spark.Utils
+
+/**
+ * An extractor object for parsing JVM memory strings, such as "10g", into an Int representing
+ * the number of megabytes. Supports the same formats as Utils.memoryStringToMb.
+ */
+private[spark] object MemoryParam {
+ def unapply(str: String): Option[Int] = {
+ try {
+ Some(Utils.memoryStringToMb(str))
+ } catch {
+ case e: NumberFormatException => None
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
new file mode 100644
index 0000000000..a430a75451
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors}
+import java.util.{TimerTask, Timer}
+import org.apache.spark.Logging
+
+
+/**
+ * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries)
+ */
+class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging {
+ private val delaySeconds = MetadataCleaner.getDelaySeconds
+ private val periodSeconds = math.max(10, delaySeconds / 10)
+ private val timer = new Timer(name + " cleanup timer", true)
+
+ private val task = new TimerTask {
+ override def run() {
+ try {
+ cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000))
+ logInfo("Ran metadata cleaner for " + name)
+ } catch {
+ case e: Exception => logError("Error running cleanup task for " + name, e)
+ }
+ }
+ }
+
+ if (delaySeconds > 0) {
+ logDebug(
+ "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " +
+ "and period of " + periodSeconds + " secs")
+ timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000)
+ }
+
+ def cancel() {
+ timer.cancel()
+ }
+}
+
+
+object MetadataCleaner {
+ def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt
+ def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala
new file mode 100644
index 0000000000..34f1f6606f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+
+/**
+ * A tuple of 2 elements. This can be used as an alternative to Scala's Tuple2 when we want to
+ * minimize object allocation.
+ *
+ * @param _1 Element 1 of this MutablePair
+ * @param _2 Element 2 of this MutablePair
+ */
+case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1,
+ @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2]
+ (var _1: T1, var _2: T2)
+ extends Product2[T1, T2]
+{
+ override def toString = "(" + _1 + "," + _2 + ")"
+
+ override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]]
+}
diff --git a/core/src/main/scala/org/apache/spark/util/NextIterator.scala b/core/src/main/scala/org/apache/spark/util/NextIterator.scala
new file mode 100644
index 0000000000..8266e5e495
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/NextIterator.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/** Provides a basic/boilerplate Iterator implementation. */
+private[spark] abstract class NextIterator[U] extends Iterator[U] {
+
+ private var gotNext = false
+ private var nextValue: U = _
+ private var closed = false
+ protected var finished = false
+
+ /**
+ * Method for subclasses to implement to provide the next element.
+ *
+ * If no next element is available, the subclass should set `finished`
+ * to `true` and may return any value (it will be ignored).
+ *
+ * This convention is required because `null` may be a valid value,
+ * and using `Option` seems like it might create unnecessary Some/None
+ * instances, given some iterators might be called in a tight loop.
+ *
+ * @return U, or set 'finished' when done
+ */
+ protected def getNext(): U
+
+ /**
+ * Method for subclasses to implement when all elements have been successfully
+ * iterated, and the iteration is done.
+ *
+ * <b>Note:</b> `NextIterator` cannot guarantee that `close` will be
+ * called because it has no control over what happens when an exception
+ * happens in the user code that is calling hasNext/next.
+ *
+ * Ideally you should have another try/catch, as in HadoopRDD, that
+ * ensures any resources are closed should iteration fail.
+ */
+ protected def close()
+
+ /**
+ * Calls the subclass-defined close method, but only once.
+ *
+ * Usually calling `close` multiple times should be fine, but historically
+ * there have been issues with some InputFormats throwing exceptions.
+ */
+ def closeIfNeeded() {
+ if (!closed) {
+ close()
+ closed = true
+ }
+ }
+
+ override def hasNext: Boolean = {
+ if (!finished) {
+ if (!gotNext) {
+ nextValue = getNext()
+ if (finished) {
+ closeIfNeeded()
+ }
+ gotNext = true
+ }
+ }
+ !finished
+ }
+
+ override def next(): U = {
+ if (!hasNext) {
+ throw new NoSuchElementException("End of stream")
+ }
+ gotNext = false
+ nextValue
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala b/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala
new file mode 100644
index 0000000000..47e1b45004
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/RateLimitedOutputStream.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.annotation.tailrec
+
+import java.io.OutputStream
+import java.util.concurrent.TimeUnit._
+
+class RateLimitedOutputStream(out: OutputStream, bytesPerSec: Int) extends OutputStream {
+ val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS)
+ val CHUNK_SIZE = 8192
+ var lastSyncTime = System.nanoTime
+ var bytesWrittenSinceSync: Long = 0
+
+ override def write(b: Int) {
+ waitToWrite(1)
+ out.write(b)
+ }
+
+ override def write(bytes: Array[Byte]) {
+ write(bytes, 0, bytes.length)
+ }
+
+ @tailrec
+ override final def write(bytes: Array[Byte], offset: Int, length: Int) {
+ val writeSize = math.min(length - offset, CHUNK_SIZE)
+ if (writeSize > 0) {
+ waitToWrite(writeSize)
+ out.write(bytes, offset, writeSize)
+ write(bytes, offset + writeSize, length)
+ }
+ }
+
+ override def flush() {
+ out.flush()
+ }
+
+ override def close() {
+ out.close()
+ }
+
+ @tailrec
+ private def waitToWrite(numBytes: Int) {
+ val now = System.nanoTime
+ val elapsedSecs = SECONDS.convert(math.max(now - lastSyncTime, 1), NANOSECONDS)
+ val rate = bytesWrittenSinceSync.toDouble / elapsedSecs
+ if (rate < bytesPerSec) {
+ // It's okay to write; just update some variables and return
+ bytesWrittenSinceSync += numBytes
+ if (now > lastSyncTime + SYNC_INTERVAL) {
+ // Sync interval has passed; let's resync
+ lastSyncTime = now
+ bytesWrittenSinceSync = numBytes
+ }
+ } else {
+ // Calculate how much time we should sleep to bring ourselves to the desired rate.
+ // Based on throttler in Kafka (https://github.com/kafka-dev/kafka/blob/master/core/src/main/scala/kafka/utils/Throttler.scala)
+ val sleepTime = MILLISECONDS.convert((bytesWrittenSinceSync / bytesPerSec - elapsedSecs), SECONDS)
+ if (sleepTime > 0) Thread.sleep(sleepTime)
+ waitToWrite(numBytes)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala
new file mode 100644
index 0000000000..f2b1ad7d0e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.nio.ByteBuffer
+import java.io.{IOException, ObjectOutputStream, EOFException, ObjectInputStream}
+import java.nio.channels.Channels
+
+/**
+ * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make
+ * it easier to pass ByteBuffers in case class messages.
+ */
+private[spark]
+class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable {
+ def value = buffer
+
+ private def readObject(in: ObjectInputStream) {
+ val length = in.readInt()
+ buffer = ByteBuffer.allocate(length)
+ var amountRead = 0
+ val channel = Channels.newChannel(in)
+ while (amountRead < length) {
+ val ret = channel.read(buffer)
+ if (ret == -1) {
+ throw new EOFException("End of file before fully reading buffer")
+ }
+ amountRead += ret
+ }
+ buffer.rewind() // Allow us to read it later
+ }
+
+ private def writeObject(out: ObjectOutputStream) {
+ out.writeInt(buffer.limit())
+ if (Channels.newChannel(out).write(buffer) != buffer.limit()) {
+ throw new IOException("Could not fully write buffer to output stream")
+ }
+ buffer.rewind() // Allow us to write it again later
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala
new file mode 100644
index 0000000000..020d5edba9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * A class for tracking the statistics of a set of numbers (count, mean and variance) in a
+ * numerically robust way. Includes support for merging two StatCounters. Based on
+ * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance Welford and Chan's algorithms for running variance]].
+ *
+ * @constructor Initialize the StatCounter with the given values.
+ */
+class StatCounter(values: TraversableOnce[Double]) extends Serializable {
+ private var n: Long = 0 // Running count of our values
+ private var mu: Double = 0 // Running mean of our values
+ private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2)
+
+ merge(values)
+
+ /** Initialize the StatCounter with no values. */
+ def this() = this(Nil)
+
+ /** Add a value into this StatCounter, updating the internal statistics. */
+ def merge(value: Double): StatCounter = {
+ val delta = value - mu
+ n += 1
+ mu += delta / n
+ m2 += delta * (value - mu)
+ this
+ }
+
+ /** Add multiple values into this StatCounter, updating the internal statistics. */
+ def merge(values: TraversableOnce[Double]): StatCounter = {
+ values.foreach(v => merge(v))
+ this
+ }
+
+ /** Merge another StatCounter into this one, adding up the internal statistics. */
+ def merge(other: StatCounter): StatCounter = {
+ if (other == this) {
+ merge(other.copy()) // Avoid overwriting fields in a weird order
+ } else {
+ if (n == 0) {
+ mu = other.mu
+ m2 = other.m2
+ n = other.n
+ } else if (other.n != 0) {
+ val delta = other.mu - mu
+ if (other.n * 10 < n) {
+ mu = mu + (delta * other.n) / (n + other.n)
+ } else if (n * 10 < other.n) {
+ mu = other.mu - (delta * n) / (n + other.n)
+ } else {
+ mu = (mu * n + other.mu * other.n) / (n + other.n)
+ }
+ m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
+ n += other.n
+ }
+ this
+ }
+ }
+
+ /** Clone this StatCounter */
+ def copy(): StatCounter = {
+ val other = new StatCounter
+ other.n = n
+ other.mu = mu
+ other.m2 = m2
+ other
+ }
+
+ def count: Long = n
+
+ def mean: Double = mu
+
+ def sum: Double = n * mu
+
+ /** Return the variance of the values. */
+ def variance: Double = {
+ if (n == 0)
+ Double.NaN
+ else
+ m2 / n
+ }
+
+ /**
+ * Return the sample variance, which corrects for bias in estimating the variance by dividing
+ * by N-1 instead of N.
+ */
+ def sampleVariance: Double = {
+ if (n <= 1)
+ Double.NaN
+ else
+ m2 / (n - 1)
+ }
+
+ /** Return the standard deviation of the values. */
+ def stdev: Double = math.sqrt(variance)
+
+ /**
+ * Return the sample standard deviation of the values, which corrects for bias in estimating the
+ * variance by dividing by N-1 instead of N.
+ */
+ def sampleStdev: Double = math.sqrt(sampleVariance)
+
+ override def toString: String = {
+ "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev)
+ }
+}
+
+object StatCounter {
+ /** Build a StatCounter from a list of values. */
+ def apply(values: TraversableOnce[Double]) = new StatCounter(values)
+
+ /** Build a StatCounter from a list of values passed as variable-length arguments. */
+ def apply(values: Double*) = new StatCounter(values)
+}
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
new file mode 100644
index 0000000000..277de2f8a6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.concurrent.ConcurrentHashMap
+import scala.collection.JavaConversions
+import scala.collection.mutable.Map
+import scala.collection.immutable
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.Logging
+
+/**
+ * This is a custom implementation of scala.collection.mutable.Map which stores the insertion
+ * time stamp along with each key-value pair. Key-value pairs that are older than a particular
+ * threshold time can them be removed using the clearOldValues method. This is intended to be a drop-in
+ * replacement of scala.collection.mutable.HashMap.
+ */
+class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging {
+ val internalMap = new ConcurrentHashMap[A, (B, Long)]()
+
+ def get(key: A): Option[B] = {
+ val value = internalMap.get(key)
+ if (value != null) Some(value._1) else None
+ }
+
+ def iterator: Iterator[(A, B)] = {
+ val jIterator = internalMap.entrySet().iterator()
+ JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1))
+ }
+
+ override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = {
+ val newMap = new TimeStampedHashMap[A, B1]
+ newMap.internalMap.putAll(this.internalMap)
+ newMap.internalMap.put(kv._1, (kv._2, currentTime))
+ newMap
+ }
+
+ override def - (key: A): Map[A, B] = {
+ val newMap = new TimeStampedHashMap[A, B]
+ newMap.internalMap.putAll(this.internalMap)
+ newMap.internalMap.remove(key)
+ newMap
+ }
+
+ override def += (kv: (A, B)): this.type = {
+ internalMap.put(kv._1, (kv._2, currentTime))
+ this
+ }
+
+ // Should we return previous value directly or as Option ?
+ def putIfAbsent(key: A, value: B): Option[B] = {
+ val prev = internalMap.putIfAbsent(key, (value, currentTime))
+ if (prev != null) Some(prev._1) else None
+ }
+
+
+ override def -= (key: A): this.type = {
+ internalMap.remove(key)
+ this
+ }
+
+ override def update(key: A, value: B) {
+ this += ((key, value))
+ }
+
+ override def apply(key: A): B = {
+ val value = internalMap.get(key)
+ if (value == null) throw new NoSuchElementException()
+ value._1
+ }
+
+ override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
+ JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
+ }
+
+ override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
+
+ override def size: Int = internalMap.size
+
+ override def foreach[U](f: ((A, B)) => U) {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ val kv = (entry.getKey, entry.getValue._1)
+ f(kv)
+ }
+ }
+
+ def toMap: immutable.Map[A, B] = iterator.toMap
+
+ /**
+ * Removes old key-value pairs that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ if (entry.getValue._2 < threshTime) {
+ logDebug("Removing key " + entry.getKey)
+ iterator.remove()
+ }
+ }
+ }
+
+ private def currentTime: Long = System.currentTimeMillis()
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala
new file mode 100644
index 0000000000..26983138ff
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.collection.mutable.Set
+import scala.collection.JavaConversions
+import java.util.concurrent.ConcurrentHashMap
+
+
+class TimeStampedHashSet[A] extends Set[A] {
+ val internalMap = new ConcurrentHashMap[A, Long]()
+
+ def contains(key: A): Boolean = {
+ internalMap.contains(key)
+ }
+
+ def iterator: Iterator[A] = {
+ val jIterator = internalMap.entrySet().iterator()
+ JavaConversions.asScalaIterator(jIterator).map(_.getKey)
+ }
+
+ override def + (elem: A): Set[A] = {
+ val newSet = new TimeStampedHashSet[A]
+ newSet ++= this
+ newSet += elem
+ newSet
+ }
+
+ override def - (elem: A): Set[A] = {
+ val newSet = new TimeStampedHashSet[A]
+ newSet ++= this
+ newSet -= elem
+ newSet
+ }
+
+ override def += (key: A): this.type = {
+ internalMap.put(key, currentTime)
+ this
+ }
+
+ override def -= (key: A): this.type = {
+ internalMap.remove(key)
+ this
+ }
+
+ override def empty: Set[A] = new TimeStampedHashSet[A]()
+
+ override def size(): Int = internalMap.size()
+
+ override def foreach[U](f: (A) => U): Unit = {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ f(iterator.next.getKey)
+ }
+ }
+
+ /**
+ * Removes old values that have timestamp earlier than `threshTime`
+ */
+ def clearOldValues(threshTime: Long) {
+ val iterator = internalMap.entrySet().iterator()
+ while(iterator.hasNext) {
+ val entry = iterator.next()
+ if (entry.getValue < threshTime) {
+ iterator.remove()
+ }
+ }
+ }
+
+ private def currentTime: Long = System.currentTimeMillis()
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala
new file mode 100644
index 0000000000..fe710c58ac
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/Vector.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+class Vector(val elements: Array[Double]) extends Serializable {
+ def length = elements.length
+
+ def apply(index: Int) = elements(index)
+
+ def + (other: Vector): Vector = {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ return Vector(length, i => this(i) + other(i))
+ }
+
+ def add(other: Vector) = this + other
+
+ def - (other: Vector): Vector = {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ return Vector(length, i => this(i) - other(i))
+ }
+
+ def subtract(other: Vector) = this - other
+
+ def dot(other: Vector): Double = {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ var ans = 0.0
+ var i = 0
+ while (i < length) {
+ ans += this(i) * other(i)
+ i += 1
+ }
+ return ans
+ }
+
+ /**
+ * return (this + plus) dot other, but without creating any intermediate storage
+ * @param plus
+ * @param other
+ * @return
+ */
+ def plusDot(plus: Vector, other: Vector): Double = {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ if (length != plus.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ var ans = 0.0
+ var i = 0
+ while (i < length) {
+ ans += (this(i) + plus(i)) * other(i)
+ i += 1
+ }
+ return ans
+ }
+
+ def += (other: Vector): Vector = {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ var i = 0
+ while (i < length) {
+ elements(i) += other(i)
+ i += 1
+ }
+ this
+ }
+
+ def addInPlace(other: Vector) = this +=other
+
+ def * (scale: Double): Vector = Vector(length, i => this(i) * scale)
+
+ def multiply (d: Double) = this * d
+
+ def / (d: Double): Vector = this * (1 / d)
+
+ def divide (d: Double) = this / d
+
+ def unary_- = this * -1
+
+ def sum = elements.reduceLeft(_ + _)
+
+ def squaredDist(other: Vector): Double = {
+ var ans = 0.0
+ var i = 0
+ while (i < length) {
+ ans += (this(i) - other(i)) * (this(i) - other(i))
+ i += 1
+ }
+ return ans
+ }
+
+ def dist(other: Vector): Double = math.sqrt(squaredDist(other))
+
+ override def toString = elements.mkString("(", ", ", ")")
+}
+
+object Vector {
+ def apply(elements: Array[Double]) = new Vector(elements)
+
+ def apply(elements: Double*) = new Vector(elements.toArray)
+
+ def apply(length: Int, initializer: Int => Double): Vector = {
+ val elements: Array[Double] = Array.tabulate(length)(initializer)
+ return new Vector(elements)
+ }
+
+ def zeros(length: Int) = new Vector(new Array[Double](length))
+
+ def ones(length: Int) = Vector(length, _ => 1)
+
+ class Multiplier(num: Double) {
+ def * (vec: Vector) = vec * num
+ }
+
+ implicit def doubleToMultiplier(num: Double) = new Multiplier(num)
+
+ implicit object VectorAccumParam extends org.apache.spark.AccumulatorParam[Vector] {
+ def addInPlace(t1: Vector, t2: Vector) = t1 + t2
+
+ def zero(initialValue: Vector) = Vector.zeros(initialValue.length)
+ }
+
+}