aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--README.md13
-rw-r--r--assembly/pom.xml16
-rw-r--r--bagel/pom.xml12
-rw-r--r--bin/compute-classpath.cmd2
-rwxr-xr-xbin/compute-classpath.sh24
-rw-r--r--conf/metrics.properties.template8
-rw-r--r--core/pom.xml1593
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileClient.java2
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileServer.java1
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java19
-rwxr-xr-xcore/src/main/java/org/apache/spark/network/netty/PathResolver.java11
-rw-r--r--core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala17
-rw-r--r--core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala210
-rw-r--r--core/src/main/scala/org/apache/spark/Partitioner.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala401
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/TaskState.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala73
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala104
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala43
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java1
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java10
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java5
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function.java11
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function2.java11
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function3.java (renamed from core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala)29
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java15
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java18
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala163
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala1060
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala410
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala247
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala601
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/Client.scala50
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala78
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala64
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala82
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/package.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala127
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala133
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala57
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala494
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala676
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala64
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala67
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala52
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockInfo.scala81
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala565
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala142
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala151
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala263
-rw-r--r--core/src/main/scala/org/apache/spark/storage/FileSegment.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala196
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageLevel.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala103
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala90
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala139
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala96
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala85
-rw-r--r--core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala93
-rw-r--r--core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala68
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala94
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/BitSet.scala103
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala153
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala279
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala128
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala69
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala32
-rw-r--r--core/src/test/scala/org/apache/spark/BroadcastSuite.scala52
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/DriverSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/FileServerSuite.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java68
-rw-r--r--core/src/test/scala/org/apache/spark/JobCancellationSuite.scala34
-rw-r--r--core/src/test/scala/org/apache/spark/LocalSparkContext.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/PartitioningSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala140
-rw-r--r--core/src/test/scala/org/apache/spark/UnpersistSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala (renamed from core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala)35
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala271
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala86
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala28
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala51
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala132
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala111
-rw-r--r--core/src/test/scala/org/apache/spark/ui/UISuite.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala73
-rw-r--r--core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala72
-rw-r--r--core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala76
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala73
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala177
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala180
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala119
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala117
-rw-r--r--docker/spark-test/README.md7
-rw-r--r--docs/_config.yml2
-rwxr-xr-xdocs/_layouts/global.html8
-rw-r--r--docs/_plugins/copy_api_dirs.rb2
-rw-r--r--docs/bagel-programming-guide.md2
-rw-r--r--docs/building-with-maven.md6
-rw-r--r--docs/cluster-overview.md16
-rw-r--r--docs/configuration.md69
-rw-r--r--docs/ec2-scripts.md2
-rw-r--r--docs/hadoop-third-party-distributions.md7
-rw-r--r--docs/index.md8
-rw-r--r--docs/job-scheduling.md2
-rw-r--r--docs/monitoring.md1
-rw-r--r--docs/python-programming-guide.md11
-rw-r--r--docs/running-on-yarn.md46
-rw-r--r--docs/scala-programming-guide.md8
-rw-r--r--docs/spark-standalone.md4
-rw-r--r--docs/streaming-programming-guide.md12
-rw-r--r--docs/tuning.md5
-rwxr-xr-xec2/spark_ec2.py72
-rw-r--r--examples/pom.xml60
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java2
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaPageRank.java3
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaWordCount.java2
-rw-r--r--examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java1
-rw-r--r--examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java98
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala21
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/LocalALS.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala15
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala3
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkPi.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/SparkTC.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala9
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala28
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala107
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala8
-rw-r--r--mllib/pom.xml12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala2
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java2
-rw-r--r--new-yarn/pom.xml161
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala446
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala94
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala519
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala149
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala228
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala222
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala209
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala687
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala43
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala47
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala109
-rw-r--r--new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala55
-rw-r--r--new-yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala220
-rw-r--r--pom.xml256
-rw-r--r--project/SparkBuild.scala172
-rw-r--r--project/plugins.sbt2
-rwxr-xr-xpyspark10
-rw-r--r--pyspark2.cmd2
-rw-r--r--python/epydoc.conf2
-rw-r--r--python/pyspark/accumulators.py19
-rw-r--r--python/pyspark/context.py116
-rw-r--r--python/pyspark/rdd.py114
-rw-r--r--python/pyspark/serializers.py301
-rw-r--r--python/pyspark/tests.py18
-rw-r--r--python/pyspark/worker.py44
-rwxr-xr-xpython/run-tests1
-rwxr-xr-xpython/test_support/userlibrary.py17
-rw-r--r--repl-bin/pom.xml8
-rwxr-xr-xrepl-bin/src/deb/bin/run2
-rwxr-xr-xrepl-bin/src/deb/bin/spark-executor2
-rwxr-xr-xrepl-bin/src/deb/bin/spark-shell2
-rw-r--r--repl/lib/scala-jline.jarbin158463 -> 0 bytes
-rw-r--r--repl/pom.xml18
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/Main.scala8
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala109
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala949
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala143
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala1674
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkISettings.scala63
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkImports.scala108
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala206
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala65
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala109
-rw-r--r--repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala213
-rwxr-xr-xrun-example12
-rw-r--r--run-example2.cmd2
-rwxr-xr-xsbt/sbt21
-rwxr-xr-xspark-class44
-rw-r--r--spark-class2.cmd7
-rwxr-xr-xspark-shell19
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jarbin1358063 -> 0 bytes
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha11
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom9
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha11
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml12
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md51
-rw-r--r--streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha11
-rw-r--r--streaming/pom.xml74
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala8
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStream.scala97
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala88
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala197
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala112
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala14
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala128
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala233
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala201
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala58
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala12
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala68
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/MQTTInputDStream.scala110
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala18
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala9
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala3
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala6
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala21
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala5
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala35
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receivers/ZeroMQReceiver.scala13
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala55
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/Job.scala)24
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala)52
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala108
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala68
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala)5
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala75
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala81
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala45
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java513
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala54
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala153
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala66
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala13
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala107
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala71
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala118
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala14
-rw-r--r--tools/pom.xml12
-rw-r--r--tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala4
-rw-r--r--yarn/pom.xml58
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala199
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala435
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala41
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala228
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala243
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala145
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala358
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala5
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala47
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala109
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala220
374 files changed, 19051 insertions, 8424 deletions
diff --git a/.gitignore b/.gitignore
index e1f64a1133..b3c4363af0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -41,3 +41,4 @@ derby.log
dist/
spark-*-bin.tar.gz
unit-tests.log
+lib/
diff --git a/README.md b/README.md
index 28ad1e4604..1550a8b551 100644
--- a/README.md
+++ b/README.md
@@ -12,9 +12,8 @@ This README file only contains basic setup instructions.
## Building
-Spark requires Scala 2.9.3 (Scala 2.10 is not yet supported). The project is
-built using Simple Build Tool (SBT), which is packaged with it. To build
-Spark and its example programs, run:
+Spark requires Scala 2.10. The project is built using Simple Build Tool (SBT),
+which is packaged with it. To build Spark and its example programs, run:
sbt/sbt assembly
@@ -55,7 +54,7 @@ versions without YARN, use:
# Cloudera CDH 4.2.0 with MapReduce v1
$ SPARK_HADOOP_VERSION=2.0.0-mr1-cdh4.2.0 sbt/sbt assembly
-For Apache Hadoop 2.x, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions
+For Apache Hadoop 2.2.X, 2.1.X, 2.0.X, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions
with YARN, also set `SPARK_YARN=true`:
# Apache Hadoop 2.0.5-alpha
@@ -64,12 +63,12 @@ with YARN, also set `SPARK_YARN=true`:
# Cloudera CDH 4.2.0 with MapReduce v2
$ SPARK_HADOOP_VERSION=2.0.0-cdh4.2.0 SPARK_YARN=true sbt/sbt assembly
-For convenience, these variables may also be set through the `conf/spark-env.sh` file
-described below.
+ # Apache Hadoop 2.2.X and newer
+ $ SPARK_HADOOP_VERSION=2.2.0 SPARK_YARN=true sbt/sbt assembly
When developing a Spark application, specify the Hadoop version by adding the
"hadoop-client" artifact to your project's dependencies. For example, if you're
-using Hadoop 1.0.1 and build your application using SBT, add this entry to
+using Hadoop 1.2.1 and build your application using SBT, add this entry to
`libraryDependencies`:
"org.apache.hadoop" % "hadoop-client" % "1.2.1"
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 09df8c1fd7..fc2adc1fbb 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -26,7 +26,7 @@
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-assembly_2.9.3</artifactId>
+ <artifactId>spark-assembly_2.10</artifactId>
<name>Spark Project Assembly</name>
<url>http://spark.incubator.apache.org/</url>
@@ -41,27 +41,27 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core_2.9.3</artifactId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-bagel_2.9.3</artifactId>
+ <artifactId>spark-bagel_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-mllib_2.9.3</artifactId>
+ <artifactId>spark-mllib_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-repl_2.9.3</artifactId>
+ <artifactId>spark-repl_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-streaming_2.9.3</artifactId>
+ <artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@@ -79,7 +79,7 @@
<artifactId>maven-shade-plugin</artifactId>
<configuration>
<shadedArtifactAttached>false</shadedArtifactAttached>
- <outputFile>${project.build.directory}/scala-${scala.version}/${project.artifactId}-${project.version}-hadoop${hadoop.version}.jar</outputFile>
+ <outputFile>${project.build.directory}/scala-${scala.binary.version}/${project.artifactId}-${project.version}-hadoop${hadoop.version}.jar</outputFile>
<artifactSet>
<includes>
<include>*:*</include>
@@ -128,7 +128,7 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-yarn_2.9.3</artifactId>
+ <artifactId>spark-yarn_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 0e552c880f..cb8e79f225 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -26,7 +26,7 @@
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-bagel_2.9.3</artifactId>
+ <artifactId>spark-bagel_2.10</artifactId>
<packaging>jar</packaging>
<name>Spark Project Bagel</name>
<url>http://spark.incubator.apache.org/</url>
@@ -34,7 +34,7 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core_2.9.3</artifactId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@@ -43,18 +43,18 @@
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_2.9.3</artifactId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_2.9.3</artifactId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
- <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
- <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
<plugin>
<groupId>org.scalatest</groupId>
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
index cf38188c4b..9e3e10ecaa 100644
--- a/bin/compute-classpath.cmd
+++ b/bin/compute-classpath.cmd
@@ -20,7 +20,7 @@ rem
rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
rem script and the ExecutorRunner in standalone cluster mode.
-set SCALA_VERSION=2.9.3
+set SCALA_VERSION=2.10
rem Figure out where the Spark framework is installed
set FWDIR=%~dp0..\
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index c7819d4932..40555089fc 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -20,7 +20,7 @@
# This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
# script and the ExecutorRunner in standalone cluster mode.
-SCALA_VERSION=2.9.3
+SCALA_VERSION=2.10
# Figure out where Spark is installed
FWDIR="$(cd `dirname $0`/..; pwd)"
@@ -32,12 +32,26 @@ fi
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf"
-if [ -f "$FWDIR/RELEASE" ]; then
- ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
+
+# First check if we have a dependencies jar. If so, include binary classes with the deps jar
+if [ -f "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar ]; then
+ CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
+
+ DEPS_ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar`
+ CLASSPATH="$CLASSPATH:$DEPS_ASSEMBLY_JAR"
else
- ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
+ # Else use spark-assembly jar from either RELEASE or assembly directory
+ if [ -f "$FWDIR/RELEASE" ]; then
+ ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
+ else
+ ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
+ fi
+ CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
fi
-CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1
if [[ $SPARK_TESTING == 1 ]]; then
diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template
index ae10f615d1..1c3d94e1b0 100644
--- a/conf/metrics.properties.template
+++ b/conf/metrics.properties.template
@@ -80,6 +80,14 @@
# /metrics/aplications/json # App information
# /metrics/master/json # Master information
+# org.apache.spark.metrics.sink.GraphiteSink
+# Name: Default: Description:
+# host NONE Hostname of Graphite server
+# port NONE Port of Graphite server
+# period 10 Poll period
+# unit seconds Units of poll period
+# prefix EMPTY STRING Prefix to prepend to metric name
+
## Examples
# Enable JmxSink for all instances by class name
#*.sink.jmx.class=org.apache.spark.metrics.sink.JmxSink
diff --git a/core/pom.xml b/core/pom.xml
index e53875c72d..f0248bc539 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -1,242 +1,1351 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<!--
- ~ Licensed to the Apache Software Foundation (ASF) under one or more
- ~ contributor license agreements. See the NOTICE file distributed with
- ~ this work for additional information regarding copyright ownership.
- ~ The ASF licenses this file to You under the Apache License, Version 2.0
- ~ (the "License"); you may not use this file except in compliance with
- ~ the License. You may obtain a copy of the License at
- ~
- ~ http://www.apache.org/licenses/LICENSE-2.0
- ~
- ~ Unless required by applicable law or agreed to in writing, software
- ~ distributed under the License is distributed on an "AS IS" BASIS,
- ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- ~ See the License for the specific language governing permissions and
- ~ limitations under the License.
- -->
-
-<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
- <modelVersion>4.0.0</modelVersion>
- <parent>
- <groupId>org.apache.spark</groupId>
- <artifactId>spark-parent</artifactId>
- <version>0.9.0-incubating-SNAPSHOT</version>
- <relativePath>../pom.xml</relativePath>
- </parent>
-
- <groupId>org.apache.spark</groupId>
- <artifactId>spark-core_2.9.3</artifactId>
- <packaging>jar</packaging>
- <name>Spark Project Core</name>
- <url>http://spark.incubator.apache.org/</url>
-
- <dependencies>
- <dependency>
- <groupId>org.apache.hadoop</groupId>
- <artifactId>hadoop-client</artifactId>
- </dependency>
- <dependency>
- <groupId>net.java.dev.jets3t</groupId>
- <artifactId>jets3t</artifactId>
- </dependency>
- <dependency>
- <groupId>org.apache.avro</groupId>
- <artifactId>avro</artifactId>
- </dependency>
- <dependency>
- <groupId>org.apache.avro</groupId>
- <artifactId>avro-ipc</artifactId>
- </dependency>
- <dependency>
- <groupId>org.apache.zookeeper</groupId>
- <artifactId>zookeeper</artifactId>
- </dependency>
- <dependency>
- <groupId>org.eclipse.jetty</groupId>
- <artifactId>jetty-server</artifactId>
- </dependency>
- <dependency>
- <groupId>com.google.guava</groupId>
- <artifactId>guava</artifactId>
- </dependency>
- <dependency>
- <groupId>com.google.code.findbugs</groupId>
- <artifactId>jsr305</artifactId>
- </dependency>
- <dependency>
- <groupId>org.slf4j</groupId>
- <artifactId>slf4j-api</artifactId>
- </dependency>
- <dependency>
- <groupId>com.ning</groupId>
- <artifactId>compress-lzf</artifactId>
- </dependency>
- <dependency>
- <groupId>org.xerial.snappy</groupId>
- <artifactId>snappy-java</artifactId>
- </dependency>
- <dependency>
- <groupId>org.ow2.asm</groupId>
- <artifactId>asm</artifactId>
- </dependency>
- <dependency>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </dependency>
- <dependency>
- <groupId>com.clearspring.analytics</groupId>
- <artifactId>stream</artifactId>
- </dependency>
- <dependency>
- <groupId>com.twitter</groupId>
- <artifactId>chill_2.9.3</artifactId>
- <version>0.3.1</version>
- </dependency>
- <dependency>
- <groupId>com.twitter</groupId>
- <artifactId>chill-java</artifactId>
- <version>0.3.1</version>
- </dependency>
- <dependency>
- <groupId>com.typesafe.akka</groupId>
- <artifactId>akka-actor</artifactId>
- </dependency>
- <dependency>
- <groupId>com.typesafe.akka</groupId>
- <artifactId>akka-remote</artifactId>
- </dependency>
- <dependency>
- <groupId>com.typesafe.akka</groupId>
- <artifactId>akka-slf4j</artifactId>
- </dependency>
- <dependency>
- <groupId>org.scala-lang</groupId>
- <artifactId>scalap</artifactId>
- </dependency>
- <dependency>
- <groupId>org.scala-lang</groupId>
- <artifactId>scala-library</artifactId>
- </dependency>
- <dependency>
- <groupId>net.liftweb</groupId>
- <artifactId>lift-json_2.9.2</artifactId>
- </dependency>
- <dependency>
- <groupId>it.unimi.dsi</groupId>
- <artifactId>fastutil</artifactId>
- </dependency>
- <dependency>
- <groupId>colt</groupId>
- <artifactId>colt</artifactId>
- </dependency>
- <dependency>
- <groupId>com.github.scala-incubator.io</groupId>
- <artifactId>scala-io-file_2.9.2</artifactId>
- </dependency>
- <dependency>
- <groupId>org.apache.mesos</groupId>
- <artifactId>mesos</artifactId>
- </dependency>
- <dependency>
- <groupId>io.netty</groupId>
- <artifactId>netty-all</artifactId>
- </dependency>
- <dependency>
- <groupId>log4j</groupId>
- <artifactId>log4j</artifactId>
- </dependency>
- <dependency>
- <groupId>com.codahale.metrics</groupId>
- <artifactId>metrics-core</artifactId>
- </dependency>
- <dependency>
- <groupId>com.codahale.metrics</groupId>
- <artifactId>metrics-jvm</artifactId>
- </dependency>
- <dependency>
- <groupId>com.codahale.metrics</groupId>
- <artifactId>metrics-json</artifactId>
- </dependency>
- <dependency>
- <groupId>com.codahale.metrics</groupId>
- <artifactId>metrics-ganglia</artifactId>
- </dependency>
- <dependency>
- <groupId>org.apache.derby</groupId>
- <artifactId>derby</artifactId>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>org.scalatest</groupId>
- <artifactId>scalatest_2.9.3</artifactId>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_2.9.3</artifactId>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>org.easymock</groupId>
- <artifactId>easymock</artifactId>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>com.novocode</groupId>
- <artifactId>junit-interface</artifactId>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>org.slf4j</groupId>
- <artifactId>slf4j-log4j12</artifactId>
- <scope>test</scope>
- </dependency>
- </dependencies>
- <build>
- <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
- <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
- <plugins>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-antrun-plugin</artifactId>
- <executions>
- <execution>
- <phase>test</phase>
- <goals>
- <goal>run</goal>
- </goals>
- <configuration>
- <exportAntProperties>true</exportAntProperties>
- <tasks>
- <property name="spark.classpath" refid="maven.test.classpath" />
- <property environment="env" />
- <fail message="Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry.">
- <condition>
- <not>
- <or>
- <isset property="env.SCALA_HOME" />
- <isset property="env.SCALA_LIBRARY_PATH" />
- </or>
- </not>
- </condition>
- </fail>
- </tasks>
- </configuration>
- </execution>
- </executions>
- </plugin>
- <plugin>
- <groupId>org.scalatest</groupId>
- <artifactId>scalatest-maven-plugin</artifactId>
- <configuration>
- <environmentVariables>
- <SPARK_HOME>${basedir}/..</SPARK_HOME>
- <SPARK_TESTING>1</SPARK_TESTING>
- <SPARK_CLASSPATH>${spark.classpath}</SPARK_CLASSPATH>
- </environmentVariables>
- </configuration>
- </plugin>
- </plugins>
- </build>
-</project>
+
+
+
+<!DOCTYPE html>
+<html>
+ <head prefix="og: http://ogp.me/ns# fb: http://ogp.me/ns/fb# githubog: http://ogp.me/ns/fb/githubog#">
+ <meta charset='utf-8'>
+ <meta http-equiv="X-UA-Compatible" content="IE=edge">
+ <title>incubator-spark/core/pom.xml at master · apache/incubator-spark · GitHub</title>
+ <link rel="search" type="application/opensearchdescription+xml" href="/opensearch.xml" title="GitHub" />
+ <link rel="fluid-icon" href="https://github.com/fluidicon.png" title="GitHub" />
+ <link rel="apple-touch-icon" sizes="57x57" href="/apple-touch-icon-114.png" />
+ <link rel="apple-touch-icon" sizes="114x114" href="/apple-touch-icon-114.png" />
+ <link rel="apple-touch-icon" sizes="72x72" href="/apple-touch-icon-144.png" />
+ <link rel="apple-touch-icon" sizes="144x144" href="/apple-touch-icon-144.png" />
+ <link rel="logo" type="image/svg" href="https://github-media-downloads.s3.amazonaws.com/github-logo.svg" />
+ <meta property="og:image" content="https://github.global.ssl.fastly.net/images/modules/logos_page/Octocat.png">
+ <meta name="hostname" content="github-fe126-cp1-prd.iad.github.net">
+ <meta name="ruby" content="ruby 1.9.3p194-tcs-github-tcmalloc (e1c0c3f392) [x86_64-linux]">
+ <link rel="assets" href="https://github.global.ssl.fastly.net/">
+ <link rel="conduit-xhr" href="https://ghconduit.com:25035/">
+ <link rel="xhr-socket" href="/_sockets" />
+
+
+
+ <meta name="msapplication-TileImage" content="/windows-tile.png" />
+ <meta name="msapplication-TileColor" content="#ffffff" />
+ <meta name="selected-link" value="repo_source" data-pjax-transient />
+ <meta content="collector.githubapp.com" name="octolytics-host" /><meta content="collector-cdn.github.com" name="octolytics-script-host" /><meta content="github" name="octolytics-app-id" /><meta content="43A45EED:3C36:63F899:52C1FB77" name="octolytics-dimension-request_id" />
+
+
+
+
+ <link rel="icon" type="image/x-icon" href="/favicon.ico" />
+
+ <meta content="authenticity_token" name="csrf-param" />
+<meta content="XL0k65gLgVyHEQhufVTLBNEWwGhZf67623b0mJZcY2A=" name="csrf-token" />
+
+ <link href="https://github.global.ssl.fastly.net/assets/github-3944f96c1c19f752fe766b332fb7716555c8296e.css" media="all" rel="stylesheet" type="text/css" />
+ <link href="https://github.global.ssl.fastly.net/assets/github2-b64d0ad5fa62a30a166145ae08b8c0a6d2f7dea7.css" media="all" rel="stylesheet" type="text/css" />
+
+
+
+
+ <script src="https://github.global.ssl.fastly.net/assets/frameworks-29a3fb0547e33bd8d4530bbad9bae3ef00d83293.js" type="text/javascript"></script>
+ <script src="https://github.global.ssl.fastly.net/assets/github-3fbe2841590c916eeba07af3fc626dd593d2f5ba.js" type="text/javascript"></script>
+
+ <meta http-equiv="x-pjax-version" content="8983adfc0294e4e53e92b27093d9e927">
+
+ <link data-pjax-transient rel='permalink' href='/apache/incubator-spark/blob/50e3b8ec4c8150f1cfc6b92f8871f520adf2cfda/core/pom.xml'>
+ <meta property="og:title" content="incubator-spark"/>
+ <meta property="og:type" content="githubog:gitrepository"/>
+ <meta property="og:url" content="https://github.com/apache/incubator-spark"/>
+ <meta property="og:image" content="https://github.global.ssl.fastly.net/images/gravatars/gravatar-user-420.png"/>
+ <meta property="og:site_name" content="GitHub"/>
+ <meta property="og:description" content="incubator-spark - Mirror of Apache Spark"/>
+
+ <meta name="description" content="incubator-spark - Mirror of Apache Spark" />
+
+ <meta content="47359" name="octolytics-dimension-user_id" /><meta content="apache" name="octolytics-dimension-user_login" /><meta content="10960835" name="octolytics-dimension-repository_id" /><meta content="apache/incubator-spark" name="octolytics-dimension-repository_nwo" /><meta content="true" name="octolytics-dimension-repository_public" /><meta content="false" name="octolytics-dimension-repository_is_fork" /><meta content="10960835" name="octolytics-dimension-repository_network_root_id" /><meta content="apache/incubator-spark" name="octolytics-dimension-repository_network_root_nwo" />
+ <link href="https://github.com/apache/incubator-spark/commits/master.atom" rel="alternate" title="Recent Commits to incubator-spark:master" type="application/atom+xml" />
+
+ </head>
+
+
+ <body class="logged_out env-production vis-public mirror page-blob">
+ <div class="wrapper">
+
+
+
+
+
+
+
+ <div class="header header-logged-out">
+ <div class="container clearfix">
+
+ <a class="header-logo-wordmark" href="https://github.com/">
+ <span class="mega-octicon octicon-logo-github"></span>
+ </a>
+
+ <div class="header-actions">
+ <a class="button primary" href="/join">Sign up</a>
+ <a class="button signin" href="/login?return_to=%2Fapache%2Fincubator-spark%2Fblob%2Fmaster%2Fcore%2Fpom.xml">Sign in</a>
+ </div>
+
+ <div class="command-bar js-command-bar in-repository">
+
+ <ul class="top-nav">
+ <li class="explore"><a href="/explore">Explore</a></li>
+ <li class="features"><a href="/features">Features</a></li>
+ <li class="enterprise"><a href="https://enterprise.github.com/">Enterprise</a></li>
+ <li class="blog"><a href="/blog">Blog</a></li>
+ </ul>
+ <form accept-charset="UTF-8" action="/search" class="command-bar-form" id="top_search_form" method="get">
+
+<input type="text" data-hotkey="/ s" name="q" id="js-command-bar-field" placeholder="Search or type a command" tabindex="1" autocapitalize="off"
+
+
+ data-repo="apache/incubator-spark"
+ data-branch="master"
+ data-sha="ede8c631c6d7a940c1ab1629574ec1003eb0861e"
+ >
+
+ <input type="hidden" name="nwo" value="apache/incubator-spark" />
+
+ <div class="select-menu js-menu-container js-select-menu search-context-select-menu">
+ <span class="minibutton select-menu-button js-menu-target">
+ <span class="js-select-button">This repository</span>
+ </span>
+
+ <div class="select-menu-modal-holder js-menu-content js-navigation-container">
+ <div class="select-menu-modal">
+
+ <div class="select-menu-item js-navigation-item js-this-repository-navigation-item selected">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <input type="radio" class="js-search-this-repository" name="search_target" value="repository" checked="checked" />
+ <div class="select-menu-item-text js-select-button-text">This repository</div>
+ </div> <!-- /.select-menu-item -->
+
+ <div class="select-menu-item js-navigation-item js-all-repositories-navigation-item">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <input type="radio" name="search_target" value="global" />
+ <div class="select-menu-item-text js-select-button-text">All repositories</div>
+ </div> <!-- /.select-menu-item -->
+
+ </div>
+ </div>
+ </div>
+
+ <span class="octicon help tooltipped downwards" title="Show command bar help">
+ <span class="octicon octicon-question"></span>
+ </span>
+
+
+ <input type="hidden" name="ref" value="cmdform">
+
+</form>
+ </div>
+
+ </div>
+</div>
+
+
+
+
+
+ <div class="site" itemscope itemtype="http://schema.org/WebPage">
+
+ <div class="pagehead repohead instapaper_ignore readability-menu">
+ <div class="container">
+
+
+<ul class="pagehead-actions">
+
+
+ <li>
+ <a href="/login?return_to=%2Fapache%2Fincubator-spark"
+ class="minibutton with-count js-toggler-target star-button tooltipped upwards"
+ title="You must be signed in to use this feature" rel="nofollow">
+ <span class="octicon octicon-star"></span>Star
+ </a>
+
+ <a class="social-count js-social-count" href="/apache/incubator-spark/stargazers">
+ 322
+ </a>
+
+ </li>
+
+ <li>
+ <a href="/login?return_to=%2Fapache%2Fincubator-spark"
+ class="minibutton with-count js-toggler-target fork-button tooltipped upwards"
+ title="You must be signed in to fork a repository" rel="nofollow">
+ <span class="octicon octicon-git-branch"></span>Fork
+ </a>
+ <a href="/apache/incubator-spark/network" class="social-count">
+ 273
+ </a>
+ </li>
+</ul>
+
+ <h1 itemscope itemtype="http://data-vocabulary.org/Breadcrumb" class="entry-title public">
+ <span class="repo-label"><span>public</span></span>
+ <span class="mega-octicon octicon-repo"></span>
+ <span class="author">
+ <a href="/apache" class="url fn" itemprop="url" rel="author"><span itemprop="title">apache</span></a>
+ </span>
+ <span class="repohead-name-divider">/</span>
+ <strong><a href="/apache/incubator-spark" class="js-current-repository js-repo-home-link">incubator-spark</a></strong>
+
+ <span class="page-context-loader">
+ <img alt="Octocat-spinner-32" height="16" src="https://github.global.ssl.fastly.net/images/spinners/octocat-spinner-32.gif" width="16" />
+ </span>
+
+ <span class="mirror-flag">
+ <span class="text">mirrored from <a href="git://git.apache.org/incubator-spark.git">git://git.apache.org/incubator-spark.git</a></span>
+ </span>
+ </h1>
+ </div><!-- /.container -->
+ </div><!-- /.repohead -->
+
+ <div class="container">
+
+
+ <div class="repository-with-sidebar repo-container ">
+
+ <div class="repository-sidebar">
+
+
+<div class="sunken-menu vertical-right repo-nav js-repo-nav js-repository-container-pjax js-octicon-loaders">
+ <div class="sunken-menu-contents">
+ <ul class="sunken-menu-group">
+ <li class="tooltipped leftwards" title="Code">
+ <a href="/apache/incubator-spark" aria-label="Code" class="selected js-selected-navigation-item sunken-menu-item" data-gotokey="c" data-pjax="true" data-selected-links="repo_source repo_downloads repo_commits repo_tags repo_branches /apache/incubator-spark">
+ <span class="octicon octicon-code"></span> <span class="full-word">Code</span>
+ <img alt="Octocat-spinner-32" class="mini-loader" height="16" src="https://github.global.ssl.fastly.net/images/spinners/octocat-spinner-32.gif" width="16" />
+</a> </li>
+
+
+ <li class="tooltipped leftwards" title="Pull Requests">
+ <a href="/apache/incubator-spark/pulls" aria-label="Pull Requests" class="js-selected-navigation-item sunken-menu-item js-disable-pjax" data-gotokey="p" data-selected-links="repo_pulls /apache/incubator-spark/pulls">
+ <span class="octicon octicon-git-pull-request"></span> <span class="full-word">Pull Requests</span>
+ <span class='counter'>45</span>
+ <img alt="Octocat-spinner-32" class="mini-loader" height="16" src="https://github.global.ssl.fastly.net/images/spinners/octocat-spinner-32.gif" width="16" />
+</a> </li>
+
+
+ </ul>
+ <div class="sunken-menu-separator"></div>
+ <ul class="sunken-menu-group">
+
+ <li class="tooltipped leftwards" title="Pulse">
+ <a href="/apache/incubator-spark/pulse" aria-label="Pulse" class="js-selected-navigation-item sunken-menu-item" data-pjax="true" data-selected-links="pulse /apache/incubator-spark/pulse">
+ <span class="octicon octicon-pulse"></span> <span class="full-word">Pulse</span>
+ <img alt="Octocat-spinner-32" class="mini-loader" height="16" src="https://github.global.ssl.fastly.net/images/spinners/octocat-spinner-32.gif" width="16" />
+</a> </li>
+
+ <li class="tooltipped leftwards" title="Graphs">
+ <a href="/apache/incubator-spark/graphs" aria-label="Graphs" class="js-selected-navigation-item sunken-menu-item" data-pjax="true" data-selected-links="repo_graphs repo_contributors /apache/incubator-spark/graphs">
+ <span class="octicon octicon-graph"></span> <span class="full-word">Graphs</span>
+ <img alt="Octocat-spinner-32" class="mini-loader" height="16" src="https://github.global.ssl.fastly.net/images/spinners/octocat-spinner-32.gif" width="16" />
+</a> </li>
+
+ <li class="tooltipped leftwards" title="Network">
+ <a href="/apache/incubator-spark/network" aria-label="Network" class="js-selected-navigation-item sunken-menu-item js-disable-pjax" data-selected-links="repo_network /apache/incubator-spark/network">
+ <span class="octicon octicon-git-branch"></span> <span class="full-word">Network</span>
+ <img alt="Octocat-spinner-32" class="mini-loader" height="16" src="https://github.global.ssl.fastly.net/images/spinners/octocat-spinner-32.gif" width="16" />
+</a> </li>
+ </ul>
+
+
+ </div>
+</div>
+
+ <div class="only-with-full-nav">
+
+
+
+
+<div class="clone-url open"
+ data-protocol-type="http"
+ data-url="/users/set_protocol?protocol_selector=http&amp;protocol_type=clone">
+ <h3><strong>HTTPS</strong> clone URL</h3>
+ <div class="clone-url-box">
+ <input type="text" class="clone js-url-field"
+ value="https://github.com/apache/incubator-spark.git" readonly="readonly">
+
+ <span class="js-zeroclipboard url-box-clippy minibutton zeroclipboard-button" data-clipboard-text="https://github.com/apache/incubator-spark.git" data-copied-hint="copied!" title="copy to clipboard"><span class="octicon octicon-clippy"></span></span>
+ </div>
+</div>
+
+
+
+<div class="clone-url "
+ data-protocol-type="subversion"
+ data-url="/users/set_protocol?protocol_selector=subversion&amp;protocol_type=clone">
+ <h3><strong>Subversion</strong> checkout URL</h3>
+ <div class="clone-url-box">
+ <input type="text" class="clone js-url-field"
+ value="https://github.com/apache/incubator-spark" readonly="readonly">
+
+ <span class="js-zeroclipboard url-box-clippy minibutton zeroclipboard-button" data-clipboard-text="https://github.com/apache/incubator-spark" data-copied-hint="copied!" title="copy to clipboard"><span class="octicon octicon-clippy"></span></span>
+ </div>
+</div>
+
+
+<p class="clone-options">You can clone with
+ <a href="#" class="js-clone-selector" data-protocol="http">HTTPS</a>,
+ or <a href="#" class="js-clone-selector" data-protocol="subversion">Subversion</a>.
+ <span class="octicon help tooltipped upwards" title="Get help on which URL is right for you.">
+ <a href="https://help.github.com/articles/which-remote-url-should-i-use">
+ <span class="octicon octicon-question"></span>
+ </a>
+ </span>
+</p>
+
+
+
+ <a href="/apache/incubator-spark/archive/master.zip"
+ class="minibutton sidebar-button"
+ title="Download this repository as a zip file"
+ rel="nofollow">
+ <span class="octicon octicon-cloud-download"></span>
+ Download ZIP
+ </a>
+ </div>
+ </div><!-- /.repository-sidebar -->
+
+ <div id="js-repo-pjax-container" class="repository-content context-loader-container" data-pjax-container>
+
+
+
+<!-- blob contrib key: blob_contributors:v21:ca1fa3336589a56eafe4ec1105b40975 -->
+
+<p title="This is a placeholder element" class="js-history-link-replace hidden"></p>
+
+<a href="/apache/incubator-spark/find/master" data-pjax data-hotkey="t" class="js-show-file-finder" style="display:none">Show File Finder</a>
+
+<div class="file-navigation">
+
+
+<div class="select-menu js-menu-container js-select-menu" >
+ <span class="minibutton select-menu-button js-menu-target" data-hotkey="w"
+ data-master-branch="master"
+ data-ref="master"
+ role="button" aria-label="Switch branches or tags" tabindex="0">
+ <span class="octicon octicon-git-branch"></span>
+ <i>branch:</i>
+ <span class="js-select-button">master</span>
+ </span>
+
+ <div class="select-menu-modal-holder js-menu-content js-navigation-container" data-pjax>
+
+ <div class="select-menu-modal">
+ <div class="select-menu-header">
+ <span class="select-menu-title">Switch branches/tags</span>
+ <span class="octicon octicon-remove-close js-menu-close"></span>
+ </div> <!-- /.select-menu-header -->
+
+ <div class="select-menu-filters">
+ <div class="select-menu-text-filter">
+ <input type="text" aria-label="Filter branches/tags" id="context-commitish-filter-field" class="js-filterable-field js-navigation-enable" placeholder="Filter branches/tags">
+ </div>
+ <div class="select-menu-tabs">
+ <ul>
+ <li class="select-menu-tab">
+ <a href="#" data-tab-filter="branches" class="js-select-menu-tab">Branches</a>
+ </li>
+ <li class="select-menu-tab">
+ <a href="#" data-tab-filter="tags" class="js-select-menu-tab">Tags</a>
+ </li>
+ </ul>
+ </div><!-- /.select-menu-tabs -->
+ </div><!-- /.select-menu-filters -->
+
+ <div class="select-menu-list select-menu-tab-bucket js-select-menu-tab-bucket" data-tab-filter="branches">
+
+ <div data-filterable-for="context-commitish-filter-field" data-filterable-type="substring">
+
+
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/akka-actors/core/pom.xml"
+ data-name="akka-actors"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="akka-actors">akka-actors</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/arthur/core/pom.xml"
+ data-name="arthur"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="arthur">arthur</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/branch-0.5/core/pom.xml"
+ data-name="branch-0.5"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="branch-0.5">branch-0.5</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/branch-0.6/core/pom.xml"
+ data-name="branch-0.6"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="branch-0.6">branch-0.6</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/branch-0.7/core/pom.xml"
+ data-name="branch-0.7"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="branch-0.7">branch-0.7</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/branch-0.8/core/pom.xml"
+ data-name="branch-0.8"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="branch-0.8">branch-0.8</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/charles-newhadoop/core/pom.xml"
+ data-name="charles-newhadoop"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="charles-newhadoop">charles-newhadoop</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/dev/core/pom.xml"
+ data-name="dev"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="dev">dev</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/formatting/core/pom.xml"
+ data-name="formatting"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="formatting">formatting</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/hive/core/pom.xml"
+ data-name="hive"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="hive">hive</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/java-api/core/pom.xml"
+ data-name="java-api"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="java-api">java-api</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item selected">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/master/core/pom.xml"
+ data-name="master"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="master">master</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/mesos-0.9/core/pom.xml"
+ data-name="mesos-0.9"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="mesos-0.9">mesos-0.9</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/mos-bt/core/pom.xml"
+ data-name="mos-bt"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="mos-bt">mos-bt</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/mos-bt-dev/core/pom.xml"
+ data-name="mos-bt-dev"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="mos-bt-dev">mos-bt-dev</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/mos-bt-topo/core/pom.xml"
+ data-name="mos-bt-topo"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="mos-bt-topo">mos-bt-topo</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/mos-shuffle/core/pom.xml"
+ data-name="mos-shuffle"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="mos-shuffle">mos-shuffle</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/mos-shuffle-supertracked/core/pom.xml"
+ data-name="mos-shuffle-supertracked"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="mos-shuffle-supertracked">mos-shuffle-supertracked</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/mos-shuffle-tracked/core/pom.xml"
+ data-name="mos-shuffle-tracked"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="mos-shuffle-tracked">mos-shuffle-tracked</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/multi-tracker/core/pom.xml"
+ data-name="multi-tracker"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="multi-tracker">multi-tracker</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/new-rdds-protobuf/core/pom.xml"
+ data-name="new-rdds-protobuf"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="new-rdds-protobuf">new-rdds-protobuf</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/object-file-fix/core/pom.xml"
+ data-name="object-file-fix"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="object-file-fix">object-file-fix</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/old-mesos/core/pom.xml"
+ data-name="old-mesos"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="old-mesos">old-mesos</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/old-rdds/core/pom.xml"
+ data-name="old-rdds"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="old-rdds">old-rdds</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/perf/core/pom.xml"
+ data-name="perf"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="perf">perf</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/reduce-logging/core/pom.xml"
+ data-name="reduce-logging"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="reduce-logging">reduce-logging</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/rxin/core/pom.xml"
+ data-name="rxin"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="rxin">rxin</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/scala-2.8/core/pom.xml"
+ data-name="scala-2.8"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="scala-2.8">scala-2.8</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/scala-2.9/core/pom.xml"
+ data-name="scala-2.9"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="scala-2.9">scala-2.9</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/scala-2.10/core/pom.xml"
+ data-name="scala-2.10"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="scala-2.10">scala-2.10</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/sched-refactoring/core/pom.xml"
+ data-name="sched-refactoring"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="sched-refactoring">sched-refactoring</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/shuffle-fix/core/pom.xml"
+ data-name="shuffle-fix"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="shuffle-fix">shuffle-fix</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/sparkplug/core/pom.xml"
+ data-name="sparkplug"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="sparkplug">sparkplug</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/streaming/core/pom.xml"
+ data-name="streaming"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="streaming">streaming</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/td-checksum/core/pom.xml"
+ data-name="td-checksum"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="td-checksum">td-checksum</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/td-rdd-save/core/pom.xml"
+ data-name="td-rdd-save"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="td-rdd-save">td-rdd-save</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/tjhunter-branch/core/pom.xml"
+ data-name="tjhunter-branch"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="tjhunter-branch">tjhunter-branch</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/twitter-mesos/core/pom.xml"
+ data-name="twitter-mesos"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="twitter-mesos">twitter-mesos</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/blob/yarn/core/pom.xml"
+ data-name="yarn"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="yarn">yarn</a>
+ </div> <!-- /.select-menu-item -->
+ </div>
+
+ <div class="select-menu-no-results">Nothing to show</div>
+ </div> <!-- /.select-menu-list -->
+
+ <div class="select-menu-list select-menu-tab-bucket js-select-menu-tab-bucket" data-tab-filter="tags">
+ <div data-filterable-for="context-commitish-filter-field" data-filterable-type="substring">
+
+
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.8.1-incubating/core/pom.xml"
+ data-name="v0.8.1-incubating"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.8.1-incubating">v0.8.1-incubating</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.8.0-incubating-rc3/core/pom.xml"
+ data-name="v0.8.0-incubating-rc3"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.8.0-incubating-rc3">v0.8.0-incubating-rc3</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.8.0-incubating/core/pom.xml"
+ data-name="v0.8.0-incubating"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.8.0-incubating">v0.8.0-incubating</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.7.2/core/pom.xml"
+ data-name="v0.7.2"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.7.2">v0.7.2</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.7.1/core/pom.xml"
+ data-name="v0.7.1"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.7.1">v0.7.1</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.7.0-bizo-1/core/pom.xml"
+ data-name="v0.7.0-bizo-1"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.7.0-bizo-1">v0.7.0-bizo-1</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.7.0/core/pom.xml"
+ data-name="v0.7.0"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.7.0">v0.7.0</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.6.2/core/pom.xml"
+ data-name="v0.6.2"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.6.2">v0.6.2</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.6.1/core/pom.xml"
+ data-name="v0.6.1"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.6.1">v0.6.1</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.6.0-yarn/core/pom.xml"
+ data-name="v0.6.0-yarn"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.6.0-yarn">v0.6.0-yarn</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.6.0/core/pom.xml"
+ data-name="v0.6.0"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.6.0">v0.6.0</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.5.2/core/pom.xml"
+ data-name="v0.5.2"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.5.2">v0.5.2</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.5.1/core/pom.xml"
+ data-name="v0.5.1"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.5.1">v0.5.1</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/v0.5.0/core/pom.xml"
+ data-name="v0.5.0"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="v0.5.0">v0.5.0</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/spark-parent-0.8.1-incubating/core/pom.xml"
+ data-name="spark-parent-0.8.1-incubating"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="spark-parent-0.8.1-incubating">spark-parent-0.8.1-incubating</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/alpha-0.2/core/pom.xml"
+ data-name="alpha-0.2"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="alpha-0.2">alpha-0.2</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/alpha-0.1/core/pom.xml"
+ data-name="alpha-0.1"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="alpha-0.1">alpha-0.1</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/0.3-scala-2.9/core/pom.xml"
+ data-name="0.3-scala-2.9"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="0.3-scala-2.9">0.3-scala-2.9</a>
+ </div> <!-- /.select-menu-item -->
+ <div class="select-menu-item js-navigation-item ">
+ <span class="select-menu-item-icon octicon octicon-check"></span>
+ <a href="/apache/incubator-spark/tree/0.3-scala-2.8/core/pom.xml"
+ data-name="0.3-scala-2.8"
+ data-skip-pjax="true"
+ rel="nofollow"
+ class="js-navigation-open select-menu-item-text js-select-button-text css-truncate-target"
+ title="0.3-scala-2.8">0.3-scala-2.8</a>
+ </div> <!-- /.select-menu-item -->
+ </div>
+
+ <div class="select-menu-no-results">Nothing to show</div>
+ </div> <!-- /.select-menu-list -->
+
+ </div> <!-- /.select-menu-modal -->
+ </div> <!-- /.select-menu-modal-holder -->
+</div> <!-- /.select-menu -->
+
+ <div class="breadcrumb">
+ <span class='repo-root js-repo-root'><span itemscope="" itemtype="http://data-vocabulary.org/Breadcrumb"><a href="/apache/incubator-spark" data-branch="master" data-direction="back" data-pjax="true" itemscope="url"><span itemprop="title">incubator-spark</span></a></span></span><span class="separator"> / </span><span itemscope="" itemtype="http://data-vocabulary.org/Breadcrumb"><a href="/apache/incubator-spark/tree/master/core" data-branch="master" data-direction="back" data-pjax="true" itemscope="url"><span itemprop="title">core</span></a></span><span class="separator"> / </span><strong class="final-path">pom.xml</strong> <span class="js-zeroclipboard minibutton zeroclipboard-button" data-clipboard-text="core/pom.xml" data-copied-hint="copied!" title="copy to clipboard"><span class="octicon octicon-clippy"></span></span>
+ </div>
+</div>
+
+
+
+ <div class="commit file-history-tease">
+ <img class="main-avatar" height="24" src="https://2.gravatar.com/avatar/def597cee3897ba290da26972ff628d2?d=https%3A%2F%2Fidenticons.github.com%2F0ec10d5e3ed3720a2d578417a894cf49.png&amp;r=x&amp;s=140" width="24" />
+ <span class="author"><a href="/pwendell" rel="author">pwendell</a></span>
+ <time class="js-relative-date" datetime="2013-12-16T21:56:21-08:00" title="2013-12-16 21:56:21">December 16, 2013</time>
+ <div class="commit-title">
+ <a href="/apache/incubator-spark/commit/c1fec89895f03dbdbb6f445ea3cdcd2d050555c4" class="message" data-pjax="true" title="Cleanup">Cleanup</a>
+ </div>
+
+ <div class="participation">
+ <p class="quickstat"><a href="#blob_contributors_box" rel="facebox"><strong>15</strong> contributors</a></p>
+ <a class="avatar tooltipped downwards" title="ScrapCodes" href="/apache/incubator-spark/commits/master/core/pom.xml?author=ScrapCodes"><img height="20" src="https://1.gravatar.com/avatar/e9813bbbab2caa993bf7e2b2d60de894?d=https%3A%2F%2Fidenticons.github.com%2F38c660c74f82a216b75167debab770ed.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="mateiz" href="/apache/incubator-spark/commits/master/core/pom.xml?author=mateiz"><img height="20" src="https://2.gravatar.com/avatar/17d3d634fc898bf3ae19444450af6805?d=https%3A%2F%2Fidenticons.github.com%2F627fcfc1b5e8e4535cc26ecfa133743e.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="jey" href="/apache/incubator-spark/commits/master/core/pom.xml?author=jey"><img height="20" src="https://1.gravatar.com/avatar/85a3f16c85ee6e38527e82e035ed4bf9?d=https%3A%2F%2Fidenticons.github.com%2F2ebb6c06bdc16ef37ec965c6b325b5c6.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="markhamstra" href="/apache/incubator-spark/commits/master/core/pom.xml?author=markhamstra"><img height="20" src="https://2.gravatar.com/avatar/aaa2e907159f6919c4a4c8039d46752f?d=https%3A%2F%2Fidenticons.github.com%2F80aa657180765c73a93528281452d8dc.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="pwendell" href="/apache/incubator-spark/commits/master/core/pom.xml?author=pwendell"><img height="20" src="https://2.gravatar.com/avatar/def597cee3897ba290da26972ff628d2?d=https%3A%2F%2Fidenticons.github.com%2F0ec10d5e3ed3720a2d578417a894cf49.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="rxin" href="/apache/incubator-spark/commits/master/core/pom.xml?author=rxin"><img height="20" src="https://0.gravatar.com/avatar/73bfac91877fce35c6d20e16e9e53677?d=https%3A%2F%2Fidenticons.github.com%2F017ef876ba320362d11baf01bb584eee.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="tomdz" href="/apache/incubator-spark/commits/master/core/pom.xml?author=tomdz"><img height="20" src="https://2.gravatar.com/avatar/18039bff98071ad398b4301cfb0522b4?d=https%3A%2F%2Fidenticons.github.com%2F95b22f6d3c85b4945466aee72a712a01.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="shivaram" href="/apache/incubator-spark/commits/master/core/pom.xml?author=shivaram"><img height="20" src="https://0.gravatar.com/avatar/9d61244268a5be0ab361dc6e4af65ee4?d=https%3A%2F%2Fidenticons.github.com%2F020255c0f2993d01f44c91e0b508fa84.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="mridulm" href="/apache/incubator-spark/commits/master/core/pom.xml?author=mridulm"><img height="20" src="https://1.gravatar.com/avatar/0c6f3ad3eb45b1c42ca026f9d6ff3794?d=https%3A%2F%2Fidenticons.github.com%2Feed4d497b06cb847301ab49f40608dd4.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="woggle" href="/apache/incubator-spark/commits/master/core/pom.xml?author=woggle"><img height="20" src="https://1.gravatar.com/avatar/86f3647dccd886de735435991f55848e?d=https%3A%2F%2Fidenticons.github.com%2Ff23fdb10552dbc5ec1a5c9ba8e6f6be3.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="colorant" href="/apache/incubator-spark/commits/master/core/pom.xml?author=colorant"><img height="20" src="https://0.gravatar.com/avatar/1e54a90ff3671ca36be8a1163038ca56?d=https%3A%2F%2Fidenticons.github.com%2F90d602b3c6ec06eda12e485e109415fe.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="jerryshao" href="/apache/incubator-spark/commits/master/core/pom.xml?author=jerryshao"><img height="20" src="https://2.gravatar.com/avatar/24fb58dab7ef7c3202c4c8061ee51a12?d=https%3A%2F%2Fidenticons.github.com%2Ff8737f93e1f940e20f1259ad5403be32.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="russellcardullo" href="/apache/incubator-spark/commits/master/core/pom.xml?author=russellcardullo"><img height="20" src="https://0.gravatar.com/avatar/4173fab43d50a19e42950a8f7c7d96f8?d=https%3A%2F%2Fidenticons.github.com%2Fcae42ebaab97a0f8a1feb2f3fc97ee91.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="folone" href="/apache/incubator-spark/commits/master/core/pom.xml?author=folone"><img height="20" src="https://2.gravatar.com/avatar/50e7e3f60b3507383d2b327857b66a62?d=https%3A%2F%2Fidenticons.github.com%2F6c2bcbccd23191b40f4932e2b8450681.png&amp;r=x&amp;s=140" width="20" /></a>
+ <a class="avatar tooltipped downwards" title="witgo" href="/apache/incubator-spark/commits/master/core/pom.xml?author=witgo"><img height="20" src="https://2.gravatar.com/avatar/82f14c57ba7522241b0e3c67d759609e?d=https%3A%2F%2Fidenticons.github.com%2F878f3103ff85a3edb2d415930bbdbd5a.png&amp;r=x&amp;s=140" width="20" /></a>
+
+
+ </div>
+ <div id="blob_contributors_box" style="display:none">
+ <h2 class="facebox-header">Users who have contributed to this file</h2>
+ <ul class="facebox-user-list">
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://1.gravatar.com/avatar/e9813bbbab2caa993bf7e2b2d60de894?d=https%3A%2F%2Fidenticons.github.com%2F38c660c74f82a216b75167debab770ed.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/ScrapCodes">ScrapCodes</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://2.gravatar.com/avatar/17d3d634fc898bf3ae19444450af6805?d=https%3A%2F%2Fidenticons.github.com%2F627fcfc1b5e8e4535cc26ecfa133743e.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/mateiz">mateiz</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://1.gravatar.com/avatar/85a3f16c85ee6e38527e82e035ed4bf9?d=https%3A%2F%2Fidenticons.github.com%2F2ebb6c06bdc16ef37ec965c6b325b5c6.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/jey">jey</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://2.gravatar.com/avatar/aaa2e907159f6919c4a4c8039d46752f?d=https%3A%2F%2Fidenticons.github.com%2F80aa657180765c73a93528281452d8dc.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/markhamstra">markhamstra</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://2.gravatar.com/avatar/def597cee3897ba290da26972ff628d2?d=https%3A%2F%2Fidenticons.github.com%2F0ec10d5e3ed3720a2d578417a894cf49.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/pwendell">pwendell</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://0.gravatar.com/avatar/73bfac91877fce35c6d20e16e9e53677?d=https%3A%2F%2Fidenticons.github.com%2F017ef876ba320362d11baf01bb584eee.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/rxin">rxin</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://2.gravatar.com/avatar/18039bff98071ad398b4301cfb0522b4?d=https%3A%2F%2Fidenticons.github.com%2F95b22f6d3c85b4945466aee72a712a01.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/tomdz">tomdz</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://0.gravatar.com/avatar/9d61244268a5be0ab361dc6e4af65ee4?d=https%3A%2F%2Fidenticons.github.com%2F020255c0f2993d01f44c91e0b508fa84.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/shivaram">shivaram</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://1.gravatar.com/avatar/0c6f3ad3eb45b1c42ca026f9d6ff3794?d=https%3A%2F%2Fidenticons.github.com%2Feed4d497b06cb847301ab49f40608dd4.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/mridulm">mridulm</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://1.gravatar.com/avatar/86f3647dccd886de735435991f55848e?d=https%3A%2F%2Fidenticons.github.com%2Ff23fdb10552dbc5ec1a5c9ba8e6f6be3.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/woggle">woggle</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://0.gravatar.com/avatar/1e54a90ff3671ca36be8a1163038ca56?d=https%3A%2F%2Fidenticons.github.com%2F90d602b3c6ec06eda12e485e109415fe.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/colorant">colorant</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://2.gravatar.com/avatar/24fb58dab7ef7c3202c4c8061ee51a12?d=https%3A%2F%2Fidenticons.github.com%2Ff8737f93e1f940e20f1259ad5403be32.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/jerryshao">jerryshao</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://0.gravatar.com/avatar/4173fab43d50a19e42950a8f7c7d96f8?d=https%3A%2F%2Fidenticons.github.com%2Fcae42ebaab97a0f8a1feb2f3fc97ee91.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/russellcardullo">russellcardullo</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://2.gravatar.com/avatar/50e7e3f60b3507383d2b327857b66a62?d=https%3A%2F%2Fidenticons.github.com%2F6c2bcbccd23191b40f4932e2b8450681.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/folone">folone</a>
+ </li>
+ <li class="facebox-user-list-item">
+ <img height="24" src="https://2.gravatar.com/avatar/82f14c57ba7522241b0e3c67d759609e?d=https%3A%2F%2Fidenticons.github.com%2F878f3103ff85a3edb2d415930bbdbd5a.png&amp;r=x&amp;s=140" width="24" />
+ <a href="/witgo">witgo</a>
+ </li>
+ </ul>
+ </div>
+ </div>
+
+<div id="files" class="bubble">
+ <div class="file">
+ <div class="meta">
+ <div class="info">
+ <span class="icon"><b class="octicon octicon-file-text"></b></span>
+ <span class="mode" title="File Mode">file</span>
+ <span>232 lines (228 sloc)</span>
+ <span>7.715 kb</span>
+ </div>
+ <div class="actions">
+ <div class="button-group">
+ <a class="minibutton disabled tooltipped leftwards" href="#"
+ title="You must be signed in to make or propose changes">Edit</a>
+ <a href="/apache/incubator-spark/raw/master/core/pom.xml" class="button minibutton " id="raw-url">Raw</a>
+ <a href="/apache/incubator-spark/blame/master/core/pom.xml" class="button minibutton ">Blame</a>
+ <a href="/apache/incubator-spark/commits/master/core/pom.xml" class="button minibutton " rel="nofollow">History</a>
+ </div><!-- /.button-group -->
+ <a class="minibutton danger disabled empty-icon tooltipped leftwards" href="#"
+ title="You must be signed in to make or propose changes">
+ Delete
+ </a>
+ </div><!-- /.actions -->
+
+ </div>
+ <div class="blob-wrapper data type-xml js-blob-data">
+ <table class="file-code file-diff">
+ <tr class="file-code-line">
+ <td class="blob-line-nums">
+ <span id="L1" rel="#L1">1</span>
+<span id="L2" rel="#L2">2</span>
+<span id="L3" rel="#L3">3</span>
+<span id="L4" rel="#L4">4</span>
+<span id="L5" rel="#L5">5</span>
+<span id="L6" rel="#L6">6</span>
+<span id="L7" rel="#L7">7</span>
+<span id="L8" rel="#L8">8</span>
+<span id="L9" rel="#L9">9</span>
+<span id="L10" rel="#L10">10</span>
+<span id="L11" rel="#L11">11</span>
+<span id="L12" rel="#L12">12</span>
+<span id="L13" rel="#L13">13</span>
+<span id="L14" rel="#L14">14</span>
+<span id="L15" rel="#L15">15</span>
+<span id="L16" rel="#L16">16</span>
+<span id="L17" rel="#L17">17</span>
+<span id="L18" rel="#L18">18</span>
+<span id="L19" rel="#L19">19</span>
+<span id="L20" rel="#L20">20</span>
+<span id="L21" rel="#L21">21</span>
+<span id="L22" rel="#L22">22</span>
+<span id="L23" rel="#L23">23</span>
+<span id="L24" rel="#L24">24</span>
+<span id="L25" rel="#L25">25</span>
+<span id="L26" rel="#L26">26</span>
+<span id="L27" rel="#L27">27</span>
+<span id="L28" rel="#L28">28</span>
+<span id="L29" rel="#L29">29</span>
+<span id="L30" rel="#L30">30</span>
+<span id="L31" rel="#L31">31</span>
+<span id="L32" rel="#L32">32</span>
+<span id="L33" rel="#L33">33</span>
+<span id="L34" rel="#L34">34</span>
+<span id="L35" rel="#L35">35</span>
+<span id="L36" rel="#L36">36</span>
+<span id="L37" rel="#L37">37</span>
+<span id="L38" rel="#L38">38</span>
+<span id="L39" rel="#L39">39</span>
+<span id="L40" rel="#L40">40</span>
+<span id="L41" rel="#L41">41</span>
+<span id="L42" rel="#L42">42</span>
+<span id="L43" rel="#L43">43</span>
+<span id="L44" rel="#L44">44</span>
+<span id="L45" rel="#L45">45</span>
+<span id="L46" rel="#L46">46</span>
+<span id="L47" rel="#L47">47</span>
+<span id="L48" rel="#L48">48</span>
+<span id="L49" rel="#L49">49</span>
+<span id="L50" rel="#L50">50</span>
+<span id="L51" rel="#L51">51</span>
+<span id="L52" rel="#L52">52</span>
+<span id="L53" rel="#L53">53</span>
+<span id="L54" rel="#L54">54</span>
+<span id="L55" rel="#L55">55</span>
+<span id="L56" rel="#L56">56</span>
+<span id="L57" rel="#L57">57</span>
+<span id="L58" rel="#L58">58</span>
+<span id="L59" rel="#L59">59</span>
+<span id="L60" rel="#L60">60</span>
+<span id="L61" rel="#L61">61</span>
+<span id="L62" rel="#L62">62</span>
+<span id="L63" rel="#L63">63</span>
+<span id="L64" rel="#L64">64</span>
+<span id="L65" rel="#L65">65</span>
+<span id="L66" rel="#L66">66</span>
+<span id="L67" rel="#L67">67</span>
+<span id="L68" rel="#L68">68</span>
+<span id="L69" rel="#L69">69</span>
+<span id="L70" rel="#L70">70</span>
+<span id="L71" rel="#L71">71</span>
+<span id="L72" rel="#L72">72</span>
+<span id="L73" rel="#L73">73</span>
+<span id="L74" rel="#L74">74</span>
+<span id="L75" rel="#L75">75</span>
+<span id="L76" rel="#L76">76</span>
+<span id="L77" rel="#L77">77</span>
+<span id="L78" rel="#L78">78</span>
+<span id="L79" rel="#L79">79</span>
+<span id="L80" rel="#L80">80</span>
+<span id="L81" rel="#L81">81</span>
+<span id="L82" rel="#L82">82</span>
+<span id="L83" rel="#L83">83</span>
+<span id="L84" rel="#L84">84</span>
+<span id="L85" rel="#L85">85</span>
+<span id="L86" rel="#L86">86</span>
+<span id="L87" rel="#L87">87</span>
+<span id="L88" rel="#L88">88</span>
+<span id="L89" rel="#L89">89</span>
+<span id="L90" rel="#L90">90</span>
+<span id="L91" rel="#L91">91</span>
+<span id="L92" rel="#L92">92</span>
+<span id="L93" rel="#L93">93</span>
+<span id="L94" rel="#L94">94</span>
+<span id="L95" rel="#L95">95</span>
+<span id="L96" rel="#L96">96</span>
+<span id="L97" rel="#L97">97</span>
+<span id="L98" rel="#L98">98</span>
+<span id="L99" rel="#L99">99</span>
+<span id="L100" rel="#L100">100</span>
+<span id="L101" rel="#L101">101</span>
+<span id="L102" rel="#L102">102</span>
+<span id="L103" rel="#L103">103</span>
+<span id="L104" rel="#L104">104</span>
+<span id="L105" rel="#L105">105</span>
+<span id="L106" rel="#L106">106</span>
+<span id="L107" rel="#L107">107</span>
+<span id="L108" rel="#L108">108</span>
+<span id="L109" rel="#L109">109</span>
+<span id="L110" rel="#L110">110</span>
+<span id="L111" rel="#L111">111</span>
+<span id="L112" rel="#L112">112</span>
+<span id="L113" rel="#L113">113</span>
+<span id="L114" rel="#L114">114</span>
+<span id="L115" rel="#L115">115</span>
+<span id="L116" rel="#L116">116</span>
+<span id="L117" rel="#L117">117</span>
+<span id="L118" rel="#L118">118</span>
+<span id="L119" rel="#L119">119</span>
+<span id="L120" rel="#L120">120</span>
+<span id="L121" rel="#L121">121</span>
+<span id="L122" rel="#L122">122</span>
+<span id="L123" rel="#L123">123</span>
+<span id="L124" rel="#L124">124</span>
+<span id="L125" rel="#L125">125</span>
+<span id="L126" rel="#L126">126</span>
+<span id="L127" rel="#L127">127</span>
+<span id="L128" rel="#L128">128</span>
+<span id="L129" rel="#L129">129</span>
+<span id="L130" rel="#L130">130</span>
+<span id="L131" rel="#L131">131</span>
+<span id="L132" rel="#L132">132</span>
+<span id="L133" rel="#L133">133</span>
+<span id="L134" rel="#L134">134</span>
+<span id="L135" rel="#L135">135</span>
+<span id="L136" rel="#L136">136</span>
+<span id="L137" rel="#L137">137</span>
+<span id="L138" rel="#L138">138</span>
+<span id="L139" rel="#L139">139</span>
+<span id="L140" rel="#L140">140</span>
+<span id="L141" rel="#L141">141</span>
+<span id="L142" rel="#L142">142</span>
+<span id="L143" rel="#L143">143</span>
+<span id="L144" rel="#L144">144</span>
+<span id="L145" rel="#L145">145</span>
+<span id="L146" rel="#L146">146</span>
+<span id="L147" rel="#L147">147</span>
+<span id="L148" rel="#L148">148</span>
+<span id="L149" rel="#L149">149</span>
+<span id="L150" rel="#L150">150</span>
+<span id="L151" rel="#L151">151</span>
+<span id="L152" rel="#L152">152</span>
+<span id="L153" rel="#L153">153</span>
+<span id="L154" rel="#L154">154</span>
+<span id="L155" rel="#L155">155</span>
+<span id="L156" rel="#L156">156</span>
+<span id="L157" rel="#L157">157</span>
+<span id="L158" rel="#L158">158</span>
+<span id="L159" rel="#L159">159</span>
+<span id="L160" rel="#L160">160</span>
+<span id="L161" rel="#L161">161</span>
+<span id="L162" rel="#L162">162</span>
+<span id="L163" rel="#L163">163</span>
+<span id="L164" rel="#L164">164</span>
+<span id="L165" rel="#L165">165</span>
+<span id="L166" rel="#L166">166</span>
+<span id="L167" rel="#L167">167</span>
+<span id="L168" rel="#L168">168</span>
+<span id="L169" rel="#L169">169</span>
+<span id="L170" rel="#L170">170</span>
+<span id="L171" rel="#L171">171</span>
+<span id="L172" rel="#L172">172</span>
+<span id="L173" rel="#L173">173</span>
+<span id="L174" rel="#L174">174</span>
+<span id="L175" rel="#L175">175</span>
+<span id="L176" rel="#L176">176</span>
+<span id="L177" rel="#L177">177</span>
+<span id="L178" rel="#L178">178</span>
+<span id="L179" rel="#L179">179</span>
+<span id="L180" rel="#L180">180</span>
+<span id="L181" rel="#L181">181</span>
+<span id="L182" rel="#L182">182</span>
+<span id="L183" rel="#L183">183</span>
+<span id="L184" rel="#L184">184</span>
+<span id="L185" rel="#L185">185</span>
+<span id="L186" rel="#L186">186</span>
+<span id="L187" rel="#L187">187</span>
+<span id="L188" rel="#L188">188</span>
+<span id="L189" rel="#L189">189</span>
+<span id="L190" rel="#L190">190</span>
+<span id="L191" rel="#L191">191</span>
+<span id="L192" rel="#L192">192</span>
+<span id="L193" rel="#L193">193</span>
+<span id="L194" rel="#L194">194</span>
+<span id="L195" rel="#L195">195</span>
+<span id="L196" rel="#L196">196</span>
+<span id="L197" rel="#L197">197</span>
+<span id="L198" rel="#L198">198</span>
+<span id="L199" rel="#L199">199</span>
+<span id="L200" rel="#L200">200</span>
+<span id="L201" rel="#L201">201</span>
+<span id="L202" rel="#L202">202</span>
+<span id="L203" rel="#L203">203</span>
+<span id="L204" rel="#L204">204</span>
+<span id="L205" rel="#L205">205</span>
+<span id="L206" rel="#L206">206</span>
+<span id="L207" rel="#L207">207</span>
+<span id="L208" rel="#L208">208</span>
+<span id="L209" rel="#L209">209</span>
+<span id="L210" rel="#L210">210</span>
+<span id="L211" rel="#L211">211</span>
+<span id="L212" rel="#L212">212</span>
+<span id="L213" rel="#L213">213</span>
+<span id="L214" rel="#L214">214</span>
+<span id="L215" rel="#L215">215</span>
+<span id="L216" rel="#L216">216</span>
+<span id="L217" rel="#L217">217</span>
+<span id="L218" rel="#L218">218</span>
+<span id="L219" rel="#L219">219</span>
+<span id="L220" rel="#L220">220</span>
+<span id="L221" rel="#L221">221</span>
+<span id="L222" rel="#L222">222</span>
+<span id="L223" rel="#L223">223</span>
+<span id="L224" rel="#L224">224</span>
+<span id="L225" rel="#L225">225</span>
+<span id="L226" rel="#L226">226</span>
+<span id="L227" rel="#L227">227</span>
+<span id="L228" rel="#L228">228</span>
+<span id="L229" rel="#L229">229</span>
+<span id="L230" rel="#L230">230</span>
+<span id="L231" rel="#L231">231</span>
+
+ </td>
+ <td class="blob-line-code">
+ <div class="code-body highlight"><pre><div class='line' id='LC1'><span class="cp">&lt;?xml version=&quot;1.0&quot; encoding=&quot;UTF-8&quot;?&gt;</span></div><div class='line' id='LC2'><span class="c">&lt;!--</span></div><div class='line' id='LC3'><span class="c"> ~ Licensed to the Apache Software Foundation (ASF) under one or more</span></div><div class='line' id='LC4'><span class="c"> ~ contributor license agreements. See the NOTICE file distributed with</span></div><div class='line' id='LC5'><span class="c"> ~ this work for additional information regarding copyright ownership.</span></div><div class='line' id='LC6'><span class="c"> ~ The ASF licenses this file to You under the Apache License, Version 2.0</span></div><div class='line' id='LC7'><span class="c"> ~ (the &quot;License&quot;); you may not use this file except in compliance with</span></div><div class='line' id='LC8'><span class="c"> ~ the License. You may obtain a copy of the License at</span></div><div class='line' id='LC9'><span class="c"> ~</span></div><div class='line' id='LC10'><span class="c"> ~ http://www.apache.org/licenses/LICENSE-2.0</span></div><div class='line' id='LC11'><span class="c"> ~</span></div><div class='line' id='LC12'><span class="c"> ~ Unless required by applicable law or agreed to in writing, software</span></div><div class='line' id='LC13'><span class="c"> ~ distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span></div><div class='line' id='LC14'><span class="c"> ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span></div><div class='line' id='LC15'><span class="c"> ~ See the License for the specific language governing permissions and</span></div><div class='line' id='LC16'><span class="c"> ~ limitations under the License.</span></div><div class='line' id='LC17'><span class="c"> --&gt;</span></div><div class='line' id='LC18'><br/></div><div class='line' id='LC19'><span class="nt">&lt;project</span> <span class="na">xmlns=</span><span class="s">&quot;http://maven.apache.org/POM/4.0.0&quot;</span> <span class="na">xmlns:xsi=</span><span class="s">&quot;http://www.w3.org/2001/XMLSchema-instance&quot;</span> <span class="na">xsi:schemaLocation=</span><span class="s">&quot;http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd&quot;</span><span class="nt">&gt;</span></div><div class='line' id='LC20'>&nbsp;&nbsp;<span class="nt">&lt;modelVersion&gt;</span>4.0.0<span class="nt">&lt;/modelVersion&gt;</span></div><div class='line' id='LC21'>&nbsp;&nbsp;<span class="nt">&lt;parent&gt;</span></div><div class='line' id='LC22'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.apache.spark<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC23'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>spark-parent<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC24'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;version&gt;</span>0.9.0-incubating-SNAPSHOT<span class="nt">&lt;/version&gt;</span></div><div class='line' id='LC25'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;relativePath&gt;</span>../pom.xml<span class="nt">&lt;/relativePath&gt;</span></div><div class='line' id='LC26'>&nbsp;&nbsp;<span class="nt">&lt;/parent&gt;</span></div><div class='line' id='LC27'><br/></div><div class='line' id='LC28'>&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.apache.spark<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC29'>&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>spark-core_2.10<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC30'>&nbsp;&nbsp;<span class="nt">&lt;packaging&gt;</span>jar<span class="nt">&lt;/packaging&gt;</span></div><div class='line' id='LC31'>&nbsp;&nbsp;<span class="nt">&lt;name&gt;</span>Spark Project Core<span class="nt">&lt;/name&gt;</span></div><div class='line' id='LC32'>&nbsp;&nbsp;<span class="nt">&lt;url&gt;</span>http://spark.incubator.apache.org/<span class="nt">&lt;/url&gt;</span></div><div class='line' id='LC33'><br/></div><div class='line' id='LC34'>&nbsp;&nbsp;<span class="nt">&lt;dependencies&gt;</span></div><div class='line' id='LC35'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC36'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.apache.hadoop<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC37'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>hadoop-client<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC38'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC39'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC40'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>net.java.dev.jets3t<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC41'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>jets3t<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC42'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC43'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC44'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.apache.avro<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC45'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>avro<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC46'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC47'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC48'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.apache.avro<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC49'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>avro-ipc<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC50'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC51'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC52'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.apache.zookeeper<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC53'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>zookeeper<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC54'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC55'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC56'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.eclipse.jetty<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC57'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>jetty-server<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC58'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC59'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC60'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>com.google.guava<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC61'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>guava<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC62'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC63'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC64'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>com.google.code.findbugs<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC65'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>jsr305<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC66'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC67'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC68'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.slf4j<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC69'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>slf4j-api<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC70'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC71'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC72'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>com.ning<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC73'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>compress-lzf<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC74'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC75'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC76'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.xerial.snappy<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC77'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>snappy-java<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC78'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC79'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC80'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.ow2.asm<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC81'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>asm<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC82'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC83'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC84'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>com.twitter<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC85'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>chill_${scala.binary.version}<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC86'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;version&gt;</span>0.3.1<span class="nt">&lt;/version&gt;</span></div><div class='line' id='LC87'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC88'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC89'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>com.twitter<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC90'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>chill-java<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC91'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;version&gt;</span>0.3.1<span class="nt">&lt;/version&gt;</span></div><div class='line' id='LC92'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC93'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC94'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>${akka.group}<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC95'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>akka-remote_${scala.binary.version}<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC96'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC97'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC98'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>${akka.group}<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC99'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>akka-slf4j_${scala.binary.version}<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC100'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC101'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC102'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.scala-lang<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC103'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>scala-library<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC104'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC105'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC106'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>net.liftweb<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC107'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>lift-json_${scala.binary.version}<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC108'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC109'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC110'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>it.unimi.dsi<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC111'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>fastutil<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC112'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC113'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC114'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>colt<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC115'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>colt<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC116'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC117'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC118'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.apache.mesos<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC119'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>mesos<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC120'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC121'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC122'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>io.netty<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC123'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>netty-all<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC124'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC125'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC126'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>log4j<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC127'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>log4j<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC128'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC129'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC130'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>com.codahale.metrics<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC131'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>metrics-core<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC132'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC133'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC134'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>com.codahale.metrics<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC135'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>metrics-jvm<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC136'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC137'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC138'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>com.codahale.metrics<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC139'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>metrics-json<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC140'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC141'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC142'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>com.codahale.metrics<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC143'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>metrics-ganglia<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC144'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC145'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC146'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>com.codahale.metrics<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC147'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>metrics-graphite<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC148'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC149'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC150'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.apache.derby<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC151'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>derby<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC152'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;scope&gt;</span>test<span class="nt">&lt;/scope&gt;</span></div><div class='line' id='LC153'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC154'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC155'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>commons-io<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC156'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>commons-io<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC157'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;scope&gt;</span>test<span class="nt">&lt;/scope&gt;</span></div><div class='line' id='LC158'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC159'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC160'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.scalatest<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC161'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>scalatest_${scala.binary.version}<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC162'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;scope&gt;</span>test<span class="nt">&lt;/scope&gt;</span></div><div class='line' id='LC163'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC164'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC165'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.scalacheck<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC166'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>scalacheck_${scala.binary.version}<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC167'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;scope&gt;</span>test<span class="nt">&lt;/scope&gt;</span></div><div class='line' id='LC168'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC169'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC170'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.easymock<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC171'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>easymock<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC172'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;scope&gt;</span>test<span class="nt">&lt;/scope&gt;</span></div><div class='line' id='LC173'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC174'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC175'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>com.novocode<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC176'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>junit-interface<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC177'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;scope&gt;</span>test<span class="nt">&lt;/scope&gt;</span></div><div class='line' id='LC178'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC179'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;dependency&gt;</span></div><div class='line' id='LC180'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.slf4j<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC181'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>slf4j-log4j12<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC182'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;scope&gt;</span>test<span class="nt">&lt;/scope&gt;</span></div><div class='line' id='LC183'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/dependency&gt;</span></div><div class='line' id='LC184'>&nbsp;&nbsp;<span class="nt">&lt;/dependencies&gt;</span></div><div class='line' id='LC185'>&nbsp;&nbsp;<span class="nt">&lt;build&gt;</span></div><div class='line' id='LC186'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;outputDirectory&gt;</span>target/scala-${scala.binary.version}/classes<span class="nt">&lt;/outputDirectory&gt;</span></div><div class='line' id='LC187'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;testOutputDirectory&gt;</span>target/scala-${scala.binary.version}/test-classes<span class="nt">&lt;/testOutputDirectory&gt;</span></div><div class='line' id='LC188'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;plugins&gt;</span></div><div class='line' id='LC189'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;plugin&gt;</span></div><div class='line' id='LC190'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.apache.maven.plugins<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC191'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>maven-antrun-plugin<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC192'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;executions&gt;</span></div><div class='line' id='LC193'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;execution&gt;</span></div><div class='line' id='LC194'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;phase&gt;</span>test<span class="nt">&lt;/phase&gt;</span></div><div class='line' id='LC195'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;goals&gt;</span></div><div class='line' id='LC196'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;goal&gt;</span>run<span class="nt">&lt;/goal&gt;</span></div><div class='line' id='LC197'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/goals&gt;</span></div><div class='line' id='LC198'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;configuration&gt;</span></div><div class='line' id='LC199'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;exportAntProperties&gt;</span>true<span class="nt">&lt;/exportAntProperties&gt;</span></div><div class='line' id='LC200'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;tasks&gt;</span></div><div class='line' id='LC201'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;property</span> <span class="na">name=</span><span class="s">&quot;spark.classpath&quot;</span> <span class="na">refid=</span><span class="s">&quot;maven.test.classpath&quot;</span> <span class="nt">/&gt;</span></div><div class='line' id='LC202'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;property</span> <span class="na">environment=</span><span class="s">&quot;env&quot;</span> <span class="nt">/&gt;</span></div><div class='line' id='LC203'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;fail</span> <span class="na">message=</span><span class="s">&quot;Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry.&quot;</span><span class="nt">&gt;</span></div><div class='line' id='LC204'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;condition&gt;</span></div><div class='line' id='LC205'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;not&gt;</span></div><div class='line' id='LC206'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;or&gt;</span></div><div class='line' id='LC207'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;isset</span> <span class="na">property=</span><span class="s">&quot;env.SCALA_HOME&quot;</span> <span class="nt">/&gt;</span></div><div class='line' id='LC208'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;isset</span> <span class="na">property=</span><span class="s">&quot;env.SCALA_LIBRARY_PATH&quot;</span> <span class="nt">/&gt;</span></div><div class='line' id='LC209'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/or&gt;</span></div><div class='line' id='LC210'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/not&gt;</span></div><div class='line' id='LC211'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/condition&gt;</span></div><div class='line' id='LC212'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/fail&gt;</span></div><div class='line' id='LC213'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/tasks&gt;</span></div><div class='line' id='LC214'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/configuration&gt;</span></div><div class='line' id='LC215'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/execution&gt;</span></div><div class='line' id='LC216'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/executions&gt;</span></div><div class='line' id='LC217'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/plugin&gt;</span></div><div class='line' id='LC218'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;plugin&gt;</span></div><div class='line' id='LC219'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;groupId&gt;</span>org.scalatest<span class="nt">&lt;/groupId&gt;</span></div><div class='line' id='LC220'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;artifactId&gt;</span>scalatest-maven-plugin<span class="nt">&lt;/artifactId&gt;</span></div><div class='line' id='LC221'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;configuration&gt;</span></div><div class='line' id='LC222'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;environmentVariables&gt;</span></div><div class='line' id='LC223'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;SPARK_HOME&gt;</span>${basedir}/..<span class="nt">&lt;/SPARK_HOME&gt;</span></div><div class='line' id='LC224'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;SPARK_TESTING&gt;</span>1<span class="nt">&lt;/SPARK_TESTING&gt;</span></div><div class='line' id='LC225'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;SPARK_CLASSPATH&gt;</span>${spark.classpath}<span class="nt">&lt;/SPARK_CLASSPATH&gt;</span></div><div class='line' id='LC226'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/environmentVariables&gt;</span></div><div class='line' id='LC227'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/configuration&gt;</span></div><div class='line' id='LC228'>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/plugin&gt;</span></div><div class='line' id='LC229'>&nbsp;&nbsp;&nbsp;&nbsp;<span class="nt">&lt;/plugins&gt;</span></div><div class='line' id='LC230'>&nbsp;&nbsp;<span class="nt">&lt;/build&gt;</span></div><div class='line' id='LC231'><span class="nt">&lt;/project&gt;</span></div></pre></div>
+ </td>
+ </tr>
+ </table>
+ </div>
+
+ </div>
+</div>
+
+<a href="#jump-to-line" rel="facebox[.linejump]" data-hotkey="l" class="js-jump-to-line" style="display:none">Jump to Line</a>
+<div id="jump-to-line" style="display:none">
+ <form accept-charset="UTF-8" class="js-jump-to-line-form">
+ <input class="linejump-input js-jump-to-line-field" type="text" placeholder="Jump to line&hellip;" autofocus>
+ <button type="submit" class="button">Go</button>
+ </form>
+</div>
+
+ </div>
+
+ </div><!-- /.repo-container -->
+ <div class="modal-backdrop"></div>
+ </div><!-- /.container -->
+ </div><!-- /.site -->
+
+
+ </div><!-- /.wrapper -->
+
+ <div class="container">
+ <div class="site-footer">
+ <ul class="site-footer-links right">
+ <li><a href="https://status.github.com/">Status</a></li>
+ <li><a href="http://developer.github.com">API</a></li>
+ <li><a href="http://training.github.com">Training</a></li>
+ <li><a href="http://shop.github.com">Shop</a></li>
+ <li><a href="/blog">Blog</a></li>
+ <li><a href="/about">About</a></li>
+
+ </ul>
+
+ <a href="/">
+ <span class="mega-octicon octicon-mark-github"></span>
+ </a>
+
+ <ul class="site-footer-links">
+ <li>&copy; 2013 <span title="0.02474s from github-fe126-cp1-prd.iad.github.net">GitHub</span>, Inc.</li>
+ <li><a href="/site/terms">Terms</a></li>
+ <li><a href="/site/privacy">Privacy</a></li>
+ <li><a href="/security">Security</a></li>
+ <li><a href="/contact">Contact</a></li>
+ </ul>
+ </div><!-- /.site-footer -->
+</div><!-- /.container -->
+
+
+ <div class="fullscreen-overlay js-fullscreen-overlay" id="fullscreen_overlay">
+ <div class="fullscreen-container js-fullscreen-container">
+ <div class="textarea-wrap">
+ <textarea name="fullscreen-contents" id="fullscreen-contents" class="js-fullscreen-contents" placeholder="" data-suggester="fullscreen_suggester"></textarea>
+ <div class="suggester-container">
+ <div class="suggester fullscreen-suggester js-navigation-container" id="fullscreen_suggester"
+ data-url="/apache/incubator-spark/suggestions/commit">
+ </div>
+ </div>
+ </div>
+ </div>
+ <div class="fullscreen-sidebar">
+ <a href="#" class="exit-fullscreen js-exit-fullscreen tooltipped leftwards" title="Exit Zen Mode">
+ <span class="mega-octicon octicon-screen-normal"></span>
+ </a>
+ <a href="#" class="theme-switcher js-theme-switcher tooltipped leftwards"
+ title="Switch themes">
+ <span class="octicon octicon-color-mode"></span>
+ </a>
+ </div>
+</div>
+
+
+
+ <div id="ajax-error-message" class="flash flash-error">
+ <span class="octicon octicon-alert"></span>
+ <a href="#" class="octicon octicon-remove-close close ajax-error-dismiss"></a>
+ Something went wrong with that request. Please try again.
+ </div>
+
+ </body>
+</html>
+
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java
index 20a7a3aa8c..edd0fc56f8 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileClient.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileClient.java
@@ -19,8 +19,6 @@ package org.apache.spark.network.netty;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
-import io.netty.channel.ChannelFuture;
-import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioSocketChannel;
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java
index 666432474d..a99af348ce 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileServer.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileServer.java
@@ -20,7 +20,6 @@ package org.apache.spark.network.netty;
import java.net.InetSocketAddress;
import io.netty.bootstrap.ServerBootstrap;
-import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
index cfd8132891..172c6e4b1c 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
@@ -25,6 +25,7 @@ import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.FileSegment;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@@ -37,40 +38,34 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@Override
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
BlockId blockId = BlockId.apply(blockIdString);
- String path = pResolver.getAbsolutePath(blockId.name());
- // if getFilePath returns null, close the channel
- if (path == null) {
+ FileSegment fileSegment = pResolver.getBlockLocation(blockId);
+ // if getBlockLocation returns null, close the channel
+ if (fileSegment == null) {
//ctx.close();
return;
}
- File file = new File(path);
+ File file = fileSegment.file();
if (file.exists()) {
if (!file.isFile()) {
- //logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
- long length = file.length();
+ long length = fileSegment.length();
if (length > Integer.MAX_VALUE || length <= 0) {
- //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
- //logger.info("Sending block "+blockId+" filelen = "+len);
- //logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
- .getChannel(), 0, file.length()));
+ .getChannel(), fileSegment.offset(), fileSegment.length()));
} catch (Exception e) {
- //logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
- //logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();
diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
index 94c034cad0..9f7ced44cf 100755
--- a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
+++ b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
@@ -17,13 +17,10 @@
package org.apache.spark.network.netty;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.FileSegment;
public interface PathResolver {
- /**
- * Get the absolute path of the file
- *
- * @param fileId
- * @return the absolute path of file
- */
- public String getAbsolutePath(String fileId);
+ /** Get the file segment in which the given block resides. */
+ public FileSegment getBlockLocation(BlockId blockId);
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
index f87460039b..0c47afae54 100644
--- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
@@ -17,20 +17,29 @@
package org.apache.hadoop.mapred
+private[apache]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext");
- val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID])
+ val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
+ "org.apache.hadoop.mapred.JobContext")
+ val ctor = klass.getDeclaredConstructor(classOf[JobConf],
+ classOf[org.apache.hadoop.mapreduce.JobID])
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext")
+ val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
+ "org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
- def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
+ def newTaskAttemptID(
+ jtIdentifier: String,
+ jobId: Int,
+ isMap: Boolean,
+ taskId: Int,
+ attemptId: Int) = {
new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
index 93180307fa..32429f01ac 100644
--- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
+++ b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
@@ -17,9 +17,10 @@
package org.apache.hadoop.mapreduce
-import org.apache.hadoop.conf.Configuration
import java.lang.{Integer => JInteger, Boolean => JBoolean}
+import org.apache.hadoop.conf.Configuration
+private[apache]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
@@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil {
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
- def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
- val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID");
+ def newTaskAttemptID(
+ jtIdentifier: String,
+ jobId: Int,
+ isMap: Boolean,
+ taskId: Int,
+ attemptId: Int) = {
+ val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
- // first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN)
+ // First, attempt to use the old-style constructor that takes a boolean isMap
+ // (not available in YARN)
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean],
- classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new
- JInteger(attemptId)).asInstanceOf[TaskAttemptID]
+ classOf[Int], classOf[Int])
+ ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId),
+ new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
} catch {
case exc: NoSuchMethodException => {
- // failed, look for the new ctor that takes a TaskType (not available in 1.x)
- val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]]
- val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE")
+ // If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
+ val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
+ .asInstanceOf[Class[Enum[_]]]
+ val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
+ taskTypeClass, if(isMap) "MAP" else "REDUCE")
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass,
classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new
- JInteger(attemptId)).asInstanceOf[TaskAttemptID]
+ ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId),
+ new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 1ad9240cfa..c6b4ac5192 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -99,7 +99,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
if (!atMost.isFinite()) {
awaitResult()
- } else {
+ } else jobWaiter.synchronized {
val finishTime = System.currentTimeMillis() + atMost.toMillis
while (!isCompleted) {
val time = System.currentTimeMillis()
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 1e3f1ebfaf..ccffcc356c 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -20,31 +20,28 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
+import scala.concurrent.Await
+import scala.concurrent.duration._
import akka.actor._
-import akka.dispatch._
import akka.pattern.ask
-import akka.remote._
-import akka.util.Duration
-
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashMap}
-
+import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
-private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
+private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
+ extends Actor with Logging {
def receive = {
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
- sender ! tracker.getSerializedLocations(shuffleId)
+ sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
@@ -55,30 +52,37 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
private[spark] class MapOutputTracker extends Logging {
- private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
-
+ private val timeout = AkkaUtils.askTimeout
+
// Set to the MapOutputTrackerActor living on the driver
- var trackerActor: ActorRef = _
+ var trackerActor: Either[ActorRef, ActorSelection] = _
- private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
- private var epoch: Long = 0
- private val epochLock = new java.lang.Object
+ protected var epoch: Long = 0
+ protected val epochLock = new java.lang.Object
- // Cache a serialized version of the output statuses for each shuffle to send them out faster
- var cacheEpoch = epoch
- private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
-
- val metadataCleaner = new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
+ private val metadataCleaner =
+ new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
- def askTracker(message: Any): Any = {
+ private def askTracker(message: Any): Any = {
try {
- val future = trackerActor.ask(message)(timeout)
- return Await.result(future, timeout)
+ /*
+ The difference between ActorRef and ActorSelection is well explained here:
+ http://doc.akka.io/docs/akka/2.2.3/project/migration-guide-2.1.x-2.2.x.html#Use_actorSelection_instead_of_actorFor
+ In spark a map output tracker can be either started on Driver where it is created which
+ is an ActorRef or it can be on executor from where it is looked up which is an
+ actorSelection.
+ */
+ val future = trackerActor match {
+ case Left(a: ActorRef) => a.ask(message)(timeout)
+ case Right(b: ActorSelection) => b.ask(message)(timeout)
+ }
+ Await.result(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error communicating with MapOutputTracker", e)
@@ -86,50 +90,12 @@ private[spark] class MapOutputTracker extends Logging {
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
- def communicate(message: Any) {
+ private def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
- def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
- throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
- }
- }
-
- def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
- var array = mapStatuses(shuffleId)
- array.synchronized {
- array(mapId) = status
- }
- }
-
- def registerMapOutputs(
- shuffleId: Int,
- statuses: Array[MapStatus],
- changeEpoch: Boolean = false) {
- mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
- if (changeEpoch) {
- incrementEpoch()
- }
- }
-
- def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var arrayOpt = mapStatuses.get(shuffleId)
- if (arrayOpt.isDefined && arrayOpt.get != null) {
- var array = arrayOpt.get
- array.synchronized {
- if (array(mapId) != null && array(mapId).location == bmAddress) {
- array(mapId) = null
- }
- }
- incrementEpoch()
- } else {
- throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
- }
- }
-
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]
@@ -159,7 +125,7 @@ private[spark] class MapOutputTracker extends Logging {
fetching += shuffleId
}
}
-
+
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
@@ -168,7 +134,7 @@ private[spark] class MapOutputTracker extends Logging {
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
- fetchedStatuses = deserializeStatuses(fetchedBytes)
+ fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
@@ -186,7 +152,7 @@ private[spark] class MapOutputTracker extends Logging {
else{
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
- }
+ }
} else {
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
@@ -194,9 +160,8 @@ private[spark] class MapOutputTracker extends Logging {
}
}
- private def cleanup(cleanupTime: Long) {
+ protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
- cachedSerializedStatuses.clearOldValues(cleanupTime)
}
def stop() {
@@ -206,15 +171,7 @@ private[spark] class MapOutputTracker extends Logging {
trackerActor = null
}
- // Called on master to increment the epoch number
- def incrementEpoch() {
- epochLock.synchronized {
- epoch += 1
- logDebug("Increasing epoch to " + epoch)
- }
- }
-
- // Called on master or workers to get current epoch number
+ // Called to get current epoch number
def getEpoch: Long = {
epochLock.synchronized {
return epoch
@@ -228,14 +185,62 @@ private[spark] class MapOutputTracker extends Logging {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
- // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
- mapStatuses.clear()
epoch = newEpoch
+ mapStatuses.clear()
+ }
+ }
+ }
+}
+
+private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
+
+ // Cache a serialized version of the output statuses for each shuffle to send them out faster
+ private var cacheEpoch = epoch
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+
+ def registerShuffle(shuffleId: Int, numMaps: Int) {
+ if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+ throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+ }
+ }
+
+ def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
+ val array = mapStatuses(shuffleId)
+ array.synchronized {
+ array(mapId) = status
+ }
+ }
+
+ def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
+ mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
+ if (changeEpoch) {
+ incrementEpoch()
+ }
+ }
+
+ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ val arrayOpt = mapStatuses.get(shuffleId)
+ if (arrayOpt.isDefined && arrayOpt.get != null) {
+ val array = arrayOpt.get
+ array.synchronized {
+ if (array(mapId) != null && array(mapId).location == bmAddress) {
+ array(mapId) = null
+ }
}
+ incrementEpoch()
+ } else {
+ throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
+ }
+ }
+
+ def incrementEpoch() {
+ epochLock.synchronized {
+ epoch += 1
+ logDebug("Increasing epoch to " + epoch)
}
}
- def getSerializedLocations(shuffleId: Int): Array[Byte] = {
+ def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
@@ -247,13 +252,13 @@ private[spark] class MapOutputTracker extends Logging {
case Some(bytes) =>
return bytes
case None =>
- statuses = mapStatuses(shuffleId)
+ statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
epochGotten = epoch
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
- // out a snapshot of the locations as "locs"; let's serialize and return that
- val bytes = serializeStatuses(statuses)
+ // out a snapshot of the locations as "statuses"; let's serialize and return that
+ val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
@@ -261,13 +266,35 @@ private[spark] class MapOutputTracker extends Logging {
cachedSerializedStatuses(shuffleId) = bytes
}
}
- return bytes
+ bytes
}
+ protected override def cleanup(cleanupTime: Long) {
+ super.cleanup(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
+ }
+
+ override def stop() {
+ super.stop()
+ cachedSerializedStatuses.clear()
+ }
+
+ override def updateEpoch(newEpoch: Long) {
+ // This might be called on the MapOutputTrackerMaster if we're running in local mode.
+ }
+
+ def has(shuffleId: Int): Boolean = {
+ cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId)
+ }
+}
+
+private[spark] object MapOutputTracker {
+ private val LOG_BASE = 1.1
+
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
@@ -278,18 +305,11 @@ private[spark] class MapOutputTracker extends Logging {
out.toByteArray
}
- // Opposite of serializeStatuses.
- def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
+ // Opposite of serializeMapStatuses.
+ def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
- objIn.readObject().
- // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
- // comment this out - nulls could be due to missing location ?
- asInstanceOf[Array[MapStatus]] // .filter( _ != null )
+ objIn.readObject().asInstanceOf[Array[MapStatus]]
}
-}
-
-private[spark] object MapOutputTracker {
- private val LOG_BASE = 1.1
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
@@ -300,7 +320,7 @@ private[spark] object MapOutputTracker {
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
assert (statuses != null)
statuses.map {
- status =>
+ status =>
if (status == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 0e2c987a59..bcec41c439 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -17,8 +17,10 @@
package org.apache.spark
-import org.apache.spark.util.Utils
+import scala.reflect.ClassTag
+
import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
/**
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
@@ -72,7 +74,7 @@ class HashPartitioner(partitions: Int) extends Partitioner {
case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
}
-
+
override def equals(other: Any): Boolean = other match {
case h: HashPartitioner =>
h.numPartitions == numPartitions
@@ -85,7 +87,7 @@ class HashPartitioner(partitions: Int) extends Partitioner {
* A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly equal ranges.
* Determines the ranges by sampling the RDD passed in.
*/
-class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
+class RangePartitioner[K <% Ordered[K]: ClassTag, V](
partitions: Int,
@transient rdd: RDD[_ <: Product2[K,V]],
private val ascending: Boolean = true)
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 0aafc0a2fc..a0f794edfd 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -24,9 +24,9 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.Map
import scala.collection.generic.Growable
-import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.reflect.{ClassTag, classTag}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
@@ -51,25 +51,19 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor
import org.apache.mesos.MesosNativeLibrary
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.LocalSparkCluster
+import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend,
- ClusterScheduler}
-import org.apache.spark.scheduler.local.LocalScheduler
+import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
+ SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import org.apache.spark.storage.{StorageUtils, BlockManagerSource}
-import org.apache.spark.ui.SparkUI
-import org.apache.spark.util._
-import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.storage.RDDInfo
-import org.apache.spark.storage.StorageStatus
-import scala.Some
+import org.apache.spark.scheduler.local.LocalScheduler
import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.storage.RDDInfo
-import org.apache.spark.storage.StorageStatus
+import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
+import org.apache.spark.ui.SparkUI
+import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
+ TimeStampedHashMap, Utils}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -88,7 +82,7 @@ class SparkContext(
val sparkHome: String = null,
val jars: Seq[String] = Nil,
val environment: Map[String, String] = Map(),
- // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc)
+ // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, etc)
// too. This is typically generated from InputFormatInfo.computePreferredLocations .. host, set
// of data-local splits on host
val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] =
@@ -125,7 +119,7 @@ class SparkContext(
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup)
- // Initalize the Spark UI
+ // Initialize the Spark UI
private[spark] val ui = new SparkUI(this)
ui.bind()
@@ -151,89 +145,16 @@ class SparkContext(
executorEnvs ++= environment
}
- // Create and start the scheduler
- private[spark] var taskScheduler: TaskScheduler = {
- // Regular expression used for local[N] master format
- val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
- // Regular expression for local[N, maxRetries], used in tests with failing tasks
- val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
- // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
- val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
- // Regular expression for connecting to Spark deploy clusters
- val SPARK_REGEX = """spark://(.*)""".r
- //Regular expression for connection to Mesos cluster
- val MESOS_REGEX = """(mesos://.*)""".r
-
- master match {
- case "local" =>
- new LocalScheduler(1, 0, this)
-
- case LOCAL_N_REGEX(threads) =>
- new LocalScheduler(threads.toInt, 0, this)
-
- case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
- new LocalScheduler(threads.toInt, maxFailures.toInt, this)
-
- case SPARK_REGEX(sparkUrl) =>
- val scheduler = new ClusterScheduler(this)
- val masterUrls = sparkUrl.split(",").map("spark://" + _)
- val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
- scheduler.initialize(backend)
- scheduler
-
- case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
- // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
- val memoryPerSlaveInt = memoryPerSlave.toInt
- if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) {
- throw new SparkException(
- "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
- memoryPerSlaveInt, SparkContext.executorMemoryRequested))
- }
-
- val scheduler = new ClusterScheduler(this)
- val localCluster = new LocalSparkCluster(
- numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
- val masterUrls = localCluster.start()
- val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
- scheduler.initialize(backend)
- backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
- localCluster.stop()
- }
- scheduler
-
- case "yarn-standalone" =>
- val scheduler = try {
- val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
- val cons = clazz.getConstructor(classOf[SparkContext])
- cons.newInstance(this).asInstanceOf[ClusterScheduler]
- } catch {
- // TODO: Enumerate the exact reasons why it can fail
- // But irrespective of it, it means we cannot proceed !
- case th: Throwable => {
- throw new SparkException("YARN mode not available ?", th)
- }
- }
- val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem)
- scheduler.initialize(backend)
- scheduler
-
- case _ =>
- if (MESOS_REGEX.findFirstIn(master).isEmpty) {
- logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
- }
- MesosNativeLibrary.load()
- val scheduler = new ClusterScheduler(this)
- val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
- val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
- val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
- } else {
- new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
- }
- scheduler.initialize(backend)
- scheduler
- }
+ // Set SPARK_USER for user who is running SparkContext.
+ val sparkUser = Option {
+ Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER"))
+ }.getOrElse {
+ SparkContext.SPARK_UNKNOWN_USER
}
+ executorEnvs("SPARK_USER") = sparkUser
+
+ // Create and start the scheduler
+ private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master, appName)
taskScheduler.start()
@volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
@@ -244,7 +165,7 @@ class SparkContext(
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
val env = SparkEnv.get
- val conf = env.hadoop.newConfiguration()
+ val conf = SparkHadoopUtil.get.newConfiguration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
@@ -254,8 +175,10 @@ class SparkContext(
conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
- for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
- conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
+ Utils.getSystemProperties.foreach { case (key, value) =>
+ if (key.startsWith("spark.hadoop.")) {
+ conf.set(key.substring("spark.hadoop.".length), value)
+ }
}
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
@@ -269,6 +192,12 @@ class SparkContext(
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
+ private[spark] def getLocalProperties(): Properties = localProperties.get()
+
+ private[spark] def setLocalProperties(props: Properties) {
+ localProperties.set(props)
+ }
+
def initLocalProperties() {
localProperties.set(new Properties())
}
@@ -288,15 +217,46 @@ class SparkContext(
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
/** Set a human readable description of the current job. */
+ @deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
}
+ /**
+ * Assigns a group id to all the jobs started by this thread until the group id is set to a
+ * different value or cleared.
+ *
+ * Often, a unit of execution in an application consists of multiple Spark actions or jobs.
+ * Application programmers can use this method to group all those jobs together and give a
+ * group description. Once set, the Spark web UI will associate such jobs with this group.
+ *
+ * The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all
+ * running jobs in this group. For example,
+ * {{{
+ * // In the main thread:
+ * sc.setJobGroup("some_job_to_cancel", "some job description")
+ * sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
+ *
+ * // In a separate thread:
+ * sc.cancelJobGroup("some_job_to_cancel")
+ * }}}
+ */
+ def setJobGroup(groupId: String, description: String) {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
+ }
+
+ /** Clear the job group id and its description. */
+ def clearJobGroup() {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
+ }
+
// Post init
taskScheduler.postStartHook()
- val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
- val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
+ private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
+ private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
@@ -308,19 +268,19 @@ class SparkContext(
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
- def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
/** Distribute a local Scala collection to form an RDD. */
- def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ def makeRDD[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
parallelize(seq, numSlices)
}
/** Distribute a local Scala collection to form an RDD, with one or more
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
- def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
@@ -347,7 +307,7 @@ class SparkContext(
minSplits: Int = defaultMinSplits
): RDD[(K, V)] = {
// Add necessary security credentials to the JobConf before broadcasting it.
- SparkEnv.get.hadoop.addCredentials(conf)
+ SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
@@ -373,7 +333,7 @@ class SparkContext(
}
/**
- * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly. Instead, callers
* can just write, for example,
* {{{
@@ -381,17 +341,17 @@ class SparkContext(
* }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
- (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
+ (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F])
: RDD[(K, V)] = {
hadoopFile(path,
- fm.erasure.asInstanceOf[Class[F]],
- km.erasure.asInstanceOf[Class[K]],
- vm.erasure.asInstanceOf[Class[V]],
+ fm.runtimeClass.asInstanceOf[Class[F]],
+ km.runtimeClass.asInstanceOf[Class[K]],
+ vm.runtimeClass.asInstanceOf[Class[V]],
minSplits)
}
/**
- * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly. Instead, callers
* can just write, for example,
* {{{
@@ -399,17 +359,17 @@ class SparkContext(
* }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
- (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] =
+ (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] =
hadoopFile[K, V, F](path, defaultMinSplits)
/** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String)
- (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = {
+ (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = {
newAPIHadoopFile(
path,
- fm.erasure.asInstanceOf[Class[F]],
- km.erasure.asInstanceOf[Class[K]],
- vm.erasure.asInstanceOf[Class[V]])
+ fm.runtimeClass.asInstanceOf[Class[F]],
+ km.runtimeClass.asInstanceOf[Class[K]],
+ vm.runtimeClass.asInstanceOf[Class[V]])
}
/**
@@ -467,11 +427,11 @@ class SparkContext(
* IntWritable). The most natural thing would've been to have implicit objects for the
* converters, but then we couldn't have an object for every subclass of Writable (you can't
* have a parameterized singleton object). We use functions instead to create a new converter
- * for the appropriate type. In addition, we pass the converter a ClassManifest of its type to
+ * for the appropriate type. In addition, we pass the converter a ClassTag of its type to
* allow it to figure out the Writable class to use in the subclass case.
*/
def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits)
- (implicit km: ClassManifest[K], vm: ClassManifest[V],
+ (implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
val kc = kcf()
@@ -490,7 +450,7 @@ class SparkContext(
* slow if you use the default serializer (Java serialization), though the nice thing about it is
* that there's very little effort required to save arbitrary objects.
*/
- def objectFile[T: ClassManifest](
+ def objectFile[T: ClassTag](
path: String,
minSplits: Int = defaultMinSplits
): RDD[T] = {
@@ -499,17 +459,17 @@ class SparkContext(
}
- protected[spark] def checkpointFile[T: ClassManifest](
+ protected[spark] def checkpointFile[T: ClassTag](
path: String
): RDD[T] = {
new CheckpointRDD[T](this, path)
}
/** Build the union of a list of RDDs. */
- def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
+ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
/** Build the union of a list of RDDs passed as variable-length arguments. */
- def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] =
+ def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] =
new UnionRDD(this, Seq(first) ++ rest)
// Methods for creating shared variables
@@ -557,7 +517,8 @@ class SparkContext(
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
- case _ => path
+ case "local" => "file:" + uri.getPath
+ case _ => path
}
addedFiles(key) = System.currentTimeMillis
@@ -651,12 +612,11 @@ class SparkContext(
/**
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node.
*/
def addJar(path: String) {
if (path == null) {
- logWarning("null specified as parameter to addJar",
- new SparkException("null specified as parameter to addJar"))
+ logWarning("null specified as parameter to addJar")
} else {
var key = ""
if (path.contains("\\")) {
@@ -665,8 +625,9 @@ class SparkContext(
} else {
val uri = new URI(path)
key = uri.getScheme match {
+ // A JAR file which exists only on the driver node
case null | "file" =>
- if (env.hadoop.isYarnMode()) {
+ if (SparkHadoopUtil.get.isYarnMode()) {
// In order for this to work on yarn the user must specify the --addjars option to
// the client to upload the file into the distributed cache to make it show up in the
// current working directory.
@@ -682,6 +643,9 @@ class SparkContext(
} else {
env.httpFileServer.addJar(new File(uri.getPath))
}
+ // A JAR file which exists locally on every worker node
+ case "local" =>
+ "file:" + uri.getPath
case _ =>
path
}
@@ -748,7 +712,7 @@ class SparkContext(
* flag specifies whether the scheduler can run the computation on the driver rather than
* shipping it out to the cluster, for short actions like first().
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
@@ -758,11 +722,10 @@ class SparkContext(
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
+ dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
resultHandler, localProperties.get)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
- result
}
/**
@@ -770,7 +733,7 @@ class SparkContext(
* allowLocal flag specifies whether the scheduler can run the computation on the driver rather
* than shipping it out to the cluster, for short actions like first().
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
@@ -785,7 +748,7 @@ class SparkContext(
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: Iterator[T] => U,
partitions: Seq[Int],
@@ -797,21 +760,21 @@ class SparkContext(
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
- def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
+ def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
runJob(rdd, func, 0 until rdd.partitions.size, false)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
- def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
+ def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
runJob(rdd, func, 0 until rdd.partitions.size, false)
}
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
processPartition: (TaskContext, Iterator[T]) => U,
resultHandler: (Int, U) => Unit)
@@ -822,7 +785,7 @@ class SparkContext(
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
processPartition: Iterator[T] => U,
resultHandler: (Int, U) => Unit)
@@ -867,13 +830,19 @@ class SparkContext(
callSite,
allowLocal = false,
resultHandler,
- null)
+ localProperties.get)
new SimpleFutureAction(waiter, resultFunc)
}
/**
- * Cancel all jobs that have been scheduled or are running.
+ * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]]
+ * for more information.
*/
+ def cancelJobGroup(groupId: String) {
+ dagScheduler.cancelJobGroup(groupId)
+ }
+
+ /** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
dagScheduler.cancelAllJobs()
}
@@ -895,9 +864,8 @@ class SparkContext(
* prevent accidental overriding of checkpoint files in the existing directory.
*/
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
- val env = SparkEnv.get
val path = new Path(dir)
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
if (!useExisting) {
if (fs.exists(path)) {
throw new Exception("Checkpoint directory '" + path + "' already exists.")
@@ -934,7 +902,12 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
- val SPARK_JOB_DESCRIPTION = "spark.job.description"
+
+ private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
+
+ private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
+
+ private[spark] val SPARK_UNKNOWN_USER = "<unknown>"
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
@@ -958,16 +931,16 @@ object SparkContext {
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
- implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
+ implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
- implicit def rddToAsyncRDDActions[T: ClassManifest](rdd: RDD[T]) = new AsyncRDDActions(rdd)
+ implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
- implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
+ implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
- implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
+ implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassTag, V: ClassTag](
rdd: RDD[(K, V)]) =
new OrderedRDDFunctions[K, V, (K, V)](rdd)
@@ -992,16 +965,16 @@ object SparkContext {
implicit def stringToText(s: String) = new Text(s)
- private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = {
+ private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]): ArrayWritable = {
def anyToWritable[U <% Writable](u: U): Writable = u
- new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]],
+ new ArrayWritable(classTag[T].runtimeClass.asInstanceOf[Class[Writable]],
arr.map(x => anyToWritable(x)).toArray)
}
// Helper objects for converting common types to Writable
- private def simpleWritableConverter[T, W <: Writable: ClassManifest](convert: W => T) = {
- val wClass = classManifest[W].erasure.asInstanceOf[Class[W]]
+ private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) = {
+ val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]]
new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
}
@@ -1020,7 +993,7 @@ object SparkContext {
implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString)
implicit def writableWritableConverter[T <: Writable]() =
- new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T])
+ new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T])
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
@@ -1052,17 +1025,135 @@ object SparkContext {
.map(Utils.memoryStringToMb)
.getOrElse(512)
}
+
+ // Creates a task scheduler based on a given master URL. Extracted for testing.
+ private
+ def createTaskScheduler(sc: SparkContext, master: String, appName: String): TaskScheduler = {
+ // Regular expression used for local[N] master format
+ val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
+ // Regular expression for local[N, maxRetries], used in tests with failing tasks
+ val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
+ // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
+ val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
+ // Regular expression for connecting to Spark deploy clusters
+ val SPARK_REGEX = """spark://(.*)""".r
+ // Regular expression for connection to Mesos cluster by mesos:// or zk:// url
+ val MESOS_REGEX = """(mesos|zk)://.*""".r
+ // Regular expression for connection to Simr cluster
+ val SIMR_REGEX = """simr://(.*)""".r
+
+ master match {
+ case "local" =>
+ new LocalScheduler(1, 0, sc)
+
+ case LOCAL_N_REGEX(threads) =>
+ new LocalScheduler(threads.toInt, 0, sc)
+
+ case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
+ new LocalScheduler(threads.toInt, maxFailures.toInt, sc)
+
+ case SPARK_REGEX(sparkUrl) =>
+ val scheduler = new ClusterScheduler(sc)
+ val masterUrls = sparkUrl.split(",").map("spark://" + _)
+ val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName)
+ scheduler.initialize(backend)
+ scheduler
+
+ case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
+ // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
+ val memoryPerSlaveInt = memoryPerSlave.toInt
+ if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) {
+ throw new SparkException(
+ "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
+ memoryPerSlaveInt, SparkContext.executorMemoryRequested))
+ }
+
+ val scheduler = new ClusterScheduler(sc)
+ val localCluster = new LocalSparkCluster(
+ numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
+ val masterUrls = localCluster.start()
+ val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName)
+ scheduler.initialize(backend)
+ backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
+ localCluster.stop()
+ }
+ scheduler
+
+ case "yarn-standalone" =>
+ val scheduler = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(sc).asInstanceOf[ClusterScheduler]
+ } catch {
+ // TODO: Enumerate the exact reasons why it can fail
+ // But irrespective of it, it means we cannot proceed !
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+ val backend = new CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ scheduler.initialize(backend)
+ scheduler
+
+ case "yarn-client" =>
+ val scheduler = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(sc).asInstanceOf[ClusterScheduler]
+
+ } catch {
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+
+ val backend = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
+ val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext])
+ cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
+ } catch {
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+
+ scheduler.initialize(backend)
+ scheduler
+
+ case mesosUrl @ MESOS_REGEX(_) =>
+ MesosNativeLibrary.load()
+ val scheduler = new ClusterScheduler(sc)
+ val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
+ val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs
+ val backend = if (coarseGrained) {
+ new CoarseMesosSchedulerBackend(scheduler, sc, url, appName)
+ } else {
+ new MesosSchedulerBackend(scheduler, sc, url, appName)
+ }
+ scheduler.initialize(backend)
+ scheduler
+
+ case SIMR_REGEX(simrUrl) =>
+ val scheduler = new ClusterScheduler(sc)
+ val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl)
+ scheduler.initialize(backend)
+ scheduler
+
+ case _ =>
+ throw new SparkException("Could not parse Master URL: '" + master + "'")
+ }
+ }
}
/**
* A class encapsulating how to convert some type T to Writable. It stores both the Writable class
* corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion.
- * The getter for the writable class takes a ClassManifest[T] in case this is a generic object
+ * The getter for the writable class takes a ClassTag[T] in case this is a generic object
* that doesn't know the type of T when it is created. This sounds strange but is necessary to
* support converting subclasses of Writable to themselves (writableWritableConverter).
*/
private[spark] class WritableConverter[T](
- val writableClass: ClassManifest[T] => Class[_ <: Writable],
+ val writableClass: ClassTag[T] => Class[_ <: Writable],
val convert: Writable => T)
extends Serializable
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 29968c273c..826f5c2d8c 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -20,18 +20,18 @@ package org.apache.spark
import collection.mutable
import serializer.Serializer
-import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
+import akka.actor._
import akka.remote.RemoteActorRefProvider
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster}
import org.apache.spark.network.ConnectionManager
import org.apache.spark.serializer.{Serializer, SerializerManager}
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.api.python.PythonWorkerFactory
+import com.google.common.collect.MapMaker
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
@@ -58,18 +58,9 @@ class SparkEnv (
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
- val hadoop = {
- val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
- if(yarnMode) {
- try {
- Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
- } catch {
- case th: Throwable => throw new SparkException("Unable to load YARN support", th)
- }
- } else {
- new SparkHadoopUtil
- }
- }
+ // A general, soft-reference map for metadata needed during HadoopRDD split computation
+ // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
+ private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
def stop() {
pythonWorkers.foreach { case(key, worker) => worker.stop() }
@@ -83,7 +74,8 @@ class SparkEnv (
actorSystem.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
// down, but let's call it anyway in case it gets fixed in a later release
- actorSystem.awaitTermination()
+ // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
+ //actorSystem.awaitTermination()
}
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
@@ -160,17 +152,17 @@ object SparkEnv extends Logging {
val closureSerializer = serializerManager.get(
System.getProperty("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"))
- def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
+ def registerOrLookup(name: String, newActor: => Actor): Either[ActorRef, ActorSelection] = {
if (isDriver) {
logInfo("Registering " + name)
- actorSystem.actorOf(Props(newActor), name = name)
+ Left(actorSystem.actorOf(Props(newActor), name = name))
} else {
val driverHost: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
Utils.checkHost(driverHost, "Expected hostname")
- val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
+ val url = "akka.tcp://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
logInfo("Connecting to " + name + ": " + url)
- actorSystem.actorFor(url)
+ Right(actorSystem.actorSelection(url))
}
}
@@ -187,10 +179,14 @@ object SparkEnv extends Logging {
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
- val mapOutputTracker = new MapOutputTracker()
+ val mapOutputTracker = if (isDriver) {
+ new MapOutputTrackerMaster()
+ } else {
+ new MapOutputTracker()
+ }
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
- new MapOutputTrackerActor(mapOutputTracker))
+ new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index afa76a4a76..103a1c2051 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -17,14 +17,14 @@
package org.apache.hadoop.mapred
-import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.fs.Path
-
+import java.io.IOException
import java.text.SimpleDateFormat
import java.text.NumberFormat
-import java.io.IOException
import java.util.Date
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.Logging
import org.apache.spark.SerializableWritable
@@ -36,6 +36,7 @@ import org.apache.spark.SerializableWritable
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
+private[apache]
class SparkHadoopWriter(@transient jobConf: JobConf)
extends Logging
with SparkHadoopMapRedUtil
@@ -86,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
}
getOutputCommitter().setupTask(getTaskContext())
- writer = getOutputFormat().getRecordWriter(
- fs, conf.value, outputName, Reporter.NULL)
+ writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL)
}
def write(key: AnyRef, value: AnyRef) {
- if (writer!=null) {
- //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")")
+ if (writer != null) {
writer.write(key, value)
} else {
throw new IOException("Writer is null, open() has not been called")
@@ -182,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
}
}
+private[apache]
object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala
index 19ce8369d9..0bf1e4a5e2 100644
--- a/core/src/main/scala/org/apache/spark/TaskState.scala
+++ b/core/src/main/scala/org/apache/spark/TaskState.scala
@@ -19,8 +19,7 @@ package org.apache.spark
import org.apache.mesos.Protos.{TaskState => MesosTaskState}
-private[spark] object TaskState
- extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") {
+private[spark] object TaskState extends Enumeration {
val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 5fd1fab580..da30cf619a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -17,18 +17,23 @@
package org.apache.spark.api.java
+import scala.reflect.ClassTag
+
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.util.StatCounter
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.storage.StorageLevel
+
import java.lang.Double
import org.apache.spark.Partitioner
+import scala.collection.JavaConverters._
+
class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] {
- override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]]
+ override val classTag: ClassTag[Double] = implicitly[ClassTag[Double]]
override val rdd: RDD[Double] = srdd.map(x => Double.valueOf(x))
@@ -42,12 +47,25 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaDoubleRDD = fromRDD(srdd.cache())
- /**
+ /**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
+ */
+ def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist())
+
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking))
+
// first() has to be overriden here in order for its return type to be Double instead of Object.
override def first(): Double = srdd.first()
@@ -81,8 +99,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
fromRDD(srdd.coalesce(numPartitions, shuffle))
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions))
+
+ /**
* Return an RDD with the elements from `this` that are not in `other`.
- *
+ *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
@@ -158,6 +187,44 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0,50) [50,100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = {
+ val result = srdd.histogram(bucketCount)
+ (result._1, result._2)
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1,10,20,50] the buckets are [1,10) [10,20) [20,50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1,0,0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[scala.Double]): Array[Long] = {
+ srdd.histogram(buckets, false)
+ }
+
+ def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = {
+ srdd.histogram(buckets.map(_.toDouble), evenBuckets)
+ }
}
object JavaDoubleRDD {
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index a6518abf45..363667fa86 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -22,6 +22,7 @@ import java.util.Comparator
import scala.Tuple2
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
@@ -43,13 +44,13 @@ import org.apache.spark.rdd.OrderedRDDFunctions
import org.apache.spark.storage.StorageLevel
-class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K],
- implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
+class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K],
+ implicit val vClassTag: ClassTag[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
- override val classManifest: ClassManifest[(K, V)] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+ override val classTag: ClassTag[(K, V)] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]]
import JavaPairRDD._
@@ -58,13 +59,26 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache())
- /**
+ /**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
+ */
+ def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist())
+
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking))
+
// Transformations (return a new RDD)
/**
@@ -95,6 +109,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(rdd.coalesce(numPartitions, shuffle))
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions))
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
@@ -114,14 +139,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
override def first(): (K, V) = rdd.first()
// Pair RDD functions
-
+
/**
- * Generic function to combine the elements for each key using a custom set of aggregation
- * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
- * "combined type" C * Note that V and C can be different -- for example, one might group an
- * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
+ * Generic function to combine the elements for each key using a custom set of aggregation
+ * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
+ * "combined type" C * Note that V and C can be different -- for example, one might group an
+ * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
* functions:
- *
+ *
* - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
@@ -133,8 +158,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
partitioner: Partitioner): JavaPairRDD[K, C] = {
- implicit val cm: ClassManifest[C] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]]
fromRDD(rdd.combineByKey(
createCombiner,
mergeValue,
@@ -171,14 +195,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
- /**
+ /**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
- /**
+ /**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
@@ -234,7 +258,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/**
* Return an RDD with the elements from `this` that are not in `other`.
- *
+ *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
@@ -291,15 +315,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
}
- /**
+ /**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
* partitioner/parallelism level.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
- implicit val cm: ClassManifest[C] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]]
fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd)))
}
@@ -390,8 +413,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* this also retains the original RDD's partitioning.
*/
def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = {
- implicit val cm: ClassManifest[U] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
fromRDD(rdd.mapValues(f))
}
@@ -402,8 +424,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = {
import scala.collection.JavaConverters._
def fn = (x: V) => f.apply(x).asScala
- implicit val cm: ClassManifest[U] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
fromRDD(rdd.flatMapValues(fn))
}
@@ -568,6 +589,20 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
}
/**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+ * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+ * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+ * order of the keys).
+ */
+ def sortByKey(comp: Comparator[K], ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = {
+ class KeyOrdering(val a: K) extends Ordered[K] {
+ override def compare(b: K) = comp.compare(a, b)
+ }
+ implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
+ fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending, numPartitions))
+ }
+
+ /**
* Return an RDD with the keys of each tuple.
*/
def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1))
@@ -579,23 +614,32 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
}
object JavaPairRDD {
- def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassManifest[K],
- vcm: ClassManifest[T]): RDD[(K, JList[T])] =
+ def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassTag[K],
+ vcm: ClassTag[T]): RDD[(K, JList[T])] =
rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _)
- def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassManifest[K],
- vcm: ClassManifest[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V],
- Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
+ def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassTag[K],
+ vcm: ClassTag[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd)
+ .mapValues((x: (Seq[V], Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1],
- Seq[W2]))])(implicit kcm: ClassManifest[K]) : RDD[(K, (JList[V], JList[W1],
+ Seq[W2]))])(implicit kcm: ClassTag[K]) : RDD[(K, (JList[V], JList[W1],
JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues(
(x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1),
seqAsJavaList(x._2),
seqAsJavaList(x._3)))
- def fromRDD[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]): JavaPairRDD[K, V] =
+ def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd)
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
+
+
+ /** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
+ def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
+ implicit val cmk: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val cmv: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
+ new JavaPairRDD[K, V](rdd.rdd)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index eec58abdd6..037cd1c774 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -17,12 +17,14 @@
package org.apache.spark.api.java
+import scala.reflect.ClassTag
+
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.storage.StorageLevel
-class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends
+class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) extends
JavaRDDLike[T, JavaRDD[T]] {
override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd)
@@ -41,9 +43,17 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
*/
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking))
+
// Transformations (return a new RDD)
/**
@@ -74,6 +84,17 @@ JavaRDDLike[T, JavaRDD[T]] {
rdd.coalesce(numPartitions, shuffle)
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions)
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
@@ -104,12 +125,13 @@ JavaRDDLike[T, JavaRDD[T]] {
*/
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
wrapRDD(rdd.subtract(other, p))
+
+ override def toString = rdd.toString
}
object JavaRDD {
- implicit def fromRDD[T: ClassManifest](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd)
+ implicit def fromRDD[T: ClassTag](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd)
implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd
}
-
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 7a3568c5ef..f344804b4c 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -20,6 +20,7 @@ package org.apache.spark.api.java
import java.util.{List => JList, Comparator}
import scala.Tuple2
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
@@ -35,7 +36,7 @@ import org.apache.spark.storage.StorageLevel
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
- implicit val classManifest: ClassManifest[T]
+ implicit val classTag: ClassTag[T]
def rdd: RDD[T]
@@ -71,7 +72,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
- def mapPartitionsWithIndex[R: ClassManifest](
+ def mapPartitionsWithIndex[R: ClassTag](
f: JFunction2[Int, java.util.Iterator[T], java.util.Iterator[R]],
preservesPartitioning: Boolean = false): JavaRDD[R] =
new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))),
@@ -87,7 +88,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
+ def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]]
new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType())
}
@@ -118,7 +119,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
+ def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]]
JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
@@ -158,18 +159,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* elements (a, b) where a is in `this` and b is in `other`.
*/
def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] =
- JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest,
- other.classManifest)
+ JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classTag))(classTag, other.classTag)
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = {
- implicit val kcm: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val vcm: ClassManifest[JList[T]] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val vcm: ClassTag[JList[T]] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]]
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm)
}
@@ -178,10 +177,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* mapping to that key.
*/
def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = {
- implicit val kcm: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val vcm: ClassManifest[JList[T]] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val vcm: ClassTag[JList[T]] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]]
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm)
}
@@ -209,7 +207,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* a map on the other).
*/
def zip[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = {
- JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest)
+ JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag)
}
/**
@@ -224,7 +222,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
JavaRDD.fromRDD(
- rdd.zipPartitions(other.rdd)(fn)(other.classManifest, f.elementType()))(f.elementType())
+ rdd.zipPartitions(other.rdd)(fn)(other.classTag, f.elementType()))(f.elementType())
}
// Actions (launch a job to return a value to the user program)
@@ -247,6 +245,17 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
+ * Return an array that contains all of the elements in a specific partition of this RDD.
+ */
+ def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = {
+ // This is useful for implementing `take` from other language frontends
+ // like Python where the data is serialized.
+ import scala.collection.JavaConversions._
+ val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true)
+ res.map(x => new java.util.ArrayList(x.toSeq)).toArray
+ }
+
+ /**
* Reduces the elements of this RDD using the specified commutative and associative binary operator.
*/
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
@@ -356,7 +365,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
- implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 8869e072bf..acf328aa6a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -21,6 +21,7 @@ import java.util.{Map => JMap}
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
@@ -82,8 +83,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)
}
@@ -94,10 +94,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/** Distribute a local Scala collection to form an RDD. */
def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int)
: JavaPairRDD[K, V] = {
- implicit val kcm: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val vcm: ClassManifest[V] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val vcm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices))
}
@@ -132,16 +130,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits))
}
/**Get an RDD for a Hadoop SequenceFile. */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]):
JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass))
}
@@ -153,8 +151,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
sc.objectFile(path, minSplits)(cm)
}
@@ -166,8 +163,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T](path: String): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
sc.objectFile(path)(cm)
}
@@ -183,8 +179,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits))
}
@@ -199,8 +195,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
keyClass: Class[K],
valueClass: Class[V]
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass))
}
@@ -212,8 +208,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits))
}
@@ -224,8 +220,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
keyClass: Class[K],
valueClass: Class[V]
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopFile(path,
inputFormatClass, keyClass, valueClass))
}
@@ -240,8 +236,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
kClass: Class[K],
vClass: Class[V],
conf: Configuration): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(kClass)
- implicit val vcm = ClassManifest.fromClass(vClass)
+ implicit val kcm: ClassTag[K] = ClassTag(kClass)
+ implicit val vcm: ClassTag[V] = ClassTag(vClass)
new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf))
}
@@ -254,15 +250,15 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(kClass)
- implicit val vcm = ClassManifest.fromClass(vClass)
+ implicit val kcm: ClassTag[K] = ClassTag(kClass)
+ implicit val vcm: ClassTag[V] = ClassTag(vClass)
new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass))
}
/** Build the union of two or more RDDs. */
override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = {
val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
- implicit val cm: ClassManifest[T] = first.classManifest
+ implicit val cm: ClassTag[T] = first.classTag
sc.union(rdds)(cm)
}
@@ -270,9 +266,9 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]])
: JavaPairRDD[K, V] = {
val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
- implicit val cm: ClassManifest[(K, V)] = first.classManifest
- implicit val kcm: ClassManifest[K] = first.kManifest
- implicit val vcm: ClassManifest[V] = first.vManifest
+ implicit val cm: ClassTag[(K, V)] = first.classTag
+ implicit val kcm: ClassTag[K] = first.kClassTag
+ implicit val vcm: ClassTag[V] = first.vClassTag
new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm)
}
@@ -405,8 +401,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
}
protected def checkpointFile[T](path: String): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
new JavaRDD(sc.checkpointFile(path))
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
index c9cbce5624..2090efd3b9 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
@@ -17,7 +17,6 @@
package org.apache.spark.api.java;
-import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
index 4830067f7a..3e85052cd0 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
@@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
-import scala.runtime.AbstractFunction1;
-
import java.io.Serializable;
/**
@@ -27,11 +25,7 @@ import java.io.Serializable;
*/
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
-public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
+public abstract class DoubleFlatMapFunction<T> extends WrappedFunction1<T, Iterable<Double>>
implements Serializable {
-
- public abstract Iterable<Double> call(T t);
-
- @Override
- public final Iterable<Double> apply(T t) { return call(t); }
+ // Intentionally left blank
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
index db34cd190a..5e9b8c48b8 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
@@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
-import scala.runtime.AbstractFunction1;
-
import java.io.Serializable;
/**
@@ -29,6 +27,5 @@ import java.io.Serializable;
// are overloaded for both Function and DoubleFunction.
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
implements Serializable {
-
- public abstract Double call(T t) throws Exception;
+ // Intentionally left blank
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
index 158539a846..bdb01f7670 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
@@ -17,12 +17,11 @@
package org.apache.spark.api.java.function
+import scala.reflect.ClassTag
+
/**
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
- @throws(classOf[Exception])
- def call(x: T) : java.lang.Iterable[R]
-
- def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]]
+ def elementType(): ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
index 5ef6a814f5..aae1349c5e 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
@@ -17,12 +17,11 @@
package org.apache.spark.api.java.function
+import scala.reflect.ClassTag
+
/**
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
- @throws(classOf[Exception])
- def call(a: A, b:B) : java.lang.Iterable[C]
-
- def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
+ def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
index b9070cfd83..537439ef53 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
@@ -17,9 +17,8 @@
package org.apache.spark.api.java.function;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction1;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -30,10 +29,8 @@ import java.io.Serializable;
* when mapping RDDs of other types.
*/
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
- public abstract R call(T t) throws Exception;
-
- public ClassManifest<R> returnType() {
- return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<R> returnType() {
+ return ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
index d4c9154869..a2d1214fb4 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
@@ -17,9 +17,8 @@
package org.apache.spark.api.java.function;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction2;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -29,10 +28,8 @@ import java.io.Serializable;
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
implements Serializable {
- public abstract R call(T1 t1, T2 t2) throws Exception;
-
- public ClassManifest<R> returnType() {
- return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<R> returnType() {
+ return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
index aea08ff81b..fb1deceab5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
@@ -15,27 +15,22 @@
* limitations under the License.
*/
-package org.apache.spark.rdd
+package org.apache.spark.api.java.function;
-import org.apache.spark.{Partition, TaskContext}
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+import scala.runtime.AbstractFunction2;
+import java.io.Serializable;
/**
- * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
- * TaskContext, the closure can either get access to the interruptible flag or get the index
- * of the partition in the RDD.
+ * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R.
*/
-private[spark]
-class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: (TaskContext, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean
- ) extends RDD[U](prev) {
+public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R>
+ implements Serializable {
- override def getPartitions: Array[Partition] = firstParent[T].partitions
-
- override val partitioner = if (preservesPartitioning) prev.partitioner else None
-
- override def compute(split: Partition, context: TaskContext) =
- f(context, firstParent[T].iterator(split, context))
+ public ClassTag<R> returnType() {
+ return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
+ }
}
+
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
index c0e5544b7d..ca485b3cc2 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
@@ -18,9 +18,8 @@
package org.apache.spark.api.java.function;
import scala.Tuple2;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction1;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -34,13 +33,11 @@ public abstract class PairFlatMapFunction<T, K, V>
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
implements Serializable {
- public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception;
-
- public ClassManifest<K> keyType() {
- return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<K> keyType() {
+ return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}
- public ClassManifest<V> valueType() {
- return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<V> valueType() {
+ return (ClassTag<V>) ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
index 40480fe8e8..cbe2306026 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
@@ -18,9 +18,8 @@
package org.apache.spark.api.java.function;
import scala.Tuple2;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction1;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -29,17 +28,14 @@ import java.io.Serializable;
*/
// PairFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and PairFunction.
-public abstract class PairFunction<T, K, V>
- extends WrappedFunction1<T, Tuple2<K, V>>
+public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>>
implements Serializable {
- public abstract Tuple2<K, V> call(T t) throws Exception;
-
- public ClassManifest<K> keyType() {
- return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<K> keyType() {
+ return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}
- public ClassManifest<V> valueType() {
- return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<V> valueType() {
+ return (ClassTag<V>) ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
new file mode 100644
index 0000000000..d314dbdf1d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function
+
+import scala.runtime.AbstractFunction3
+
+/**
+ * Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the
+ * apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply
+ * isn't marked to allow that).
+ */
+private[spark] abstract class WrappedFunction3[T1, T2, T3, R]
+ extends AbstractFunction3[T1, T2, T3, R] {
+ @throws(classOf[Exception])
+ def call(t1: T1, t2: T2, t3: T3): R
+
+ final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3)
+}
+
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 1f8ad688a6..ca42c76928 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -22,18 +22,17 @@ import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark._
import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.PipedRDD
import org.apache.spark.util.Utils
-
-private[spark] class PythonRDD[T: ClassManifest](
+private[spark] class PythonRDD[T: ClassTag](
parent: RDD[T],
- command: Seq[String],
+ command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
preservePartitoning: Boolean,
@@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest](
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
- // Similar to Runtime.exec(), if we are given a single string, split it into words
- // using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String, envVars: JMap[String, String],
- pythonIncludes: JList[String],
- preservePartitoning: Boolean, pythonExec: String,
- broadcastVars: JList[Broadcast[Array[Byte]]],
- accumulator: Accumulator[JList[Array[Byte]]]) =
- this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
- broadcastVars, accumulator)
-
override def getPartitions = parent.partitions
override val partitioner = if (preservePartitoning) parent.partitioner else None
-
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
@@ -71,11 +59,10 @@ private[spark] class PythonRDD[T: ClassManifest](
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
- val printOut = new PrintWriter(stream)
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
- PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+ dataOut.writeUTF(SparkFiles.getRootDirectory)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
@@ -85,21 +72,16 @@ private[spark] class PythonRDD[T: ClassManifest](
}
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
- for (f <- pythonIncludes) {
- PythonRDD.writeAsPickle(f, dataOut)
- }
+ pythonIncludes.foreach(dataOut.writeUTF)
dataOut.flush()
- // Serialized user code
- for (elem <- command) {
- printOut.println(elem)
- }
- printOut.flush()
+ // Serialized command:
+ dataOut.writeInt(command.length)
+ dataOut.write(command)
// Data values
for (elem <- parent.iterator(split, context)) {
- PythonRDD.writeAsPickle(elem, dataOut)
+ PythonRDD.writeToStream(elem, dataOut)
}
dataOut.flush()
- printOut.flush()
worker.shutdownOutput()
} catch {
case e: IOException =>
@@ -132,7 +114,7 @@ private[spark] class PythonRDD[T: ClassManifest](
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
- case -3 =>
+ case SpecialLengths.TIMING_DATA =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
@@ -143,30 +125,30 @@ private[spark] class PythonRDD[T: ClassManifest](
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
read
- case -2 =>
+ case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj))
- case -1 =>
+ case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
- // read some accumulator updates; let's do that, breaking when we
- // get a negative length record.
- var len2 = stream.readInt()
- while (len2 >= 0) {
- val update = new Array[Byte](len2)
+ // read some accumulator updates:
+ val numAccumulatorUpdates = stream.readInt()
+ (1 to numAccumulatorUpdates).foreach { _ =>
+ val updateLen = stream.readInt()
+ val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator += Collections.singletonList(update)
- len2 = stream.readInt()
+
}
- new Array[Byte](0)
+ Array.empty[Byte]
}
} catch {
case eof: EOFException => {
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
- case e => throw e
+ case e: Throwable => throw e
}
}
@@ -197,62 +179,15 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
}
-private[spark] object PythonRDD {
-
- /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
- def stripPickle(arr: Array[Byte]) : Array[Byte] = {
- arr.slice(2, arr.length - 1)
- }
+private object SpecialLengths {
+ val END_OF_DATA_SECTION = -1
+ val PYTHON_EXCEPTION_THROWN = -2
+ val TIMING_DATA = -3
+}
- /**
- * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
- * The data format is a 32-bit integer representing the pickled object's length (in bytes),
- * followed by the pickled data.
- *
- * Pickle module:
- *
- * http://docs.python.org/2/library/pickle.html
- *
- * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
- *
- * http://hg.python.org/cpython/file/2.6/Lib/pickle.py
- * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
- *
- * @param elem the object to write
- * @param dOut a data output stream
- */
- def writeAsPickle(elem: Any, dOut: DataOutputStream) {
- if (elem.isInstanceOf[Array[Byte]]) {
- val arr = elem.asInstanceOf[Array[Byte]]
- dOut.writeInt(arr.length)
- dOut.write(arr)
- } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
- val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
- val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.write(PythonRDD.stripPickle(t._1))
- dOut.write(PythonRDD.stripPickle(t._2))
- dOut.writeByte(Pickle.TUPLE2)
- dOut.writeByte(Pickle.STOP)
- } else if (elem.isInstanceOf[String]) {
- // For uniformity, strings are wrapped into Pickles.
- val s = elem.asInstanceOf[String].getBytes("UTF-8")
- val length = 2 + 1 + 4 + s.length + 1
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.write(Pickle.BINUNICODE)
- dOut.writeInt(Integer.reverseBytes(s.length))
- dOut.write(s)
- dOut.writeByte(Pickle.STOP)
- } else {
- throw new SparkException("Unexpected RDD type")
- }
- }
+private[spark] object PythonRDD {
- def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+ def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -265,39 +200,41 @@ private[spark] object PythonRDD {
}
} catch {
case eof: EOFException => {}
- case e => throw e
+ case e: Throwable => throw e
}
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
- def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ def writeToStream(elem: Any, dataOut: DataOutputStream) {
+ elem match {
+ case bytes: Array[Byte] =>
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ case pair: (Array[Byte], Array[Byte]) =>
+ dataOut.writeInt(pair._1.length)
+ dataOut.write(pair._1)
+ dataOut.writeInt(pair._2.length)
+ dataOut.write(pair._2)
+ case str: String =>
+ dataOut.writeUTF(str)
+ case other =>
+ throw new SparkException("Unexpected element type " + other.getClass)
+ }
+ }
+
+ def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
- writeIteratorToPickleFile(items.asScala, filename)
+ writeToFile(items.asScala, filename)
}
- def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
+ def writeToFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
- writeAsPickle(item, file)
+ writeToStream(item, file)
}
file.close()
}
- def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
- implicit val cm : ClassManifest[T] = rdd.elementClassManifest
- rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
- }
-}
-
-private object Pickle {
- val PROTO: Byte = 0x80.toByte
- val TWO: Byte = 0x02.toByte
- val BINUNICODE: Byte = 'X'
- val STOP: Byte = '.'
- val TUPLE2: Byte = 0x86.toByte
- val EMPTY_LIST: Byte = ']'
- val MARK: Byte = '('
- val APPENDS: Byte = 'e'
}
private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
@@ -308,7 +245,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
-class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 67d45723ba..f291266fcf 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -64,7 +64,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
startDaemon()
new Socket(daemonHost, daemonPort)
}
- case e => throw e
+ case e: Throwable => throw e
}
}
}
@@ -198,7 +198,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}.start()
} catch {
- case e => {
+ case e: Throwable => {
stopDaemon()
throw e
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
deleted file mode 100644
index 5332510e87..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
+++ /dev/null
@@ -1,1060 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Comparator, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ListBuffer, Map, Set}
-import scala.math
-
-import org.apache.spark._
-import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
-import org.apache.spark.util.Utils
-
-private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
- extends Broadcast[T](id)
- with Logging
- with Serializable {
-
- def value = value_
-
- def blockId = BroadcastBlockId(id)
-
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var hasBlocksBitVector: BitSet = null
- @transient var numCopiesSent: Array[Int] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = new AtomicInteger(0)
-
- // Used ONLY by driver to track how many unique blocks have been sent out
- @transient var sentBlocks = new AtomicInteger(0)
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
-
- @transient var listOfSources = ListBuffer[SourceInfo]()
-
- @transient var serveMR: ServeMultipleRequests = null
-
- // Used only in driver
- @transient var guideMR: GuideMultipleRequests = null
-
- // Used only in Workers
- @transient var ttGuide: TalkToGuide = null
-
- @transient var hostAddress = Utils.localIpAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast()
- }
-
- def sendBroadcast() {
- logInfo("Local host address: " + hostAddress)
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = MultiTracker.blockifyObject(value_)
-
- // Prepare the value being broadcasted
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks.set(variableInfo.totalBlocks)
-
- // Guide has all the blocks
- hasBlocksBitVector = new BitSet(totalBlocks)
- hasBlocksBitVector.set(0, totalBlocks)
-
- // Guide still hasn't sent any block
- numCopiesSent = new Array[Int](totalBlocks)
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon(true)
- guideMR.start()
- logInfo("GuideMultipleRequests started...")
-
- // Must always come AFTER guideMR is created
- while (guidePort == -1) {
- guidePortLock.synchronized { guidePortLock.wait() }
- }
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- // Must always come AFTER serveMR is created
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Must always come AFTER listenPort is created
- val driverSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
- hasBlocksBitVector.synchronized {
- driverSource.hasBlocksBitVector = hasBlocksBitVector
- }
-
- // In the beginning, this is the only known source to Guide
- listOfSources += driverSource
-
- // Register with the Tracker
- MultiTracker.registerBroadcast(id,
- SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
- }
-
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(blockId) match {
- case Some(x) =>
- value_ = x.asInstanceOf[T]
-
- case None =>
- logInfo("Started reading broadcast variable " + id)
- // Initializing everything because driver will only send null/0 values
- // Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables()
-
- logInfo("Local host address: " + hostAddress)
-
- // Start local ServeMultipleRequests thread first
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(id)
- if (receptionSucceeded) {
- value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- // Initialize variables in the worker node. Driver sends everything as 0/null
- private def initializeWorkerVariables() {
- arrayOfBlocks = null
- hasBlocksBitVector = null
- numCopiesSent = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = new AtomicInteger(0)
-
- listenPortLock = new Object
- totalBlocksLock = new Object
-
- serveMR = null
- ttGuide = null
-
- hostAddress = Utils.localIpAddress
- listenPort = -1
-
- listOfSources = ListBuffer[SourceInfo]()
-
- stopBroadcast = false
- }
-
- private def getLocalSourceInfo: SourceInfo = {
- // Wait till hostName and listenPort are OK
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Wait till totalBlocks and totalBytes are OK
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized { totalBlocksLock.wait() }
- }
-
- var localSourceInfo = SourceInfo(
- hostAddress, listenPort, totalBlocks, totalBytes)
-
- localSourceInfo.hasBlocks = hasBlocks.get
-
- hasBlocksBitVector.synchronized {
- localSourceInfo.hasBlocksBitVector = hasBlocksBitVector
- }
-
- return localSourceInfo
- }
-
- // Add new SourceInfo to the listOfSources. Update if it exists already.
- // Optimizing just by OR-ing the BitVectors was BAD for performance
- private def addToListOfSources(newSourceInfo: SourceInfo) {
- listOfSources.synchronized {
- if (listOfSources.contains(newSourceInfo)) {
- listOfSources = listOfSources - newSourceInfo
- }
- listOfSources += newSourceInfo
- }
- }
-
- private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) {
- newSourceInfos.foreach { newSourceInfo =>
- addToListOfSources(newSourceInfo)
- }
- }
-
- class TalkToGuide(gInfo: SourceInfo)
- extends Thread with Logging {
- override def run() {
-
- // Keep exchaning information until all blocks have been received
- while (hasBlocks.get < totalBlocks) {
- talkOnce
- Thread.sleep(MultiTracker.ranGen.nextInt(
- MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
- MultiTracker.MinKnockInterval)
- }
-
- // Talk one more time to let the Guide know of reception completion
- talkOnce
- }
-
- // Connect to Guide and send this worker's information
- private def talkOnce {
- var clientSocketToGuide: Socket = null
- var oosGuide: ObjectOutputStream = null
- var oisGuide: ObjectInputStream = null
-
- clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort)
- oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream)
- oosGuide.flush()
- oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream)
-
- // Send local information
- oosGuide.writeObject(getLocalSourceInfo)
- oosGuide.flush()
-
- // Receive source information from Guide
- var suitableSources =
- oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
- logDebug("Received suitableSources from Driver " + suitableSources)
-
- addToListOfSources(suitableSources)
-
- oisGuide.close()
- oosGuide.close()
- clientSocketToGuide.close()
- }
- }
-
- def receiveBroadcast(variableID: Long): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableID)
-
- if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Setup initial states of variables
- totalBlocks = gInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- hasBlocksBitVector = new BitSet(totalBlocks)
- numCopiesSent = new Array[Int](totalBlocks)
- totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
- totalBytes = gInfo.totalBytes
-
- // Start ttGuide to periodically talk to the Guide
- var ttGuide = new TalkToGuide(gInfo)
- ttGuide.setDaemon(true)
- ttGuide.start()
- logInfo("TalkToGuide started...")
-
- // Start pController to run TalkToPeer threads
- var pcController = new PeerChatterController
- pcController.setDaemon(true)
- pcController.start()
- logInfo("PeerChatterController started...")
-
- // FIXME: Must fix this. This might never break if broadcast fails.
- // We should be able to break and send false. Also need to kill threads
- while (hasBlocks.get < totalBlocks) {
- Thread.sleep(MultiTracker.MaxKnockInterval)
- }
-
- return true
- }
-
- class PeerChatterController
- extends Thread with Logging {
- private var peersNowTalking = ListBuffer[SourceInfo]()
- // TODO: There is a possible bug with blocksInRequestBitVector when a
- // certain bit is NOT unset upon failure resulting in an infinite loop.
- private var blocksInRequestBitVector = new BitSet(totalBlocks)
-
- override def run() {
- var threadPool = Utils.newDaemonFixedThreadPool(
- MultiTracker.MaxChatSlots, "Bit Torrent Chatter")
-
- while (hasBlocks.get < totalBlocks) {
- var numThreadsToCreate = 0
- listOfSources.synchronized {
- numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
- threadPool.getActiveCount
- }
-
- while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
- var peerToTalkTo = pickPeerToTalkToRandom
-
- if (peerToTalkTo != null)
- logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector)
- else
- logDebug("No peer chosen...")
-
- if (peerToTalkTo != null) {
- threadPool.execute(new TalkToPeer(peerToTalkTo))
-
- // Add to peersNowTalking. Remove in the thread. We have to do this
- // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once
- peersNowTalking.synchronized { peersNowTalking += peerToTalkTo }
- }
-
- numThreadsToCreate = numThreadsToCreate - 1
- }
-
- // Sleep for a while before starting some more threads
- Thread.sleep(MultiTracker.MinKnockInterval)
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- // Right now picking the one that has the most blocks this peer wants
- // Also picking peer randomly if no one has anything interesting
- private def pickPeerToTalkToRandom: SourceInfo = {
- var curPeer: SourceInfo = null
- var curMax = 0
-
- logDebug("Picking peers to talk to...")
-
- // Find peers that are not connected right now
- var peersNotInUse = ListBuffer[SourceInfo]()
- listOfSources.synchronized {
- peersNowTalking.synchronized {
- peersNotInUse = listOfSources -- peersNowTalking
- }
- }
-
- // Select the peer that has the most blocks that this receiver does not
- peersNotInUse.foreach { eachSource =>
- var tempHasBlocksBitVector: BitSet = null
- hasBlocksBitVector.synchronized {
- tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
- tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size)
- tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector)
-
- if (tempHasBlocksBitVector.cardinality > curMax) {
- curPeer = eachSource
- curMax = tempHasBlocksBitVector.cardinality
- }
- }
-
- // Always picking randomly
- if (curPeer == null && peersNotInUse.size > 0) {
- // Pick uniformly the i'th required peer
- var i = MultiTracker.ranGen.nextInt(peersNotInUse.size)
-
- var peerIter = peersNotInUse.iterator
- curPeer = peerIter.next
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
- }
-
- return curPeer
- }
-
- // Picking peer with the weight of rare blocks it has
- private def pickPeerToTalkToRarestFirst: SourceInfo = {
- // Find peers that are not connected right now
- var peersNotInUse = ListBuffer[SourceInfo]()
- listOfSources.synchronized {
- peersNowTalking.synchronized {
- peersNotInUse = listOfSources -- peersNowTalking
- }
- }
-
- // Count the number of copies of each block in the neighborhood
- var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
-
- listOfSources.synchronized {
- listOfSources.foreach { eachSource =>
- for (i <- 0 until totalBlocks) {
- numCopiesPerBlock(i) +=
- ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
- }
- }
- }
-
- // A block is considered rare if there are at most 2 copies of that block
- // This CONSTANT could be a function of the neighborhood size
- var rareBlocksIndices = ListBuffer[Int]()
- for (i <- 0 until totalBlocks) {
- if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) {
- rareBlocksIndices += i
- }
- }
-
- // Find peers with rare blocks
- var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]()
- var totalRareBlocks = 0
-
- peersNotInUse.foreach { eachPeer =>
- var hasRareBlocks = 0
- rareBlocksIndices.foreach { rareBlock =>
- if (eachPeer.hasBlocksBitVector.get(rareBlock)) {
- hasRareBlocks += 1
- }
- }
-
- if (hasRareBlocks > 0) {
- peersWithRareBlocks += ((eachPeer, hasRareBlocks))
- }
- totalRareBlocks += hasRareBlocks
- }
-
- // Select a peer from peersWithRareBlocks based on weight calculated from
- // unique rare blocks
- var selectedPeerToTalkTo: SourceInfo = null
-
- if (peersWithRareBlocks.size > 0) {
- // Sort the peers based on how many rare blocks they have
- peersWithRareBlocks.sortBy(_._2)
-
- var randomNumber = MultiTracker.ranGen.nextDouble
- var tempSum = 0.0
-
- var i = 0
- do {
- tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks)
- if (tempSum >= randomNumber) {
- selectedPeerToTalkTo = peersWithRareBlocks(i)._1
- }
- i += 1
- } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null)
- }
-
- if (selectedPeerToTalkTo == null) {
- selectedPeerToTalkTo = pickPeerToTalkToRandom
- }
-
- return selectedPeerToTalkTo
- }
-
- class TalkToPeer(peerToTalkTo: SourceInfo)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- override def run() {
- // TODO: There is a possible bug here regarding blocksInRequestBitVector
- var blockToAskFor = -1
-
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run() {
- cleanUpConnections()
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval)
-
- logInfo("TalkToPeer started... => " + peerToTalkTo)
-
- try {
- // Connect to the source
- peerSocketToSource =
- new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort)
- oosSource =
- new ObjectOutputStream(peerSocketToSource.getOutputStream)
- oosSource.flush()
- oisSource =
- new ObjectInputStream(peerSocketToSource.getInputStream)
-
- // Receive latest SourceInfo from peerToTalkTo
- var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo]
- // Update listOfSources
- addToListOfSources(newPeerToTalkTo)
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel()
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush()
-
- var keepReceiving = true
-
- while (hasBlocks.get < totalBlocks && keepReceiving) {
- blockToAskFor =
- pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector)
-
- // No block to request
- if (blockToAskFor < 0) {
- // Nothing to receive from newPeerToTalkTo
- keepReceiving = false
- } else {
- // Let other threads know that blockToAskFor is being requested
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor)
- }
-
- // Start with sending the blockID
- oosSource.writeObject(blockToAskFor)
- oosSource.flush()
-
- // CHANGED: Driver might send some other block than the one
- // requested to ensure fast spreading of all blocks.
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
-
- if (!hasBlocksBitVector.get(bcBlock.blockID)) {
- arrayOfBlocks(bcBlock.blockID) = bcBlock
-
- // Update the hasBlocksBitVector first
- hasBlocksBitVector.synchronized {
- hasBlocksBitVector.set(bcBlock.blockID)
- hasBlocks.getAndIncrement
- }
-
- // Some block(may NOT be blockToAskFor) has arrived.
- // In any case, blockToAskFor is not in request any more
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor, false)
- }
-
- // Reset blockToAskFor to -1. Else it will be considered missing
- blockToAskFor = -1
- }
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush()
- }
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logError("TalktoPeer had a " + e)
- // FIXME: Remove 'newPeerToTalkTo' from listOfSources
- // We probably should have the following in some form, but not
- // really here. This exception can happen if the sender just breaks connection
- // listOfSources.synchronized {
- // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo)
- // listOfSources = listOfSources - peerToTalkTo
- // }
- }
- } finally {
- // blockToAskFor != -1 => there was an exception
- if (blockToAskFor != -1) {
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor, false)
- }
- }
-
- cleanUpConnections()
- }
- }
-
- // Right now it picks a block uniformly that this peer does not have
- private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = {
- var needBlocksBitVector: BitSet = null
-
- // Blocks already present
- hasBlocksBitVector.synchronized {
- needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include blocks already in transmission ONLY IF
- // MultiTracker.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
- blocksInRequestBitVector.synchronized {
- needBlocksBitVector.or(blocksInRequestBitVector)
- }
- }
-
- // Find blocks that are neither here nor in transit
- needBlocksBitVector.flip(0, needBlocksBitVector.size)
-
- // Blocks that should/can be requested
- needBlocksBitVector.and(txHasBlocksBitVector)
-
- if (needBlocksBitVector.cardinality == 0) {
- return -1
- } else {
- // Pick uniformly the i'th required block
- var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality)
- var pickedBlockIndex = needBlocksBitVector.nextSetBit(0)
-
- while (i > 0) {
- pickedBlockIndex =
- needBlocksBitVector.nextSetBit(pickedBlockIndex + 1)
- i -= 1
- }
-
- return pickedBlockIndex
- }
- }
-
- // Pick the block that seems to be the rarest across sources
- private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = {
- var needBlocksBitVector: BitSet = null
-
- // Blocks already present
- hasBlocksBitVector.synchronized {
- needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include blocks already in transmission ONLY IF
- // MultiTracker.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
- blocksInRequestBitVector.synchronized {
- needBlocksBitVector.or(blocksInRequestBitVector)
- }
- }
-
- // Find blocks that are neither here nor in transit
- needBlocksBitVector.flip(0, needBlocksBitVector.size)
-
- // Blocks that should/can be requested
- needBlocksBitVector.and(txHasBlocksBitVector)
-
- if (needBlocksBitVector.cardinality == 0) {
- return -1
- } else {
- // Count the number of copies for each block across all sources
- var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
-
- listOfSources.synchronized {
- listOfSources.foreach { eachSource =>
- for (i <- 0 until totalBlocks) {
- numCopiesPerBlock(i) +=
- ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
- }
- }
- }
-
- // Find the minimum
- var minVal = Integer.MAX_VALUE
- for (i <- 0 until totalBlocks) {
- if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) {
- minVal = numCopiesPerBlock(i)
- }
- }
-
- // Find the blocks with the least copies that this peer does not have
- var minBlocksIndices = ListBuffer[Int]()
- for (i <- 0 until totalBlocks) {
- if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) {
- minBlocksIndices += i
- }
- }
-
- // Now select a random index from minBlocksIndices
- if (minBlocksIndices.size == 0) {
- return -1
- } else {
- // Pick uniformly the i'th index
- var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size)
- return minBlocksIndices(i)
- }
- }
- }
-
- private def cleanUpConnections() {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close()
- }
-
- // Delete from peersNowTalking
- peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo }
- }
- }
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo]()
-
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool("Bit torrent guide multiple requests")
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- guidePort = serverSocket.getLocalPort
- logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized { guidePortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
- listOfSources.synchronized {
- setOfCompletedSources.synchronized {
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
- logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
- }
- }
- }
- }
- }
- if (clientSocket != null) {
- logDebug("Guide: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new GuideSingleRequest(clientSocket))
- } catch {
- // In failure, close the socket here; else, thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown()
-
- logInfo("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- MultiTracker.unregisterBroadcast(id)
- } finally {
- if (serverSocket != null) {
- logInfo("GuideMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- }
-
- private def sendStopBroadcastNotifications() {
- listOfSources.synchronized {
- listOfSources.foreach { sourceInfo =>
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
- gosSource.flush()
- gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
-
- // Throw away whatever comes in
- gisSource.readObject.asInstanceOf[SourceInfo]
-
- // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast
- gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast))
- gosSource.flush()
- } catch {
- case e: Exception => {
- logError("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close()
- }
- if (gosSource != null) {
- gosSource.close()
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close()
- }
- }
- }
- }
- }
-
- class GuideSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var sourceInfo: SourceInfo = null
- private var selectedSources: ListBuffer[SourceInfo] = null
-
- override def run() {
- try {
- logInfo("new GuideSingleRequest is running")
- // Connecting worker is sending in its information
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Select a suitable source and send it back to the worker
- selectedSources = selectSuitableSources(sourceInfo)
- logDebug("Sending selectedSources:" + selectedSources)
- oos.writeObject(selectedSources)
- oos.flush()
-
- // Add this source to the listOfSources
- addToListOfSources(sourceInfo)
- } catch {
- case e: Exception => {
- // Assuming exception caused by receiver failure: remove
- if (listOfSources != null) {
- listOfSources.synchronized { listOfSources -= sourceInfo }
- }
- }
- } finally {
- logInfo("GuideSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- // Randomly select some sources to send back
- private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = {
- var selectedSources = ListBuffer[SourceInfo]()
-
- // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true'
- // then add skipSourceInfo to setOfCompletedSources. Return blank.
- if (skipSourceInfo.hasBlocks == totalBlocks) {
- setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo }
- return selectedSources
- }
-
- listOfSources.synchronized {
- if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) {
- selectedSources = listOfSources.clone
- } else {
- var picksLeft = MultiTracker.MaxPeersInGuideResponse
- var alreadyPicked = new BitSet(listOfSources.size)
-
- while (picksLeft > 0) {
- var i = -1
-
- do {
- i = MultiTracker.ranGen.nextInt(listOfSources.size)
- } while (alreadyPicked.get(i))
-
- var peerIter = listOfSources.iterator
- var curPeer = peerIter.next
-
- // Set the BitSet before i is decremented
- alreadyPicked.set(i)
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
-
- selectedSources += curPeer
-
- picksLeft = picksLeft - 1
- }
- }
- }
-
- // Remove the receiving source (if present)
- selectedSources = selectedSources - skipSourceInfo
-
- return selectedSources
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
- // Server at most MultiTracker.MaxChatSlots peers
- var threadPool = Utils.newDaemonFixedThreadPool(
- MultiTracker.MaxChatSlots, "Bit torrent serve multiple requests")
-
- override def run() {
- var serverSocket = new ServerSocket(0)
- listenPort = serverSocket.getLocalPort
-
- logInfo("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized { listenPortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => { }
- }
- if (clientSocket != null) {
- logDebug("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new ServeSingleRequest(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ServeMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ServeSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- logInfo("new ServeSingleRequest is running")
-
- override def run() {
- try {
- // Send latest local SourceInfo to the receiver
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- oos.writeObject(getLocalSourceInfo)
- oos.flush()
-
- // Receive latest SourceInfo from the receiver
- var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- addToListOfSources(rxSourceInfo)
- }
-
- val startTime = System.currentTimeMillis
- var curTime = startTime
- var keepSending = true
- var numBlocksToSend = MultiTracker.MaxChatBlocks
-
- while (!stopBroadcast && keepSending && numBlocksToSend > 0) {
- // Receive which block to send
- var blockToSend = ois.readObject.asInstanceOf[Int]
-
- // If it is driver AND at least one copy of each block has not been
- // sent out already, MODIFY blockToSend
- if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) {
- blockToSend = sentBlocks.getAndIncrement
- }
-
- // Send the block
- sendBlock(blockToSend)
- rxSourceInfo.hasBlocksBitVector.set(blockToSend)
-
- numBlocksToSend -= 1
-
- // Receive latest SourceInfo from the receiver
- rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
- logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
- addToListOfSources(rxSourceInfo)
-
- curTime = System.currentTimeMillis
- // Revoke sending only if there is anyone waiting in the queue
- if (curTime - startTime >= MultiTracker.MaxChatTime &&
- threadPool.getQueue.size > 0) {
- keepSending = false
- }
- }
- } catch {
- case e: Exception => logError("ServeSingleRequest had a " + e)
- } finally {
- logInfo("ServeSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- private def sendBlock(blockToSend: Int) {
- try {
- oos.writeObject(arrayOfBlocks(blockToSend))
- oos.flush()
- } catch {
- case e: Exception => logError("sendBlock had a " + e)
- }
- logDebug("Sent block: " + blockToSend + " to " + clientSocket)
- }
- }
- }
-}
-
-private[spark] class BitTorrentBroadcastFactory
-extends BroadcastFactory {
- def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new BitTorrentBroadcast[T](value_, isLocal, id)
-
- def stop() { MultiTracker.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 609464e38d..47db720416 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -19,6 +19,7 @@ package org.apache.spark.broadcast
import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
import java.net.URL
+import java.util.concurrent.TimeUnit
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
@@ -83,6 +84,8 @@ private object HttpBroadcast extends Logging {
private val files = new TimeStampedHashSet[String]
private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup)
+ private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5,TimeUnit.MINUTES).toInt
+
private lazy val compressionCodec = CompressionCodec.createCodec()
def initialize(isDriver: Boolean) {
@@ -138,10 +141,13 @@ private object HttpBroadcast extends Logging {
def read[T](id: Long): T = {
val url = serverUri + "/" + BroadcastBlockId(id).name
val in = {
+ val httpConnection = new URL(url).openConnection()
+ httpConnection.setReadTimeout(httpReadTimeout)
+ val inputStream = httpConnection.getInputStream()
if (compress) {
- compressionCodec.compressedInputStream(new URL(url).openStream())
+ compressionCodec.compressedInputStream(inputStream)
} else {
- new FastBufferedInputStream(new URL(url).openStream(), bufferSize)
+ new FastBufferedInputStream(inputStream, bufferSize)
}
}
val ser = SparkEnv.get.serializer.newInstance()
diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
deleted file mode 100644
index 82ed64f190..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
+++ /dev/null
@@ -1,410 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.Random
-
-import scala.collection.mutable.Map
-
-import org.apache.spark._
-import org.apache.spark.util.Utils
-
-private object MultiTracker
-extends Logging {
-
- // Tracker Messages
- val REGISTER_BROADCAST_TRACKER = 0
- val UNREGISTER_BROADCAST_TRACKER = 1
- val FIND_BROADCAST_TRACKER = 2
-
- // Map to keep track of guides of ongoing broadcasts
- var valueToGuideMap = Map[Long, SourceInfo]()
-
- // Random number generator
- var ranGen = new Random
-
- private var initialized = false
- private var _isDriver = false
-
- private var stopBroadcast = false
-
- private var trackMV: TrackMultipleValues = null
-
- def initialize(__isDriver: Boolean) {
- synchronized {
- if (!initialized) {
- _isDriver = __isDriver
-
- if (isDriver) {
- trackMV = new TrackMultipleValues
- trackMV.setDaemon(true)
- trackMV.start()
-
- // Set DriverHostAddress to the driver's IP address for the slaves to read
- System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
- }
-
- initialized = true
- }
- }
- }
-
- def stop() {
- stopBroadcast = true
- }
-
- // Load common parameters
- private var DriverHostAddress_ = System.getProperty(
- "spark.MultiTracker.DriverHostAddress", "")
- private var DriverTrackerPort_ = System.getProperty(
- "spark.broadcast.driverTrackerPort", "11111").toInt
- private var BlockSize_ = System.getProperty(
- "spark.broadcast.blockSize", "4096").toInt * 1024
- private var MaxRetryCount_ = System.getProperty(
- "spark.broadcast.maxRetryCount", "2").toInt
-
- private var TrackerSocketTimeout_ = System.getProperty(
- "spark.broadcast.trackerSocketTimeout", "50000").toInt
- private var ServerSocketTimeout_ = System.getProperty(
- "spark.broadcast.serverSocketTimeout", "10000").toInt
-
- private var MinKnockInterval_ = System.getProperty(
- "spark.broadcast.minKnockInterval", "500").toInt
- private var MaxKnockInterval_ = System.getProperty(
- "spark.broadcast.maxKnockInterval", "999").toInt
-
- // Load TreeBroadcast config params
- private var MaxDegree_ = System.getProperty(
- "spark.broadcast.maxDegree", "2").toInt
-
- // Load BitTorrentBroadcast config params
- private var MaxPeersInGuideResponse_ = System.getProperty(
- "spark.broadcast.maxPeersInGuideResponse", "4").toInt
-
- private var MaxChatSlots_ = System.getProperty(
- "spark.broadcast.maxChatSlots", "4").toInt
- private var MaxChatTime_ = System.getProperty(
- "spark.broadcast.maxChatTime", "500").toInt
- private var MaxChatBlocks_ = System.getProperty(
- "spark.broadcast.maxChatBlocks", "1024").toInt
-
- private var EndGameFraction_ = System.getProperty(
- "spark.broadcast.endGameFraction", "0.95").toDouble
-
- def isDriver = _isDriver
-
- // Common config params
- def DriverHostAddress = DriverHostAddress_
- def DriverTrackerPort = DriverTrackerPort_
- def BlockSize = BlockSize_
- def MaxRetryCount = MaxRetryCount_
-
- def TrackerSocketTimeout = TrackerSocketTimeout_
- def ServerSocketTimeout = ServerSocketTimeout_
-
- def MinKnockInterval = MinKnockInterval_
- def MaxKnockInterval = MaxKnockInterval_
-
- // TreeBroadcast configs
- def MaxDegree = MaxDegree_
-
- // BitTorrentBroadcast configs
- def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
-
- def MaxChatSlots = MaxChatSlots_
- def MaxChatTime = MaxChatTime_
- def MaxChatBlocks = MaxChatBlocks_
-
- def EndGameFraction = EndGameFraction_
-
- class TrackMultipleValues
- extends Thread with Logging {
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool("Track multiple values")
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(DriverTrackerPort)
- logInfo("TrackMultipleValues started at " + serverSocket)
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(TrackerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- if (stopBroadcast) {
- logInfo("Stopping TrackMultipleValues...")
- }
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute(new Thread {
- override def run() {
- val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- try {
- // First, read message type
- val messageType = ois.readObject.asInstanceOf[Int]
-
- if (messageType == REGISTER_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
- // Receive hostAddress and listenPort
- val gInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Add to the map
- valueToGuideMap.synchronized {
- valueToGuideMap += (id -> gInfo)
- }
-
- logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
-
- // Remove from the map
- valueToGuideMap.synchronized {
- valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
- }
-
- logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == FIND_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
-
- var gInfo =
- if (valueToGuideMap.contains(id)) valueToGuideMap(id)
- else SourceInfo("", SourceInfo.TxNotStartedRetry)
-
- logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
-
- // Send reply back
- oos.writeObject(gInfo)
- oos.flush()
- } else {
- throw new SparkException("Undefined messageType at TrackMultipleValues")
- }
- } catch {
- case e: Exception => {
- logError("TrackMultipleValues had a " + e)
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- serverSocket.close()
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
- }
-
- def getGuideInfo(variableLong: Long): SourceInfo = {
- var clientSocketToTracker: Socket = null
- var oosTracker: ObjectOutputStream = null
- var oisTracker: ObjectInputStream = null
-
- var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
-
- var retriesLeft = MultiTracker.MaxRetryCount
- do {
- try {
- // Connect to the tracker to find out GuideInfo
- clientSocketToTracker =
- new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
- oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- // Send messageType/intention
- oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
- oosTracker.flush()
-
- // Send Long and receive GuideInfo
- oosTracker.writeObject(variableLong)
- oosTracker.flush()
- gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
- } catch {
- case e: Exception => logError("getGuideInfo had a " + e)
- } finally {
- if (oisTracker != null) {
- oisTracker.close()
- }
- if (oosTracker != null) {
- oosTracker.close()
- }
- if (clientSocketToTracker != null) {
- clientSocketToTracker.close()
- }
- }
-
- Thread.sleep(MultiTracker.ranGen.nextInt(
- MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
- MultiTracker.MinKnockInterval)
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
-
- logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
- return gInfo
- }
-
- def registerBroadcast(id: Long, gInfo: SourceInfo) {
- val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(REGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send Long of this broadcast
- oosST.writeObject(id)
- oosST.flush()
-
- // Send this tracker's information
- oosST.writeObject(gInfo)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
- def unregisterBroadcast(id: Long) {
- val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send Long of this broadcast
- oosST.writeObject(id)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
- // Helper method to convert an object to Array[BroadcastBlock]
- def blockifyObject[IN](obj: IN): VariableInfo = {
- val baos = new ByteArrayOutputStream
- val oos = new ObjectOutputStream(baos)
- oos.writeObject(obj)
- oos.close()
- baos.close()
- val byteArray = baos.toByteArray
- val bais = new ByteArrayInputStream(byteArray)
-
- var blockNum = (byteArray.length / BlockSize)
- if (byteArray.length % BlockSize != 0)
- blockNum += 1
-
- var retVal = new Array[BroadcastBlock](blockNum)
- var blockID = 0
-
- for (i <- 0 until (byteArray.length, BlockSize)) {
- val thisBlockSize = math.min(BlockSize, byteArray.length - i)
- var tempByteArray = new Array[Byte](thisBlockSize)
- val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
-
- retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
- blockID += 1
- }
- bais.close()
-
- var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
- variableInfo.hasBlocks = blockNum
-
- return variableInfo
- }
-
- // Helper method to convert Array[BroadcastBlock] to object
- def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
- totalBytes: Int,
- totalBlocks: Int): OUT = {
-
- var retByteArray = new Array[Byte](totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * BlockSize, arrayOfBlocks(i).byteArray.length)
- }
- byteArrayToObject(retByteArray)
- }
-
- private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
- val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
- override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
- }
- val retVal = in.readObject.asInstanceOf[OUT]
- in.close()
- return retVal
- }
-}
-
-private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
-extends Serializable
-
-private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
- totalBlocks: Int,
- totalBytes: Int)
-extends Serializable {
- @transient var hasBlocks = 0
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
deleted file mode 100644
index baa1fd6da4..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.util.BitSet
-
-import org.apache.spark._
-
-/**
- * Used to keep and pass around information of peers involved in a broadcast
- */
-private[spark] case class SourceInfo (hostAddress: String,
- listenPort: Int,
- totalBlocks: Int = SourceInfo.UnusedParam,
- totalBytes: Int = SourceInfo.UnusedParam)
-extends Comparable[SourceInfo] with Logging {
-
- var currentLeechers = 0
- var receptionFailed = false
-
- var hasBlocks = 0
- var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
-
- // Ascending sort based on leecher count
- def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
-}
-
-/**
- * Helper Object of SourceInfo for its constants
- */
-private[spark] object SourceInfo {
- // Broadcast has not started yet! Should never happen.
- val TxNotStartedRetry = -1
- // Broadcast has already finished. Try default mechanism.
- val TxOverGoToDefault = -3
- // Other constants
- val StopBroadcast = -2
- val UnusedParam = 0
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
new file mode 100644
index 0000000000..073a0a5029
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io._
+
+import scala.math
+import scala.util.Random
+
+import org.apache.spark._
+import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+
+private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+extends Broadcast[T](id) with Logging with Serializable {
+
+ def value = value_
+
+ def broadcastId = BroadcastBlockId(id)
+
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ }
+
+ @transient var arrayOfBlocks: Array[TorrentBlock] = null
+ @transient var totalBlocks = -1
+ @transient var totalBytes = -1
+ @transient var hasBlocks = 0
+
+ if (!isLocal) {
+ sendBroadcast()
+ }
+
+ def sendBroadcast() {
+ var tInfo = TorrentBroadcast.blockifyObject(value_)
+
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ hasBlocks = tInfo.totalBlocks
+
+ // Store meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
+ }
+
+ // Store individual pieces
+ for (i <- 0 until totalBlocks) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
+ }
+ }
+ }
+
+ // Called by JVM when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(broadcastId) match {
+ case Some(x) =>
+ value_ = x.asInstanceOf[T]
+
+ case None =>
+ val start = System.nanoTime
+ logInfo("Started reading broadcast variable " + id)
+
+ // Initialize @transient variables that will receive garbage values from the master.
+ resetWorkerVariables()
+
+ if (receiveBroadcast(id)) {
+ value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+
+ // Store the merged copy in cache so that the next worker doesn't need to rebuild it.
+ // This creates a tradeoff between memory usage and latency.
+ // Storing copy doubles the memory footprint; not storing doubles deserialization cost.
+ SparkEnv.get.blockManager.putSingle(
+ broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+
+ // Remove arrayOfBlocks from memory once value_ is on local cache
+ resetWorkerVariables()
+ } else {
+ logError("Reading broadcast variable " + id + " failed")
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ }
+ }
+ }
+
+ private def resetWorkerVariables() {
+ arrayOfBlocks = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = 0
+ }
+
+ def receiveBroadcast(variableID: Long): Boolean = {
+ // Receive meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ var attemptId = 10
+ while (attemptId > 0 && totalBlocks == -1) {
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(metaId) match {
+ case Some(x) =>
+ val tInfo = x.asInstanceOf[TorrentInfo]
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
+ hasBlocks = 0
+
+ case None =>
+ Thread.sleep(500)
+ }
+ }
+ attemptId -= 1
+ }
+ if (totalBlocks == -1) {
+ return false
+ }
+
+ // Receive actual blocks
+ val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
+ for (pid <- recvOrder) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(pieceId) match {
+ case Some(x) =>
+ arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
+ hasBlocks += 1
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
+
+ case None =>
+ throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+ }
+ }
+ }
+
+ (hasBlocks == totalBlocks)
+ }
+
+}
+
+private object TorrentBroadcast
+extends Logging {
+
+ private var initialized = false
+
+ def initialize(_isDriver: Boolean) {
+ synchronized {
+ if (!initialized) {
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ initialized = false
+ }
+
+ val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024
+
+ def blockifyObject[T](obj: T): TorrentInfo = {
+ val byteArray = Utils.serialize[T](obj)
+ val bais = new ByteArrayInputStream(byteArray)
+
+ var blockNum = (byteArray.length / BLOCK_SIZE)
+ if (byteArray.length % BLOCK_SIZE != 0)
+ blockNum += 1
+
+ var retVal = new Array[TorrentBlock](blockNum)
+ var blockID = 0
+
+ for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
+ val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
+ var tempByteArray = new Array[Byte](thisBlockSize)
+ val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+
+ retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close()
+
+ var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
+ tInfo.hasBlocks = blockNum
+
+ return tInfo
+ }
+
+ def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
+ totalBytes: Int,
+ totalBlocks: Int): T = {
+ var retByteArray = new Array[Byte](totalBytes)
+ for (i <- 0 until totalBlocks) {
+ System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
+ i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
+ }
+ Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
+ }
+
+}
+
+private[spark] case class TorrentBlock(
+ blockID: Int,
+ byteArray: Array[Byte])
+ extends Serializable
+
+private[spark] case class TorrentInfo(
+ @transient arrayOfBlocks : Array[TorrentBlock],
+ totalBlocks: Int,
+ totalBytes: Int)
+ extends Serializable {
+
+ @transient var hasBlocks = 0
+}
+
+private[spark] class TorrentBroadcastFactory
+ extends BroadcastFactory {
+
+ def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new TorrentBroadcast[T](value_, isLocal, id)
+
+ def stop() { TorrentBroadcast.stop() }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
deleted file mode 100644
index 51af80a35e..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
+++ /dev/null
@@ -1,601 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-
-import scala.collection.mutable.{ListBuffer, Set}
-
-import org.apache.spark._
-import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
-import org.apache.spark.util.Utils
-
-private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
-extends Broadcast[T](id) with Logging with Serializable {
-
- def value = value_
-
- def blockId = BroadcastBlockId(id)
-
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = 0
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
- @transient var hasBlocksLock = new Object
-
- @transient var listOfSources = ListBuffer[SourceInfo]()
-
- @transient var serveMR: ServeMultipleRequests = null
- @transient var guideMR: GuideMultipleRequests = null
-
- @transient var hostAddress = Utils.localIpAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast()
- }
-
- def sendBroadcast() {
- logInfo("Local host address: " + hostAddress)
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = MultiTracker.blockifyObject(value_)
-
- // Prepare the value being broadcasted
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks = variableInfo.totalBlocks
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon(true)
- guideMR.start()
- logInfo("GuideMultipleRequests started...")
-
- // Must always come AFTER guideMR is created
- while (guidePort == -1) {
- guidePortLock.synchronized { guidePortLock.wait() }
- }
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- // Must always come AFTER serveMR is created
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Must always come AFTER listenPort is created
- val masterSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
- listOfSources += masterSource
-
- // Register with the Tracker
- MultiTracker.registerBroadcast(id,
- SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
- }
-
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(blockId) match {
- case Some(x) =>
- value_ = x.asInstanceOf[T]
-
- case None =>
- logInfo("Started reading broadcast variable " + id)
- // Initializing everything because Driver will only send null/0 values
- // Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables()
-
- logInfo("Local host address: " + hostAddress)
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(id)
- if (receptionSucceeded) {
- value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- private def initializeWorkerVariables() {
- arrayOfBlocks = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
-
- listenPortLock = new Object
- totalBlocksLock = new Object
- hasBlocksLock = new Object
-
- serveMR = null
-
- hostAddress = Utils.localIpAddress
- listenPort = -1
-
- stopBroadcast = false
- }
-
- def receiveBroadcast(variableID: Long): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableID)
-
- if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- var clientSocketToDriver: Socket = null
- var oosDriver: ObjectOutputStream = null
- var oisDriver: ObjectInputStream = null
-
- // Connect and receive broadcast from the specified source, retrying the
- // specified number of times in case of failures
- var retriesLeft = MultiTracker.MaxRetryCount
- do {
- // Connect to Driver and send this worker's Information
- clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
- oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
- oosDriver.flush()
- oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
-
- logDebug("Connected to Driver's guiding object")
-
- // Send local source information
- oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
- oosDriver.flush()
-
- // Receive source information from Driver
- var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
- totalBlocks = sourceInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
- totalBytes = sourceInfo.totalBytes
-
- logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
-
- val start = System.nanoTime
- val receptionSucceeded = receiveSingleTransmission(sourceInfo)
- val time = (System.nanoTime - start) / 1e9
-
- // Updating some statistics in sourceInfo. Driver will be using them later
- if (!receptionSucceeded) {
- sourceInfo.receptionFailed = true
- }
-
- // Send back statistics to the Driver
- oosDriver.writeObject(sourceInfo)
-
- if (oisDriver != null) {
- oisDriver.close()
- }
- if (oosDriver != null) {
- oosDriver.close()
- }
- if (clientSocketToDriver != null) {
- clientSocketToDriver.close()
- }
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && hasBlocks < totalBlocks)
-
- return (hasBlocks == totalBlocks)
- }
-
- /**
- * Tries to receive broadcast from the source and returns Boolean status.
- * This might be called multiple times to retry a defined number of times.
- */
- private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
- var clientSocketToSource: Socket = null
- var oosSource: ObjectOutputStream = null
- var oisSource: ObjectInputStream = null
-
- var receptionSucceeded = false
- try {
- // Connect to the source to get the object itself
- clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
- oosSource.flush()
- oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
-
- logDebug("Inside receiveSingleTransmission")
- logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
-
- // Send the range
- oosSource.writeObject((hasBlocks, totalBlocks))
- oosSource.flush()
-
- for (i <- hasBlocks until totalBlocks) {
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
-
- arrayOfBlocks(hasBlocks) = bcBlock
- hasBlocks += 1
-
- // Set to true if at least one block is received
- receptionSucceeded = true
- hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
- }
- } catch {
- case e: Exception => logError("receiveSingleTransmission had a " + e)
- } finally {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (clientSocketToSource != null) {
- clientSocketToSource.close()
- }
- }
-
- return receptionSucceeded
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo]()
-
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool("Tree broadcast guide multiple requests")
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- guidePort = serverSocket.getLocalPort
- logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized { guidePortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
- listOfSources.synchronized {
- setOfCompletedSources.synchronized {
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
- logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
- }
- }
- }
- }
- }
- if (clientSocket != null) {
- logDebug("Guide: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new GuideSingleRequest(clientSocket))
- } catch {
- // In failure, close() the socket here; else, the thread will close() it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
-
- logInfo("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- MultiTracker.unregisterBroadcast(id)
- } finally {
- if (serverSocket != null) {
- logInfo("GuideMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- private def sendStopBroadcastNotifications() {
- listOfSources.synchronized {
- var listIter = listOfSources.iterator
- while (listIter.hasNext) {
- var sourceInfo = listIter.next
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
- gosSource.flush()
- gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
-
- // Send stopBroadcast signal
- gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
- gosSource.flush()
- } catch {
- case e: Exception => {
- logError("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close()
- }
- if (gosSource != null) {
- gosSource.close()
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close()
- }
- }
- }
- }
- }
-
- class GuideSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var selectedSourceInfo: SourceInfo = null
- private var thisWorkerInfo:SourceInfo = null
-
- override def run() {
- try {
- logInfo("new GuideSingleRequest is running")
- // Connecting worker is sending in its hostAddress and listenPort it will
- // be listening to. Other fields are invalid (SourceInfo.UnusedParam)
- var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- listOfSources.synchronized {
- // Select a suitable source and send it back to the worker
- selectedSourceInfo = selectSuitableSource(sourceInfo)
- logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
- oos.writeObject(selectedSourceInfo)
- oos.flush()
-
- // Add this new (if it can finish) source to the list of sources
- thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
- sourceInfo.listenPort, totalBlocks, totalBytes)
- logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
- listOfSources += thisWorkerInfo
- }
-
- // Wait till the whole transfer is done. Then receive and update source
- // statistics in listOfSources
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- listOfSources.synchronized {
- // This should work since SourceInfo is a case class
- assert(listOfSources.contains(selectedSourceInfo))
-
- // Remove first
- // (Currently removing a source based on just one failure notification!)
- listOfSources = listOfSources - selectedSourceInfo
-
- // Update sourceInfo and put it back in, IF reception succeeded
- if (!sourceInfo.receptionFailed) {
- // Add thisWorkerInfo to sources that have completed reception
- setOfCompletedSources.synchronized {
- setOfCompletedSources += thisWorkerInfo
- }
-
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- listOfSources += selectedSourceInfo
- }
- }
- } catch {
- case e: Exception => {
- // Remove failed worker from listOfSources and update leecherCount of
- // corresponding source worker
- listOfSources.synchronized {
- if (selectedSourceInfo != null) {
- // Remove first
- listOfSources = listOfSources - selectedSourceInfo
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- listOfSources += selectedSourceInfo
- }
-
- // Remove thisWorkerInfo
- if (listOfSources != null) {
- listOfSources = listOfSources - thisWorkerInfo
- }
- }
- }
- } finally {
- logInfo("GuideSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- // Assuming the caller to have a synchronized block on listOfSources
- // Select one with the most leechers. This will level-wise fill the tree
- private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
- var maxLeechers = -1
- var selectedSource: SourceInfo = null
-
- listOfSources.foreach { source =>
- if ((source.hostAddress != skipSourceInfo.hostAddress ||
- source.listenPort != skipSourceInfo.listenPort) &&
- source.currentLeechers < MultiTracker.MaxDegree &&
- source.currentLeechers > maxLeechers) {
- selectedSource = source
- maxLeechers = source.currentLeechers
- }
- }
-
- // Update leecher count
- selectedSource.currentLeechers += 1
- return selectedSource
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
-
- var threadPool = Utils.newDaemonCachedThreadPool("Tree broadcast serve multiple requests")
-
- override def run() {
- var serverSocket = new ServerSocket(0)
- listenPort = serverSocket.getLocalPort
-
- logInfo("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized { listenPortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => { }
- }
-
- if (clientSocket != null) {
- logDebug("Serve: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new ServeSingleRequest(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ServeMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ServeSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var sendFrom = 0
- private var sendUntil = totalBlocks
-
- override def run() {
- try {
- logInfo("new ServeSingleRequest is running")
-
- // Receive range to send
- var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
- sendFrom = rangeToSend._1
- sendUntil = rangeToSend._2
-
- // If not a valid range, stop broadcast
- if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- sendObject
- }
- } catch {
- case e: Exception => logError("ServeSingleRequest had a " + e)
- } finally {
- logInfo("ServeSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- private def sendObject() {
- // Wait till receiving the SourceInfo from Driver
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized { totalBlocksLock.wait() }
- }
-
- for (i <- sendFrom until sendUntil) {
- while (i == hasBlocks) {
- hasBlocksLock.synchronized { hasBlocksLock.wait() }
- }
- try {
- oos.writeObject(arrayOfBlocks(i))
- oos.flush()
- } catch {
- case e: Exception => logError("sendObject had a " + e)
- }
- logDebug("Sent block: " + i + " to " + clientSocket)
- }
- }
- }
- }
-}
-
-private[spark] class TreeBroadcastFactory
-extends BroadcastFactory {
- def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new TreeBroadcast[T](value_, isLocal, id)
-
- def stop() { MultiTracker.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala
index fcfea96ad6..37dfa7fec0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala
@@ -17,8 +17,7 @@
package org.apache.spark.deploy
-private[spark] object ExecutorState
- extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") {
+private[spark] object ExecutorState extends Enumeration {
val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
index 668032a3a2..0aa8852649 100644
--- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -1,19 +1,19 @@
/*
*
- * * Licensed to the Apache Software Foundation (ASF) under one or more
- * * contributor license agreements. See the NOTICE file distributed with
- * * this work for additional information regarding copyright ownership.
- * * The ASF licenses this file to You under the Apache License, Version 2.0
- * * (the "License"); you may not use this file except in compliance with
- * * the License. You may obtain a copy of the License at
- * *
- * * http://www.apache.org/licenses/LICENSE-2.0
- * *
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS,
- * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * * See the License for the specific language governing permissions and
- * * limitations under the License.
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
*
*/
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index 308a2bfa22..59d12a3e6f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -17,12 +17,12 @@
package org.apache.spark.deploy
-import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
+import akka.actor.ActorSystem
import org.apache.spark.deploy.worker.Worker
import org.apache.spark.deploy.master.Master
-import org.apache.spark.util.{Utils, AkkaUtils}
-import org.apache.spark.{Logging}
+import org.apache.spark.util.Utils
+import org.apache.spark.Logging
import scala.collection.mutable.ArrayBuffer
@@ -34,11 +34,11 @@ import scala.collection.mutable.ArrayBuffer
*/
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
-
+
private val localHostname = Utils.localHostName()
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
-
+
def start(): Array[String] = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
@@ -61,10 +61,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
def stop() {
logInfo("Shutting down local Spark cluster.")
// Stop the workers before the master so they don't get upset that it disconnected
+ // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors!
+ // This is unfortunate, but for now we just comment it out.
workerActorSystems.foreach(_.shutdown())
- workerActorSystems.foreach(_.awaitTermination())
-
+ //workerActorSystems.foreach(_.awaitTermination())
masterActorSystems.foreach(_.shutdown())
- masterActorSystems.foreach(_.awaitTermination())
+ //masterActorSystems.foreach(_.awaitTermination())
+ masterActorSystems.clear()
+ workerActorSystems.clear()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 993ba6bd3d..fc1537f796 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,28 +17,70 @@
package org.apache.spark.deploy
-import com.google.common.collect.MapMaker
+import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.spark.{SparkContext, SparkException}
/**
- * Contains util methods to interact with Hadoop from spark.
+ * Contains util methods to interact with Hadoop from Spark.
*/
+private[spark]
class SparkHadoopUtil {
- // A general, soft-reference map for metadata needed during HadoopRDD split computation
- // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
- private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
+ val conf = newConfiguration()
+ UserGroupInformation.setConfiguration(conf)
- // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop
- // subsystems
+ def runAsUser(user: String)(func: () => Unit) {
+ // if we are already running as the user intended there is no reason to do the doAs. It
+ // will actually break secure HDFS access as it doesn't fill in the credentials. Also if
+ // the user is UNKNOWN then we shouldn't be creating a remote unknown user
+ // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only
+ // in SparkContext.
+ val currentUser = Option(System.getProperty("user.name")).
+ getOrElse(SparkContext.SPARK_UNKNOWN_USER)
+ if (user != SparkContext.SPARK_UNKNOWN_USER && currentUser != user) {
+ val ugi = UserGroupInformation.createRemoteUser(user)
+ ugi.doAs(new PrivilegedExceptionAction[Unit] {
+ def run: Unit = func()
+ })
+ } else {
+ func()
+ }
+ }
+
+ /**
+ * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
+ * subsystems.
+ */
def newConfiguration(): Configuration = new Configuration()
- // Add any user credentials to the job conf which are necessary for running on a secure Hadoop
- // cluster
+ /**
+ * Add any user credentials to the job conf which are necessary for running on a secure Hadoop
+ * cluster.
+ */
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
+}
+
+object SparkHadoopUtil {
+ private val hadoop = {
+ val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
+ if (yarnMode) {
+ try {
+ Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
+ } catch {
+ case th: Throwable => throw new SparkException("Unable to load YARN support", th)
+ }
+ } else {
+ new SparkHadoopUtil
+ }
+ }
+ def get: SparkHadoopUtil = {
+ hadoop
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
index 77422f61ec..953755e40d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
@@ -19,20 +19,18 @@ package org.apache.spark.deploy.client
import java.util.concurrent.TimeoutException
+import scala.concurrent.duration._
+import scala.concurrent.Await
+
import akka.actor._
-import akka.actor.Terminated
import akka.pattern.ask
-import akka.util.Duration
-import akka.util.duration._
-import akka.remote.RemoteClientDisconnected
-import akka.remote.RemoteClientLifeCycleEvent
-import akka.remote.RemoteClientShutdown
-import akka.dispatch.Await
-
-import org.apache.spark.Logging
+import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent}
+
+import org.apache.spark.{SparkException, Logging}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
+import org.apache.spark.util.AkkaUtils
/**
@@ -51,18 +49,19 @@ private[spark] class Client(
val REGISTRATION_TIMEOUT = 20.seconds
val REGISTRATION_RETRIES = 3
+ var masterAddress: Address = null
var actor: ActorRef = null
var appId: String = null
var registered = false
var activeMasterUrl: String = null
class ClientActor extends Actor with Logging {
- var master: ActorRef = null
- var masterAddress: Address = null
+ var master: ActorSelection = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
var alreadyDead = false // To avoid calling listener.dead() multiple times
override def preStart() {
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
try {
registerWithMaster()
} catch {
@@ -76,7 +75,7 @@ private[spark] class Client(
def tryRegisterAllMasters() {
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
- val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
actor ! RegisterApplication(appDescription)
}
}
@@ -84,6 +83,7 @@ private[spark] class Client(
def registerWithMaster() {
tryRegisterAllMasters()
+ import context.dispatcher
var retries = 0
lazy val retryTimer: Cancellable =
context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
@@ -102,10 +102,13 @@ private[spark] class Client(
def changeMaster(url: String) {
activeMasterUrl = url
- master = context.actorFor(Master.toAkkaUrl(url))
- masterAddress = master.path.address
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
+ masterAddress = activeMasterUrl match {
+ case Master.sparkUrlRegex(host, port) =>
+ Address("akka.tcp", Master.systemName, host, port.toInt)
+ case x =>
+ throw new SparkException("Invalid spark URL: " + x)
+ }
}
override def receive = {
@@ -135,21 +138,12 @@ private[spark] class Client(
case MasterChanged(masterUrl, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterUrl)
- context.unwatch(master)
changeMaster(masterUrl)
alreadyDisconnected = false
sender ! MasterChangeAcknowledged(appId)
- case Terminated(actor_) if actor_ == master =>
- logWarning("Connection to master failed; waiting for master to reconnect...")
- markDisconnected()
-
- case RemoteClientDisconnected(transport, address) if address == masterAddress =>
- logWarning("Connection to master failed; waiting for master to reconnect...")
- markDisconnected()
-
- case RemoteClientShutdown(transport, address) if address == masterAddress =>
- logWarning("Connection to master failed; waiting for master to reconnect...")
+ case DisassociatedEvent(_, address, _) if address == masterAddress =>
+ logWarning(s"Connection to $address failed; waiting for master to reconnect...")
markDisconnected()
case StopClient =>
@@ -184,7 +178,7 @@ private[spark] class Client(
def stop() {
if (actor != null) {
try {
- val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val timeout = AkkaUtils.askTimeout
val future = actor.ask(StopClient)(timeout)
Await.result(future, timeout)
} catch {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
index fedf879eff..67e6c5d66a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
@@ -17,8 +17,7 @@
package org.apache.spark.deploy.master
-private[spark] object ApplicationState
- extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED", "UNKNOWN") {
+private[spark] object ApplicationState extends Enumeration {
type ApplicationState = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index c0849ef324..043945a211 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -65,7 +65,7 @@ private[spark] class FileSystemPersistenceEngine(
(apps, workers)
}
- private def serializeIntoFile(file: File, value: Serializable) {
+ private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
@@ -77,13 +77,13 @@ private[spark] class FileSystemPersistenceEngine(
out.close()
}
- def deserializeFromFile[T <: Serializable](file: File)(implicit m: Manifest[T]): T = {
+ def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
val dis = new DataInputStream(new FileInputStream(file))
dis.readFully(fileData)
dis.close()
- val clazz = m.erasure.asInstanceOf[Class[T]]
+ val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
serializer.fromBinary(fileData).asInstanceOf[T]
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index cd916672ac..eebd0794b8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -17,19 +17,17 @@
package org.apache.spark.deploy.master
-import java.util.Date
import java.text.SimpleDateFormat
+import java.util.Date
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.concurrent.Await
+import scala.concurrent.duration._
import akka.actor._
-import akka.actor.Terminated
-import akka.dispatch.Await
import akka.pattern.ask
-import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown}
+import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.SerializationExtension
-import akka.util.duration._
-import akka.util.{Duration, Timeout}
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
@@ -40,6 +38,8 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{AkkaUtils, Utils}
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
+ import context.dispatcher
+
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt
@@ -61,8 +61,6 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val waitingApps = new ArrayBuffer[ApplicationInfo]
val completedApps = new ArrayBuffer[ApplicationInfo]
- var firstApp: Option[ApplicationInfo] = None
-
Utils.checkHost(host, "Expected hostname")
val masterMetricsSystem = MetricsSystem.createMetricsSystem("master")
@@ -93,7 +91,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
override def preStart() {
logInfo("Starting Spark master at " + masterUrl)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.start()
masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get
context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
@@ -113,13 +111,12 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
new BlackHolePersistenceEngine()
}
- leaderElectionAgent = context.actorOf(Props(
- RECOVERY_MODE match {
+ leaderElectionAgent = RECOVERY_MODE match {
case "ZOOKEEPER" =>
- new ZooKeeperLeaderElectionAgent(self, masterUrl)
+ context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl))
case _ =>
- new MonarchyLeaderAgent(self)
- }))
+ context.actorOf(Props(classOf[MonarchyLeaderAgent], self))
+ }
}
override def preRestart(reason: Throwable, message: Option[Any]) {
@@ -142,9 +139,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
RecoveryState.ALIVE
else
RecoveryState.RECOVERING
-
logInfo("I have been elected leader! New state: " + state)
-
if (state == RecoveryState.RECOVERING) {
beginRecovery(storedApps, storedWorkers)
context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() }
@@ -156,7 +151,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
System.exit(0)
}
- case RegisterWorker(id, host, workerPort, cores, memory, webUiPort, publicAddress) => {
+ case RegisterWorker(id, workerHost, workerPort, cores, memory, workerWebUiPort, publicAddress) => {
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
host, workerPort, cores, Utils.megabytesToString(memory)))
if (state == RecoveryState.STANDBY) {
@@ -164,9 +159,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
} else if (idToWorker.contains(id)) {
sender ! RegisterWorkerFailed("Duplicate worker ID")
} else {
- val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
+ val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
+ sender, workerWebUiPort, publicAddress)
registerWorker(worker)
- context.watch(sender) // This doesn't work with remote actors but helps for testing
persistenceEngine.addWorker(worker)
sender ! RegisteredWorker(masterUrl, masterWebUiUrl)
schedule()
@@ -181,7 +176,6 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val app = createApplication(description, sender)
registerApplication(app)
logInfo("Registered app " + description.name + " with ID " + app.id)
- context.watch(sender) // This doesn't work with remote actors but helps for testing
persistenceEngine.addApplication(app)
sender ! RegisteredApplication(app.id, masterUrl)
schedule()
@@ -257,23 +251,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
if (canCompleteRecovery) { completeRecovery() }
}
- case Terminated(actor) => {
- // The disconnected actor could've been either a worker or an app; remove whichever of
- // those we have an entry for in the corresponding actor hashmap
- actorToWorker.get(actor).foreach(removeWorker)
- actorToApp.get(actor).foreach(finishApplication)
- if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
- }
-
- case RemoteClientDisconnected(transport, address) => {
- // The disconnected client could've been either a worker or an app; remove whichever it was
- addressToWorker.get(address).foreach(removeWorker)
- addressToApp.get(address).foreach(finishApplication)
- if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
- }
-
- case RemoteClientShutdown(transport, address) => {
+ case DisassociatedEvent(_, address, _) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
+ logInfo(s"$address got disassociated, removing it.")
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(finishApplication)
if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
@@ -459,14 +439,6 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
idToApp(app.id) = app
actorToApp(app.driver) = app
addressToApp(appAddress) = app
- if (firstApp == None) {
- firstApp = Some(app)
- }
- // TODO: What is firstApp?? Can we remove it?
- val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray
- if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= app.desc.memoryPerSlave)) {
- logWarning("Could not find any workers with enough memory for " + firstApp.get.id)
- }
waitingApps += app
}
@@ -530,9 +502,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
private[spark] object Master {
- private val systemName = "sparkMaster"
+ val systemName = "sparkMaster"
private val actorName = "Master"
- private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
+ val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
@@ -540,11 +512,11 @@ private[spark] object Master {
actorSystem.awaitTermination()
}
- /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
+ /** Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
def toAkkaUrl(sparkUrl: String): String = {
sparkUrl match {
case sparkUrlRegex(host, port) =>
- "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
+ "akka.tcp://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
case _ =>
throw new SparkException("Invalid master URL: " + sparkUrl)
}
@@ -552,12 +524,10 @@ private[spark] object Master {
def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
- val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
- val timeoutDuration = Duration.create(
- System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
- implicit val timeout = Timeout(timeoutDuration)
- val respFuture = actor ? RequestWebUIPort // ask pattern
- val resp = Await.result(respFuture, timeoutDuration).asInstanceOf[WebUIPortResponse]
+ val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName)
+ val timeout = AkkaUtils.askTimeout
+ val respFuture = actor.ask(RequestWebUIPort)(timeout)
+ val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse]
(actorSystem, boundPort, resp.webUIBoundPort)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
index b91be821f0..256a5a7c28 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
@@ -17,9 +17,7 @@
package org.apache.spark.deploy.master
-private[spark] object RecoveryState
- extends Enumeration("STANDBY", "ALIVE", "RECOVERING", "COMPLETING_RECOVERY") {
-
+private[spark] object RecoveryState extends Enumeration {
type MasterState = Value
val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
index 81e15c534f..6cc7fd2ff4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala
@@ -18,12 +18,12 @@
package org.apache.spark.deploy.master
import scala.collection.JavaConversions._
-import scala.concurrent.ops._
-import org.apache.spark.Logging
import org.apache.zookeeper._
-import org.apache.zookeeper.data.Stat
import org.apache.zookeeper.Watcher.Event.KeeperState
+import org.apache.zookeeper.data.Stat
+
+import org.apache.spark.Logging
/**
* Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry
@@ -33,7 +33,7 @@ import org.apache.zookeeper.Watcher.Event.KeeperState
* informed via zkDown().
*
* Additionally, all commands sent to ZooKeeper will be retried until they either fail too many
- * times or a semantic exception is thrown (e.g.., "node already exists").
+ * times or a semantic exception is thrown (e.g., "node already exists").
*/
private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) extends Logging {
val ZK_URL = System.getProperty("spark.deploy.zookeeper.url", "")
@@ -103,6 +103,7 @@ private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) ext
connectToZooKeeper()
case KeeperState.Disconnected =>
logWarning("ZooKeeper disconnected, will retry...")
+ case s => // Do nothing
}
}
}
@@ -179,7 +180,7 @@ private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) ext
} catch {
case e: KeeperException.NoNodeException => throw e
case e: KeeperException.NodeExistsException => throw e
- case e if n > 0 =>
+ case e: Exception if n > 0 =>
logError("ZooKeeper exception, " + n + " more retries...", e)
Thread.sleep(RETRY_WAIT_MILLIS)
retry(fn, n-1)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
index c8d34f25e2..0b36ef6005 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
@@ -17,9 +17,7 @@
package org.apache.spark.deploy.master
-private[spark] object WorkerState
- extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED", "UNKNOWN") {
-
+private[spark] object WorkerState extends Enumeration {
type WorkerState = Value
val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 7809013e83..7d535b08de 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -21,8 +21,8 @@ import akka.actor.ActorRef
import org.apache.zookeeper._
import org.apache.zookeeper.Watcher.Event.EventType
-import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.Logging
+import org.apache.spark.deploy.master.MasterMessages._
private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, masterUrl: String)
extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging {
@@ -105,7 +105,7 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, mas
// We found a different master file pointing to this process.
// This can happen in the following two cases:
// (1) The master process was restarted on the same node.
- // (2) The ZK server died between creating the node and returning the name of the node.
+ // (2) The ZK server died between creating the file and returning the name of the file.
// For this case, we will end up creating a second file, and MUST explicitly delete the
// first one, since our ZK session is still open.
// Note that this deletion will cause a NodeDeleted event to be fired so we check again for
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index a0233a7271..825344b3bb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -70,15 +70,15 @@ class ZooKeeperPersistenceEngine(serialization: Serialization)
(apps, workers)
}
- private def serializeIntoFile(path: String, value: Serializable) {
+ private def serializeIntoFile(path: String, value: AnyRef) {
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
zk.create(path, serialized, CreateMode.PERSISTENT)
}
- def deserializeFromFile[T <: Serializable](filename: String)(implicit m: Manifest[T]): T = {
+ def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = {
val fileData = zk.getData("/spark/master_status/" + filename)
- val clazz = m.erasure.asInstanceOf[Class[T]]
+ val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
serializer.fromBinary(fileData).asInstanceOf[T]
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index f4e574d15d..dbb0cb90f5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -17,31 +17,28 @@
package org.apache.spark.deploy.master.ui
+import scala.concurrent.Await
import scala.xml.Node
-import akka.dispatch.Await
import akka.pattern.ask
-import akka.util.duration._
-
import javax.servlet.http.HttpServletRequest
-
import net.liftweb.json.JsonAST.JValue
-import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.JsonProtocol
+import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
import org.apache.spark.deploy.master.ExecutorInfo
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
private[spark] class ApplicationPage(parent: MasterWebUI) {
val master = parent.masterActorRef
- implicit val timeout = parent.timeout
+ val timeout = parent.timeout
/** Executor details for a particular application */
def renderJson(request: HttpServletRequest): JValue = {
val appId = request.getParameter("appId")
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, 30 seconds)
+ val state = Await.result(stateFuture, timeout)
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
})
@@ -52,7 +49,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) {
def render(request: HttpServletRequest): Seq[Node] = {
val appId = request.getParameter("appId")
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, 30 seconds)
+ val state = Await.result(stateFuture, timeout)
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
})
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
index d7a57229b0..4ef762892c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
@@ -17,37 +17,33 @@
package org.apache.spark.deploy.master.ui
-import javax.servlet.http.HttpServletRequest
-
+import scala.concurrent.Await
import scala.xml.Node
-import akka.dispatch.Await
import akka.pattern.ask
-import akka.util.duration._
-
+import javax.servlet.http.HttpServletRequest
import net.liftweb.json.JsonAST.JValue
-import org.apache.spark.deploy.DeployWebUI
+import org.apache.spark.deploy.{DeployWebUI, JsonProtocol}
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
-import org.apache.spark.deploy.JsonProtocol
import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
private[spark] class IndexPage(parent: MasterWebUI) {
val master = parent.masterActorRef
- implicit val timeout = parent.timeout
+ val timeout = parent.timeout
def renderJson(request: HttpServletRequest): JValue = {
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, 30 seconds)
+ val state = Await.result(stateFuture, timeout)
JsonProtocol.writeMasterState(state)
}
/** Index view listing applications and executors */
def render(request: HttpServletRequest): Seq[Node] = {
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
- val state = Await.result(stateFuture, 30 seconds)
+ val state = Await.result(stateFuture, timeout)
val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory")
val workers = state.workers.sortBy(_.id)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index f4df729e87..9ab594b682 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -17,25 +17,21 @@
package org.apache.spark.deploy.master.ui
-import akka.util.Duration
-
import javax.servlet.http.HttpServletRequest
-
import org.eclipse.jetty.server.{Handler, Server}
-import org.apache.spark.{Logging}
+import org.apache.spark.Logging
import org.apache.spark.deploy.master.Master
import org.apache.spark.ui.JettyUtils
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* Web UI server for the standalone master.
*/
private[spark]
class MasterWebUI(val master: Master, requestedPort: Int) extends Logging {
- implicit val timeout = Duration.create(
- System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val timeout = AkkaUtils.askTimeout
val host = Utils.localHostName()
val port = requestedPort
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 8fabc95665..fff9cb60c7 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -104,7 +104,7 @@ private[spark] class ExecutorRunner(
// SPARK-698: do not call the run.cmd script, as process.destroy()
// fails to kill a process tree on Windows
Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
- command.arguments.map(substituteVariables)
+ (command.arguments ++ Seq(appId)).map(substituteVariables)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 216d9d44ac..87531b6719 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -17,23 +17,31 @@
package org.apache.spark.deploy.worker
+import java.io.File
import java.text.SimpleDateFormat
import java.util.Date
-import java.io.File
import scala.collection.mutable.HashMap
+import scala.concurrent.duration._
import akka.actor._
-import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
-import akka.util.duration._
+import akka.remote.{ DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.Logging
+import org.apache.spark.{SparkException, Logging}
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
import org.apache.spark.deploy.worker.ui.WorkerWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.spark.deploy.DeployMessages.WorkerStateResponse
+import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed
+import org.apache.spark.deploy.DeployMessages.KillExecutor
+import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
+import org.apache.spark.deploy.DeployMessages.Heartbeat
+import org.apache.spark.deploy.DeployMessages.RegisteredWorker
+import org.apache.spark.deploy.DeployMessages.LaunchExecutor
+import org.apache.spark.deploy.DeployMessages.RegisterWorker
/**
* @param masterUrls Each url should look like spark://host:port.
@@ -47,6 +55,7 @@ private[spark] class Worker(
masterUrls: Array[String],
workDirPath: String = null)
extends Actor with Logging {
+ import context.dispatcher
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
@@ -63,7 +72,8 @@ private[spark] class Worker(
var masterIndex = 0
val masterLock: Object = new Object()
- var master: ActorRef = null
+ var master: ActorSelection = null
+ var masterAddress: Address = null
var activeMasterUrl: String = ""
var activeMasterWebUiUrl : String = ""
@volatile var registered = false
@@ -114,7 +124,7 @@ private[spark] class Worker(
logInfo("Spark home: " + sparkHome)
createWorkDir()
webUi = new WorkerWebUI(this, workDir, Some(webUiPort))
-
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.start()
registerWithMaster()
@@ -126,9 +136,13 @@ private[spark] class Worker(
masterLock.synchronized {
activeMasterUrl = url
activeMasterWebUiUrl = uiUrl
- master = context.actorFor(Master.toAkkaUrl(activeMasterUrl))
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
+ masterAddress = activeMasterUrl match {
+ case Master.sparkUrlRegex(_host, _port) =>
+ Address("akka.tcp", Master.systemName, _host, _port.toInt)
+ case x =>
+ throw new SparkException("Invalid spark URL: " + x)
+ }
connected = true
}
}
@@ -136,7 +150,7 @@ private[spark] class Worker(
def tryRegisterAllMasters() {
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
- val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get,
publicAddress)
}
@@ -175,7 +189,6 @@ private[spark] class Worker(
case MasterChanged(masterUrl, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterUrl)
- context.unwatch(master)
changeMaster(masterUrl, masterWebUiUrl)
val execs = executors.values.
@@ -234,13 +247,8 @@ private[spark] class Worker(
}
}
- case Terminated(actor_) if actor_ == master =>
- masterDisconnected()
-
- case RemoteClientDisconnected(transport, address) if address == master.path.address =>
- masterDisconnected()
-
- case RemoteClientShutdown(transport, address) if address == master.path.address =>
+ case x: DisassociatedEvent if x.remoteAddress == masterAddress =>
+ logInfo(s"$x Disassociated !")
masterDisconnected()
case RequestWorkerState => {
@@ -280,8 +288,8 @@ private[spark] object Worker {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
- val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
- masterUrls, workDir)), name = "Worker")
+ actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
+ masterUrls, workDir), name = "Worker")
(actorSystem, boundPort)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
index d2d3617498..0d59048313 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
@@ -21,9 +21,10 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import akka.dispatch.Await
+import scala.concurrent.duration._
+import scala.concurrent.Await
+
import akka.pattern.ask
-import akka.util.duration._
import net.liftweb.json.JsonAST.JValue
@@ -41,13 +42,13 @@ private[spark] class IndexPage(parent: WorkerWebUI) {
def renderJson(request: HttpServletRequest): JValue = {
val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
- val workerState = Await.result(stateFuture, 30 seconds)
+ val workerState = Await.result(stateFuture, timeout)
JsonProtocol.writeWorkerState(workerState)
}
def render(request: HttpServletRequest): Seq[Node] = {
val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse]
- val workerState = Await.result(stateFuture, 30 seconds)
+ val workerState = Await.result(stateFuture, timeout)
val executorHeaders = Seq("ExecutorID", "Cores", "Memory", "Job Details", "Logs")
val runningExecutorTable =
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index 800f1cafcc..40d6bdb3fd 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -17,20 +17,16 @@
package org.apache.spark.deploy.worker.ui
-import akka.util.{Duration, Timeout}
-
-import java.io.{FileInputStream, File}
+import java.io.File
import javax.servlet.http.HttpServletRequest
-
import org.eclipse.jetty.server.{Handler, Server}
+import org.apache.spark.Logging
import org.apache.spark.deploy.worker.Worker
-import org.apache.spark.{Logging}
-import org.apache.spark.ui.JettyUtils
+import org.apache.spark.ui.{JettyUtils, UIUtils}
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.ui.UIUtils
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* Web UI server for the standalone worker.
@@ -38,8 +34,7 @@ import org.apache.spark.util.Utils
private[spark]
class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None)
extends Logging {
- implicit val timeout = Timeout(
- Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
+ val timeout = AkkaUtils.askTimeout
val host = Utils.localHostName()
val port = requestedPort.getOrElse(
System.getProperty("worker.ui.port", WorkerWebUI.DEFAULT_PORT).toInt)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 52b1c492b2..debbdd4c44 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -19,15 +19,14 @@ package org.apache.spark.executor
import java.nio.ByteBuffer
-import akka.actor.{ActorRef, Actor, Props, Terminated}
-import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
+import akka.actor._
+import akka.remote._
-import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.Logging
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{Utils, AkkaUtils}
-
private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
executorId: String,
@@ -40,14 +39,13 @@ private[spark] class CoarseGrainedExecutorBackend(
Utils.checkHostPort(hostPort, "Expected hostport")
var executor: Executor = null
- var driver: ActorRef = null
+ var driver: ActorSelection = null
override def preStart() {
logInfo("Connecting to driver: " + driverUrl)
- driver = context.actorFor(driverUrl)
+ driver = context.actorSelection(driverUrl)
driver ! RegisterExecutor(executorId, hostPort, cores)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(driver) // Doesn't work with remote actors, but useful for testing
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
}
override def receive = {
@@ -77,9 +75,14 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.killTask(taskId)
}
- case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
- logError("Driver terminated or disconnected! Shutting down.")
+ case x: DisassociatedEvent =>
+ logError(s"Driver $x disassociated! Shutting down.")
System.exit(1)
+
+ case StopExecutor =>
+ logInfo("Driver commanded a shutdown")
+ context.stop(self)
+ context.system.shutdown()
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
@@ -94,19 +97,20 @@ private[spark] object CoarseGrainedExecutorBackend {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0,
+ indestructible = true)
// set it
val sparkHostPort = hostname + ":" + boundPort
System.setProperty("spark.hostPort", sparkHostPort)
- val actor = actorSystem.actorOf(
- Props(new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
+ actorSystem.actorOf(
+ Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
name = "Executor")
actorSystem.awaitTermination()
}
def main(args: Array[String]) {
if (args.length < 4) {
- //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
+ //the reason we allow the last appid argument is to make it easy to kill rogue executors
System.err.println(
"Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " +
"[<appid>]")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 032eb04f43..0f19d7a96b 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -25,8 +25,9 @@ import java.util.concurrent._
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
-import org.apache.spark.scheduler._
import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils
@@ -74,30 +75,33 @@ private[spark] class Executor(
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)
- // Make any thread terminations due to uncaught exceptions kill the entire
- // executor process to avoid surprising stalls.
- Thread.setDefaultUncaughtExceptionHandler(
- new Thread.UncaughtExceptionHandler {
- override def uncaughtException(thread: Thread, exception: Throwable) {
- try {
- logError("Uncaught exception in thread " + thread, exception)
-
- // We may have been called from a shutdown hook. If so, we must not call System.exit().
- // (If we do, we will deadlock.)
- if (!Utils.inShutdown()) {
- if (exception.isInstanceOf[OutOfMemoryError]) {
- System.exit(ExecutorExitCode.OOM)
- } else {
- System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ if (!isLocal) {
+ // Setup an uncaught exception handler for non-local mode.
+ // Make any thread terminations due to uncaught exceptions kill the entire
+ // executor process to avoid surprising stalls.
+ Thread.setDefaultUncaughtExceptionHandler(
+ new Thread.UncaughtExceptionHandler {
+ override def uncaughtException(thread: Thread, exception: Throwable) {
+ try {
+ logError("Uncaught exception in thread " + thread, exception)
+
+ // We may have been called from a shutdown hook. If so, we must not call System.exit().
+ // (If we do, we will deadlock.)
+ if (!Utils.inShutdown()) {
+ if (exception.isInstanceOf[OutOfMemoryError]) {
+ System.exit(ExecutorExitCode.OOM)
+ } else {
+ System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ }
}
+ } catch {
+ case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
+ case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
- } catch {
- case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
- case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
}
- }
- )
+ )
+ }
val executorSource = new ExecutorSource(this, executorId)
@@ -117,7 +121,7 @@ private[spark] class Executor(
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = {
- env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
+ env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
}
// Start worker thread pool
@@ -126,6 +130,8 @@ private[spark] class Executor(
// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
+ val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
+
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
val tr = new TaskRunner(context, taskId, serializedTask)
runningTasks.put(taskId, tr)
@@ -173,7 +179,7 @@ private[spark] class Executor(
}
}
- override def run() {
+ override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () =>
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
@@ -216,18 +222,22 @@ private[spark] class Executor(
return
}
+ val resultSer = SparkEnv.get.serializer.newInstance()
+ val beforeSerialization = System.currentTimeMillis()
+ val valueBytes = resultSer.serialize(value)
+ val afterSerialization = System.currentTimeMillis()
+
for (m <- task.metrics) {
m.hostname = Utils.localHostName()
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
m.jvmGCTime = gcTime - startGCTime
+ m.resultSerializationTime = (afterSerialization - beforeSerialization).toInt
}
- // TODO I'd also like to track the time it takes to serialize the task results, but that is
- // huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a
- // custom serialized format, we could just change the relevants bytes in the byte buffer
+
val accumUpdates = Accumulators.values
- val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null))
+ val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.getOrElse(null))
val serializedDirectResult = ser.serialize(directResult)
logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
val serializedResult = {
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
index 34ed9c8f73..97176e4f5b 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala
@@ -20,8 +20,6 @@ package org.apache.spark.executor
import com.codahale.metrics.{Gauge, MetricRegistry}
import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.hdfs.DistributedFileSystem
-import org.apache.hadoop.fs.LocalFileSystem
import scala.collection.JavaConversions._
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index f311141148..bb1471d9ee 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -44,6 +44,11 @@ class TaskMetrics extends Serializable {
var jvmGCTime: Long = _
/**
+ * Amount of time spent serializing the task result
+ */
+ var resultSerializationTime: Long = _
+
+ /**
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here
*/
var shuffleReadMetrics: Option[ShuffleReadMetrics] = None
@@ -61,45 +66,53 @@ object TaskMetrics {
class ShuffleReadMetrics extends Serializable {
/**
- * Time when shuffle finishs
+ * Absolute time when this task finished reading shuffle data
*/
var shuffleFinishTime: Long = _
/**
- * Total number of blocks fetched in a shuffle (remote or local)
+ * Number of blocks fetched in this shuffle by this task (remote or local)
*/
var totalBlocksFetched: Int = _
/**
- * Number of remote blocks fetched in a shuffle
+ * Number of remote blocks fetched in this shuffle by this task
*/
var remoteBlocksFetched: Int = _
/**
- * Local blocks fetched in a shuffle
+ * Number of local blocks fetched in this shuffle by this task
*/
var localBlocksFetched: Int = _
/**
- * Total time that is spent blocked waiting for shuffle to fetch data
+ * Time the task spent waiting for remote shuffle blocks. This only includes the time
+ * blocking on shuffle input data. For instance if block B is being fetched while the task is
+ * still not finished processing block A, it is not considered to be blocking on block B.
*/
var fetchWaitTime: Long = _
/**
- * The total amount of time for all the shuffle fetches. This adds up time from overlapping
- * shuffles, so can be longer than task time
+ * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all
+ * input blocks. Since block fetches are both pipelined and parallelized, this can
+ * exceed fetchWaitTime and executorRunTime.
*/
var remoteFetchTime: Long = _
/**
- * Total number of remote bytes read from a shuffle
+ * Total number of remote bytes read from the shuffle by this task
*/
var remoteBytesRead: Long = _
}
class ShuffleWriteMetrics extends Serializable {
/**
- * Number of bytes written for a shuffle
+ * Number of bytes written for the shuffle by this task
*/
var shuffleBytesWritten: Long = _
+
+ /**
+ * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
+ */
+ var shuffleWriteTime: Long = _
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
new file mode 100644
index 0000000000..cdcfec8ca7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+import java.net.InetSocketAddress
+
+import com.codahale.metrics.MetricRegistry
+import com.codahale.metrics.graphite.{GraphiteReporter, Graphite}
+
+import org.apache.spark.metrics.MetricsSystem
+
+class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+ val GRAPHITE_DEFAULT_PERIOD = 10
+ val GRAPHITE_DEFAULT_UNIT = "SECONDS"
+ val GRAPHITE_DEFAULT_PREFIX = ""
+
+ val GRAPHITE_KEY_HOST = "host"
+ val GRAPHITE_KEY_PORT = "port"
+ val GRAPHITE_KEY_PERIOD = "period"
+ val GRAPHITE_KEY_UNIT = "unit"
+ val GRAPHITE_KEY_PREFIX = "prefix"
+
+ def propertyToOption(prop: String) = Option(property.getProperty(prop))
+
+ if (!propertyToOption(GRAPHITE_KEY_HOST).isDefined) {
+ throw new Exception("Graphite sink requires 'host' property.")
+ }
+
+ if (!propertyToOption(GRAPHITE_KEY_PORT).isDefined) {
+ throw new Exception("Graphite sink requires 'port' property.")
+ }
+
+ val host = propertyToOption(GRAPHITE_KEY_HOST).get
+ val port = propertyToOption(GRAPHITE_KEY_PORT).get.toInt
+
+ val pollPeriod = propertyToOption(GRAPHITE_KEY_PERIOD) match {
+ case Some(s) => s.toInt
+ case None => GRAPHITE_DEFAULT_PERIOD
+ }
+
+ val pollUnit = propertyToOption(GRAPHITE_KEY_UNIT) match {
+ case Some(s) => TimeUnit.valueOf(s.toUpperCase())
+ case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT)
+ }
+
+ val prefix = propertyToOption(GRAPHITE_KEY_PREFIX).getOrElse(GRAPHITE_DEFAULT_PREFIX)
+
+ MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
+
+ val graphite: Graphite = new Graphite(new InetSocketAddress(host, port))
+
+ val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry)
+ .convertDurationsTo(TimeUnit.MILLISECONDS)
+ .convertRatesTo(TimeUnit.SECONDS)
+ .prefixedWith(prefix)
+ .build(graphite)
+
+ override def start() {
+ reporter.start(pollPeriod, pollUnit)
+ }
+
+ override def stop() {
+ reporter.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index 9c2fee4023..703bc6a9ca 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -31,11 +31,11 @@ import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
import scala.collection.mutable.ArrayBuffer
-import akka.dispatch.{Await, Promise, ExecutionContext, Future}
-import akka.util.Duration
-import akka.util.duration._
-import org.apache.spark.util.Utils
+import scala.concurrent.{Await, Promise, ExecutionContext, Future}
+import scala.concurrent.duration.Duration
+import scala.concurrent.duration._
+import org.apache.spark.util.Utils
private[spark] class ConnectionManager(port: Int) extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
index 8d9ad9604d..4f5742d29b 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
@@ -25,8 +25,8 @@ import scala.io.Source
import java.nio.ByteBuffer
import java.net.InetAddress
-import akka.dispatch.Await
-import akka.util.duration._
+import scala.concurrent.Await
+import scala.concurrent.duration._
private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
index 481ff8c3e0..b1e1576dad 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
@@ -76,7 +76,7 @@ private[spark] object ShuffleCopier extends Logging {
extends FileClientHandler with Logging {
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
- logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
+ logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)")
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
index 1586dff254..546d921067 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -21,7 +21,7 @@ import java.io.File
import org.apache.spark.Logging
import org.apache.spark.util.Utils
-import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.{BlockId, FileSegment}
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@@ -54,8 +54,7 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
- override def getAbsolutePath(blockIdString: String): String = {
- val blockId = BlockId(blockIdString)
+ override def getBlockLocation(blockId: BlockId): FileSegment = {
if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
@@ -65,7 +64,7 @@ private[spark] object ShuffleSender {
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId.name)
- return file.getAbsolutePath
+ return new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index f132e2b735..70a5a8caff 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -15,6 +15,8 @@
* limitations under the License.
*/
+package org.apache
+
/**
* Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to
* Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection,
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index faaf837be0..d1c74a5063 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global
+import scala.reflect.ClassTag
import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
@@ -28,7 +29,7 @@ import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
* A set of asynchronous RDD actions available through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
-class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging {
+class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
/**
* Returns a future for counting the number of elements in the RDD.
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index 44ea573a7c..424354ae16 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
import org.apache.spark.storage.{BlockId, BlockManager}
@@ -25,7 +27,7 @@ private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends P
}
private[spark]
-class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[BlockId])
+class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId])
extends RDD[T](sc, Nil) {
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
index 9b0c882481..87b950ba43 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
@@ -18,6 +18,7 @@
package org.apache.spark.rdd
import java.io.{ObjectOutputStream, IOException}
+import scala.reflect.ClassTag
import org.apache.spark._
@@ -43,7 +44,7 @@ class CartesianPartition(
}
private[spark]
-class CartesianRDD[T: ClassManifest, U:ClassManifest](
+class CartesianRDD[T: ClassTag, U: ClassTag](
sc: SparkContext,
var rdd1 : RDD[T],
var rdd2 : RDD[U])
@@ -70,7 +71,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
override def compute(split: Partition, context: TaskContext) = {
val currSplit = split.asInstanceOf[CartesianPartition]
for (x <- rdd1.iterator(currSplit.s1, context);
- y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
+ y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
override def getDependencies: Seq[Dependency[_]] = List(
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index ccaaecb85b..a712ef1c27 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -17,14 +17,13 @@
package org.apache.spark.rdd
-import org.apache.spark._
-import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.io.{NullWritable, BytesWritable}
-import org.apache.hadoop.util.ReflectionUtils
+import java.io.IOException
+
+import scala.reflect.ClassTag
+
import org.apache.hadoop.fs.Path
-import java.io.{File, IOException, EOFException}
-import java.text.NumberFormat
+import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
@@ -32,7 +31,7 @@ private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
*/
private[spark]
-class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
+class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
extends RDD[T](sc, Nil) {
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
@@ -83,7 +82,7 @@ private[spark] object CheckpointRDD extends Logging {
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val env = SparkEnv.get
val outputDir = new Path(path)
- val fs = outputDir.getFileSystem(env.hadoop.newConfiguration())
+ val fs = outputDir.getFileSystem(SparkHadoopUtil.get.newConfiguration())
val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
@@ -122,7 +121,7 @@ private[spark] object CheckpointRDD extends Logging {
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
val env = SparkEnv.get
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(path, bufferSize)
val serializer = env.serializer.newInstance()
@@ -145,7 +144,7 @@ private[spark] object CheckpointRDD extends Logging {
val sc = new SparkContext(cluster, "CheckpointRDD Test")
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
val path = new Path(hdfsPath, "temp")
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index c5de6362a9..98da35763b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -22,6 +22,7 @@ import java.io.{ObjectOutputStream, IOException}
import scala.collection.mutable
import scala.Some
import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
/**
* Class that captures a coalesced RDD by essentially keeping track of parent partitions
@@ -68,7 +69,7 @@ case class CoalescedRDDPartition(
* @param maxPartitions number of desired partitions in the coalesced RDD
* @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
*/
-class CoalescedRDD[T: ClassManifest](
+class CoalescedRDD[T: ClassTag](
@transient var prev: RDD[T],
maxPartitions: Int,
balanceSlack: Double = 0.10)
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index a4bec41752..688c310ee9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator
import org.apache.spark.util.StatCounter
import org.apache.spark.{TaskContext, Logging}
+import scala.collection.immutable.NumericRange
+
/**
* Extra functions available on RDDs of Doubles through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
@@ -76,4 +78,129 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
val evaluator = new SumEvaluator(self.partitions.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = {
+ // Compute the minimum and the maxium
+ val (max: Double, min: Double) = self.mapPartitions { items =>
+ Iterator(items.foldRight(Double.NegativeInfinity,
+ Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) =>
+ (x._1.max(e), x._2.min(e))))
+ }.reduce { (maxmin1, maxmin2) =>
+ (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2))
+ }
+ if (min.isNaN || max.isNaN || max.isInfinity || min.isInfinity ) {
+ throw new UnsupportedOperationException(
+ "Histogram on either an empty RDD or RDD containing +/-infinity or NaN")
+ }
+ val increment = (max-min)/bucketCount.toDouble
+ val range = if (increment != 0) {
+ Range.Double.inclusive(min, max, increment)
+ } else {
+ List(min, min)
+ }
+ val buckets = range.toArray
+ (buckets, histogram(buckets, true))
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1, 0, 0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = {
+ if (buckets.length < 2) {
+ throw new IllegalArgumentException("buckets array must have at least two elements")
+ }
+ // The histogramPartition function computes the partail histogram for a given
+ // partition. The provided bucketFunction determines which bucket in the array
+ // to increment or returns None if there is no bucket. This is done so we can
+ // specialize for uniformly distributed buckets and save the O(log n) binary
+ // search cost.
+ def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]):
+ Iterator[Array[Long]] = {
+ val counters = new Array[Long](buckets.length - 1)
+ while (iter.hasNext) {
+ bucketFunction(iter.next()) match {
+ case Some(x: Int) => {counters(x) += 1}
+ case _ => {}
+ }
+ }
+ Iterator(counters)
+ }
+ // Merge the counters.
+ def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = {
+ a1.indices.foreach(i => a1(i) += a2(i))
+ a1
+ }
+ // Basic bucket function. This works using Java's built in Array
+ // binary search. Takes log(size(buckets))
+ def basicBucketFunction(e: Double): Option[Int] = {
+ val location = java.util.Arrays.binarySearch(buckets, e)
+ if (location < 0) {
+ // If the location is less than 0 then the insertion point in the array
+ // to keep it sorted is -location-1
+ val insertionPoint = -location-1
+ // If we have to insert before the first element or after the last one
+ // its out of bounds.
+ // We do this rather than buckets.lengthCompare(insertionPoint)
+ // because Array[Double] fails to override it (for now).
+ if (insertionPoint > 0 && insertionPoint < buckets.length) {
+ Some(insertionPoint-1)
+ } else {
+ None
+ }
+ } else if (location < buckets.length - 1) {
+ // Exact match, just insert here
+ Some(location)
+ } else {
+ // Exact match to the last element
+ Some(location - 1)
+ }
+ }
+ // Determine the bucket function in constant time. Requires that buckets are evenly spaced
+ def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = {
+ // If our input is not a number unless the increment is also NaN then we fail fast
+ if (e.isNaN()) {
+ return None
+ }
+ val bucketNumber = (e - min)/(increment)
+ // We do this rather than buckets.lengthCompare(bucketNumber)
+ // because Array[Double] fails to override it (for now).
+ if (bucketNumber > count || bucketNumber < 0) {
+ None
+ } else {
+ Some(bucketNumber.toInt.min(count - 1))
+ }
+ }
+ // Decide which bucket function to pass to histogramPartition. We decide here
+ // rather than having a general function so that the decission need only be made
+ // once rather than once per shard
+ val bucketFunction = if (evenBuckets) {
+ fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _
+ } else {
+ basicBucketFunction _
+ }
+ self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
index c8900d1a93..a84e5f9fd8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
@@ -17,13 +17,14 @@
package org.apache.spark.rdd
-import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
+import scala.reflect.ClassTag
+import org.apache.spark.{Partition, SparkContext, TaskContext}
/**
* An RDD that is empty, i.e. has no element in it.
*/
-class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) {
+class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) {
override def getPartitions: Array[Partition] = Array.empty
diff --git a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
index 5312dc0b59..e74c83b90b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
@@ -18,8 +18,9 @@
package org.apache.spark.rdd
import org.apache.spark.{OneToOneDependency, Partition, TaskContext}
+import scala.reflect.ClassTag
-private[spark] class FilteredRDD[T: ClassManifest](
+private[spark] class FilteredRDD[T: ClassTag](
prev: RDD[T],
f: T => Boolean)
extends RDD[T](prev) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
index cbdf6d84c0..4d1878fc14 100644
--- a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
@@ -18,10 +18,11 @@
package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
+import scala.reflect.ClassTag
private[spark]
-class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
+class FlatMappedRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
f: T => TraversableOnce[U])
extends RDD[U](prev) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
index 829545d7b0..1a694475f6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
@@ -18,8 +18,9 @@
package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
+import scala.reflect.ClassTag
-private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T])
+private[spark] class GlommedRDD[T: ClassTag](prev: RDD[T])
extends RDD[Array[T]](prev) {
override def getPartitions: Array[Partition] = firstParent[T].partitions
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index fad042c7ae..53f77a38f5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -29,6 +29,7 @@ import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.NextIterator
import org.apache.hadoop.conf.{Configuration, Configurable}
@@ -51,7 +52,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
* sources in HBase, or S3).
*
* @param sc The SparkContext to associate the RDD with.
- * @param broadCastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
+ * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
* variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job.
* Otherwise, a new JobConf will be created on each slave using the enclosed Configuration.
* @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD
@@ -131,6 +132,8 @@ class HadoopRDD[K, V](
override def getPartitions: Array[Partition] = {
val jobConf = getJobConf()
+ // add the credentials here as this can be called before SparkContext initialized
+ SparkHadoopUtil.get.addCredentials(jobConf)
val inputFormat = getInputFormat(jobConf)
if (inputFormat.isInstanceOf[Configurable]) {
inputFormat.asInstanceOf[Configurable].setConf(jobConf)
@@ -198,10 +201,10 @@ private[spark] object HadoopRDD {
* The three methods below are helpers for accessing the local map, a property of the SparkEnv of
* the local process.
*/
- def getCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.get(key)
+ def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key)
- def containsCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.containsKey(key)
+ def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key)
def putCachedMetadata(key: String, value: Any) =
- SparkEnv.get.hadoop.hadoopJobMetadata.put(key, value)
+ SparkEnv.get.hadoopJobMetadata.put(key, value)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index aca0146884..8df8718f3b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -19,6 +19,8 @@ package org.apache.spark.rdd
import java.sql.{Connection, ResultSet}
+import scala.reflect.ClassTag
+
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
import org.apache.spark.util.NextIterator
@@ -45,7 +47,7 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e
* This should only call getInt, getString, etc; the RDD takes care of calling next.
* The default maps a ResultSet to an array of Object.
*/
-class JdbcRDD[T: ClassManifest](
+class JdbcRDD[T: ClassTag](
sc: SparkContext,
getConnection: () => Connection,
sql: String,
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index 203179c4ea..db15baf503 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -18,20 +18,18 @@
package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
+import scala.reflect.ClassTag
-
-private[spark]
-class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
+private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
- f: Iterator[T] => Iterator[U],
+ f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {
- override val partitioner =
- if (preservesPartitioning) firstParent[T].partitioner else None
+ override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
override def getPartitions: Array[Partition] = firstParent[T].partitions
override def compute(split: Partition, context: TaskContext) =
- f(firstParent[T].iterator(split, context))
+ f(context, split.index, firstParent[T].iterator(split, context))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
index e8be1c4816..8d7c288593 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
@@ -17,10 +17,12 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.spark.{Partition, TaskContext}
private[spark]
-class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
+class MappedRDD[U: ClassTag, T: ClassTag](prev: RDD[T], f: T => U)
extends RDD[U](prev) {
override def getPartitions: Array[Partition] = firstParent[T].partitions
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index 697be8b997..d5691f2267 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -17,7 +17,9 @@
package org.apache.spark.rdd
-import org.apache.spark.{RangePartitioner, Logging}
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Logging, RangePartitioner}
/**
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
@@ -25,9 +27,9 @@ import org.apache.spark.{RangePartitioner, Logging}
* use these functions. They will work with any key type that has a `scala.math.Ordered`
* implementation.
*/
-class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest,
- V: ClassManifest,
- P <: Product2[K, V] : ClassManifest](
+class OrderedRDDFunctions[K <% Ordered[K]: ClassTag,
+ V: ClassTag,
+ P <: Product2[K, V] : ClassTag](
self: RDD[P])
extends Logging with Serializable {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 322b519bd2..4e4f860b19 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -25,6 +25,7 @@ import java.util.{HashMap => JHashMap}
import scala.collection.{mutable, Map}
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
+import scala.reflect.{ClassTag, classTag}
import org.apache.hadoop.mapred._
import org.apache.hadoop.io.compress.CompressionCodec
@@ -53,7 +54,7 @@ import org.apache.spark.util.SerializableHyperLogLog
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
-class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
+class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
extends Logging
with SparkHadoopMapReduceUtil
with Serializable {
@@ -466,7 +467,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other), partitioner)
- val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
+ val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag)
prfs.mapValues { case Seq(vs, ws) =>
(vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
}
@@ -482,7 +483,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner)
- val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
+ val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag)
prfs.mapValues { case Seq(vs, w1s, w2s) =>
(vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]])
}
@@ -539,15 +540,15 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
- def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] =
+ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] =
subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
- def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
+ def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
subtractByKey(other, new HashPartitioner(numPartitions))
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
- def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
+ def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
new SubtractedRDD[K, V, W](self, other, p)
/**
@@ -576,8 +577,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD.
*/
- def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
- saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) {
+ saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}
/**
@@ -586,16 +587,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
* supplied codec.
*/
def saveAsHadoopFile[F <: OutputFormat[K, V]](
- path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) {
- saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec)
+ path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) {
+ saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]], codec)
}
/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
- def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
- saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) {
+ saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}
/**
@@ -749,11 +750,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
*/
def values: RDD[V] = self.map(_._2)
- private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure
+ private[spark] def getKeyClass() = implicitly[ClassTag[K]].runtimeClass
- private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure
+ private[spark] def getValueClass() = implicitly[ClassTag[V]].runtimeClass
}
-private[spark] object Manifests {
- val seqSeqManifest = classManifest[Seq[Seq[_]]]
+private[spark] object ClassTags {
+ val seqSeqClassTag = classTag[Seq[Seq[_]]]
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index cd96250389..09d0a8189d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -20,13 +20,15 @@ package org.apache.spark.rdd
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
import scala.collection.Map
+import scala.reflect.ClassTag
+
import org.apache.spark._
import java.io._
import scala.Serializable
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.Utils
-private[spark] class ParallelCollectionPartition[T: ClassManifest](
+private[spark] class ParallelCollectionPartition[T: ClassTag](
var rddId: Long,
var slice: Int,
var values: Seq[T])
@@ -78,7 +80,7 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest](
}
}
-private[spark] class ParallelCollectionRDD[T: ClassManifest](
+private[spark] class ParallelCollectionRDD[T: ClassTag](
@transient sc: SparkContext,
@transient data: Seq[T],
numSlices: Int,
@@ -109,7 +111,7 @@ private object ParallelCollectionRDD {
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
* it efficient to run Spark over RDDs representing large sets of numbers.
*/
- def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
+ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
if (numSlices < 1) {
throw new IllegalArgumentException("Positive number of slices required")
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
index 165cd412fc..ea8885b36e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.spark.{NarrowDependency, SparkEnv, Partition, TaskContext}
@@ -33,11 +35,13 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
extends NarrowDependency[T](rdd) {
@transient
- val partitions: Array[Partition] = rdd.partitions.zipWithIndex
- .filter(s => partitionFilterFunc(s._2))
+ val partitions: Array[Partition] = rdd.partitions
+ .filter(s => partitionFilterFunc(s.index)).zipWithIndex
.map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition }
- override def getParents(partitionId: Int) = List(partitions(partitionId).index)
+ override def getParents(partitionId: Int) = {
+ List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index)
+ }
}
@@ -47,7 +51,7 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
*/
-class PartitionPruningRDD[T: ClassManifest](
+class PartitionPruningRDD[T: ClassTag](
@transient prev: RDD[T],
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
@@ -67,6 +71,6 @@ object PartitionPruningRDD {
* when its type T is not known at compile time.
*/
def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = {
- new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassManifest)
+ new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassTag)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index d5304ab0ae..1dbbe39898 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -24,6 +24,7 @@ import scala.collection.Map
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
+import scala.reflect.ClassTag
import org.apache.spark.{SparkEnv, Partition, TaskContext}
import org.apache.spark.broadcast.Broadcast
@@ -33,7 +34,7 @@ import org.apache.spark.broadcast.Broadcast
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
*/
-class PipedRDD[T: ClassManifest](
+class PipedRDD[T: ClassTag](
prev: RDD[T],
command: Seq[String],
envVars: Map[String, String],
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index e23e7a63a1..136fa45327 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -23,6 +23,9 @@ import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.reflect.{classTag, ClassTag}
+
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.NullWritable
@@ -70,7 +73,7 @@ import org.apache.spark._
* [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
* on RDD internals.
*/
-abstract class RDD[T: ClassManifest](
+abstract class RDD[T: ClassTag](
@transient private var sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
@@ -102,7 +105,7 @@ abstract class RDD[T: ClassManifest](
protected def getPreferredLocations(split: Partition): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
- val partitioner: Option[Partitioner] = None
+ @transient val partitioner: Option[Partitioner] = None
// =======================================================================
// Methods and fields available on all RDDs
@@ -115,7 +118,7 @@ abstract class RDD[T: ClassManifest](
val id: Int = sc.newRddId()
/** A friendly name for this RDD */
- var name: String = null
+ @transient var name: String = null
/** Assign a name to this RDD */
def setName(_name: String) = {
@@ -124,7 +127,7 @@ abstract class RDD[T: ClassManifest](
}
/** User-defined generator of this RDD*/
- var generator = Utils.getCallSiteInfo.firstUserClass
+ @transient var generator = Utils.getCallSiteInfo.firstUserClass
/** Reset generator*/
def setGenerator(_generator: String) = {
@@ -244,13 +247,13 @@ abstract class RDD[T: ClassManifest](
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
- def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f))
+ def map[U: ClassTag](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f))
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
- def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] =
+ def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] =
new FlatMappedRDD(this, sc.clean(f))
/**
@@ -267,6 +270,19 @@ abstract class RDD[T: ClassManifest](
def distinct(): RDD[T] = distinct(partitions.size)
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): RDD[T] = {
+ coalesce(numPartitions, true)
+ }
+
+ /**
* Return a new RDD that is reduced into `numPartitions` partitions.
*
* This results in a narrow dependency, e.g. if you go from 1000 partitions
@@ -362,25 +378,25 @@ abstract class RDD[T: ClassManifest](
* Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
* elements (a, b) where a is in `this` and b is in `other`.
*/
- def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
+ def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
/**
* Return an RDD of grouped items.
*/
- def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] =
+ def groupBy[K: ClassTag](f: T => K): RDD[(K, Seq[T])] =
groupBy[K](f, defaultPartitioner(this))
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] =
+ def groupBy[K: ClassTag](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] =
groupBy(f, new HashPartitioner(numPartitions))
/**
* Return an RDD of grouped items.
*/
- def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = {
+ def groupBy[K: ClassTag](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = {
val cleanF = sc.clean(f)
this.map(t => (cleanF(t), t)).groupByKey(p)
}
@@ -396,7 +412,6 @@ abstract class RDD[T: ClassManifest](
def pipe(command: String, env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
-
/**
* Return an RDD created by piping elements to a forked external process.
* The print behavior can be customized by providing two functions.
@@ -428,29 +443,31 @@ abstract class RDD[T: ClassManifest](
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[U: ClassManifest](
+ def mapPartitions[U: ClassTag](
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
- new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
- def mapPartitionsWithIndex[U: ClassManifest](
+ def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
- val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
- new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
* Return a new RDD by applying a function to each partition of this RDD. This is a variant of
* mapPartitions that also passes the TaskContext into the closure.
*/
- def mapPartitionsWithContext[U: ClassManifest](
+ def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
- new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
@@ -458,7 +475,7 @@ abstract class RDD[T: ClassManifest](
* of the original partition.
*/
@deprecated("use mapPartitionsWithIndex", "0.7.0")
- def mapPartitionsWithSplit[U: ClassManifest](
+ def mapPartitionsWithSplit[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
mapPartitionsWithIndex(f, preservesPartitioning)
}
@@ -468,14 +485,13 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def mapWith[A: ClassManifest, U: ClassManifest]
+ def mapWith[A: ClassTag, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.map(t => f(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }, preservesPartitioning)
}
/**
@@ -483,14 +499,13 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def flatMapWith[A: ClassManifest, U: ClassManifest]
+ def flatMapWith[A: ClassTag, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.flatMap(t => f(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }, preservesPartitioning)
}
/**
@@ -498,12 +513,11 @@ abstract class RDD[T: ClassManifest](
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(context.partitionId)
+ def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) {
+ mapPartitionsWithIndex { (index, iter) =>
+ val a = constructA(index)
iter.map(t => {f(t, a); t})
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
+ }.foreach(_ => {})
}
/**
@@ -511,12 +525,11 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(context.partitionId)
+ def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.filter(t => p(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
+ }, preservesPartitioning = true)
}
/**
@@ -525,7 +538,7 @@ abstract class RDD[T: ClassManifest](
* partitions* and the *same number of elements in each partition* (e.g. one was made through
* a map on the other).
*/
- def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
+ def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
@@ -533,20 +546,30 @@ abstract class RDD[T: ClassManifest](
* *same number of partitions*, but does *not* require them to have the same number
* of elements in each partition.
*/
- def zipPartitions[B: ClassManifest, V: ClassManifest]
+ def zipPartitions[B: ClassTag, V: ClassTag]
(rdd2: RDD[B])
(f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] =
- new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2)
+ new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false)
- def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]
+ def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag]
+ (rdd2: RDD[B], rdd3: RDD[C], preservesPartitioning: Boolean)
+ (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] =
+ new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, preservesPartitioning)
+
+ def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag]
(rdd2: RDD[B], rdd3: RDD[C])
(f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] =
- new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3)
+ new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, false)
+
+ def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]
+ (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D], preservesPartitioning: Boolean)
+ (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] =
+ new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, preservesPartitioning)
- def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]
+ def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]
(rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D])
(f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] =
- new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4)
+ new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, false)
// Actions (launch a job to return a value to the user program)
@@ -581,7 +604,7 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD that contains all matching values by applying `f`.
*/
- def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = {
+ def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = {
filter(f.isDefinedAt).map(f)
}
@@ -671,7 +694,7 @@ abstract class RDD[T: ClassManifest](
* allowed to modify and return their first argument instead of creating a new U to avoid memory
* allocation.
*/
- def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
+ def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
@@ -720,7 +743,7 @@ abstract class RDD[T: ClassManifest](
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): Map[T, Long] = {
- if (elementClassManifest.erasure.isArray) {
+ if (elementClassTag.runtimeClass.isArray) {
throw new SparkException("countByValue() does not support arrays")
}
// TODO: This should perhaps be distributed by default.
@@ -751,7 +774,7 @@ abstract class RDD[T: ClassManifest](
timeout: Long,
confidence: Double = 0.95
): PartialResult[Map[T, BoundedDouble]] = {
- if (elementClassManifest.erasure.isArray) {
+ if (elementClassTag.runtimeClass.isArray) {
throw new SparkException("countByValueApprox() does not support arrays")
}
val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
@@ -939,14 +962,14 @@ abstract class RDD[T: ClassManifest](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
- private[spark] val origin = Utils.formatSparkCallSite
+ @transient private[spark] val origin = Utils.formatSparkCallSite
- private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+ private[spark] def elementClassTag: ClassTag[T] = classTag[T]
private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
/** Returns the first parent RDD */
- protected[spark] def firstParent[U: ClassManifest] = {
+ protected[spark] def firstParent[U: ClassTag] = {
dependencies.head.rdd.asInstanceOf[RDD[U]]
}
@@ -954,7 +977,7 @@ abstract class RDD[T: ClassManifest](
def context = sc
// Avoid handling doCheckpoint multiple times to prevent excessive recursion
- private var doCheckpointCalled = false
+ @transient private var doCheckpointCalled = false
/**
* Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
@@ -1008,7 +1031,7 @@ abstract class RDD[T: ClassManifest](
origin)
def toJavaRDD() : JavaRDD[T] = {
- new JavaRDD(this)(elementClassManifest)
+ new JavaRDD(this)(elementClassTag)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 6009a41570..3b56e45aa9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
@@ -38,7 +40,7 @@ private[spark] object CheckpointState extends Enumeration {
* manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations
* of the checkpointed RDD.
*/
-private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
+private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T])
extends Logging with Serializable {
import CheckpointState._
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
index 2c5253ae30..d433670cc2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
@@ -17,6 +17,7 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
import java.util.Random
import cern.jet.random.Poisson
@@ -29,9 +30,9 @@ class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition
override val index: Int = prev.index
}
-class SampledRDD[T: ClassManifest](
+class SampledRDD[T: ClassTag](
prev: RDD[T],
- withReplacement: Boolean,
+ withReplacement: Boolean,
frac: Double,
seed: Int)
extends RDD[T](prev) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
index 5fe4676029..2d1bd5b481 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
@@ -14,9 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.rdd
+import scala.reflect.{ ClassTag, classTag}
+
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.io.compress.CompressionCodec
@@ -32,15 +33,15 @@ import org.apache.spark.Logging
*
* Import `org.apache.spark.SparkContext._` at the top of their program to use these functions.
*/
-class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest](
+class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag](
self: RDD[(K, V)])
extends Logging
with Serializable {
- private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = {
+ private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = {
val c = {
- if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
- classManifest[T].erasure
+ if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) {
+ classTag[T].runtimeClass
} else {
// We get the type of the Writable class by looking at the apply method which converts
// from T to Writable. Since we have two apply methods we filter out the one which
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index a5d751a7bd..3682c84598 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -17,8 +17,10 @@
package org.apache.spark.rdd
-import org.apache.spark.{Dependency, Partitioner, SparkEnv, ShuffleDependency, Partition, TaskContext}
+import scala.reflect.ClassTag
+import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency,
+ SparkEnv, TaskContext}
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
@@ -32,7 +34,7 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* @tparam K the key class.
* @tparam V the value class.
*/
-class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest](
+class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
@transient var prev: RDD[P],
part: Partitioner)
extends RDD[P](prev.context, Nil) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 7af4d803e7..aab30b1bb4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -18,8 +18,11 @@
package org.apache.spark.rdd
import java.util.{HashMap => JHashMap}
+
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
import org.apache.spark.Partitioner
import org.apache.spark.Dependency
import org.apache.spark.TaskContext
@@ -45,7 +48,7 @@ import org.apache.spark.OneToOneDependency
* you can use `rdd1`'s partitioner/partition size and not worry about running
* out of memory because of the size of `rdd2`.
*/
-private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
+private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
@transient var rdd1: RDD[_ <: Product2[K, V]],
@transient var rdd2: RDD[_ <: Product2[K, W]],
part: Partitioner)
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index ae8a9f36a6..08a41ac558 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -18,10 +18,13 @@
package org.apache.spark.rdd
import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
import org.apache.spark.{Dependency, RangeDependency, SparkContext, Partition, TaskContext}
+
import java.io.{ObjectOutputStream, IOException}
-private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
+private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int)
extends Partition {
var split: Partition = rdd.partitions(splitIndex)
@@ -40,7 +43,7 @@ private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], spl
}
}
-class UnionRDD[T: ClassManifest](
+class UnionRDD[T: ClassTag](
sc: SparkContext,
@transient var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index 31e6fd519d..83be3c6eb4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -19,10 +19,12 @@ package org.apache.spark.rdd
import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext}
import java.io.{ObjectOutputStream, IOException}
+import scala.reflect.ClassTag
private[spark] class ZippedPartitionsPartition(
idx: Int,
- @transient rdds: Seq[RDD[_]])
+ @transient rdds: Seq[RDD[_]],
+ @transient val preferredLocations: Seq[String])
extends Partition {
override val index: Int = idx
@@ -37,33 +39,31 @@ private[spark] class ZippedPartitionsPartition(
}
}
-abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
+abstract class ZippedPartitionsBaseRDD[V: ClassTag](
sc: SparkContext,
- var rdds: Seq[RDD[_]])
+ var rdds: Seq[RDD[_]],
+ preservesPartitioning: Boolean = false)
extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
+ override val partitioner =
+ if (preservesPartitioning) firstParent[Any].partitioner else None
+
override def getPartitions: Array[Partition] = {
- val sizes = rdds.map(x => x.partitions.size)
- if (!sizes.forall(x => x == sizes(0))) {
+ val numParts = rdds.head.partitions.size
+ if (!rdds.forall(rdd => rdd.partitions.size == numParts)) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
- val array = new Array[Partition](sizes(0))
- for (i <- 0 until sizes(0)) {
- array(i) = new ZippedPartitionsPartition(i, rdds)
+ Array.tabulate[Partition](numParts) { i =>
+ val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i)))
+ // Check whether there are any hosts that match all RDDs; otherwise return the union
+ val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y))
+ val locs = if (!exactMatchLocations.isEmpty) exactMatchLocations else prefs.flatten.distinct
+ new ZippedPartitionsPartition(i, rdds, locs)
}
- array
}
override def getPreferredLocations(s: Partition): Seq[String] = {
- val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions
- val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) }
- // Check whether there are any hosts that match all RDDs; otherwise return the union
- val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y))
- if (!exactMatchLocations.isEmpty) {
- exactMatchLocations
- } else {
- prefs.flatten.distinct
- }
+ s.asInstanceOf[ZippedPartitionsPartition].preferredLocations
}
override def clearDependencies() {
@@ -72,12 +72,13 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
}
}
-class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest](
+class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B]) => Iterator[V],
var rdd1: RDD[A],
- var rdd2: RDD[B])
- extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) {
+ var rdd2: RDD[B],
+ preservesPartitioning: Boolean = false)
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
@@ -92,13 +93,14 @@ class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]
}
class ZippedPartitionsRDD3
- [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest](
+ [A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
- var rdd3: RDD[C])
- extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) {
+ var rdd3: RDD[C],
+ preservesPartitioning: Boolean = false)
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
@@ -116,14 +118,15 @@ class ZippedPartitionsRDD3
}
class ZippedPartitionsRDD4
- [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest](
+ [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
- var rdd4: RDD[D])
- extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) {
+ var rdd4: RDD[D],
+ preservesPartitioning: Boolean = false)
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
index 567b67dfee..fb5b070c18 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
@@ -18,10 +18,12 @@
package org.apache.spark.rdd
import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext}
+
import java.io.{ObjectOutputStream, IOException}
+import scala.reflect.ClassTag
-private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest](
+private[spark] class ZippedPartition[T: ClassTag, U: ClassTag](
idx: Int,
@transient rdd1: RDD[T],
@transient rdd2: RDD[U]
@@ -42,7 +44,7 @@ private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest](
}
}
-class ZippedRDD[T: ClassManifest, U: ClassManifest](
+class ZippedRDD[T: ClassTag, U: ClassTag](
sc: SparkContext,
var rdd1: RDD[T],
var rdd2: RDD[U])
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 15a04e6558..963d15b76d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -19,10 +19,13 @@ package org.apache.spark.scheduler
import java.io.NotSerializableException
import java.util.Properties
-import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.concurrent.duration._
+import scala.reflect.ClassTag
+
+import akka.actor._
import org.apache.spark._
import org.apache.spark.rdd.RDD
@@ -52,19 +55,25 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
private[spark]
class DAGScheduler(
taskSched: TaskScheduler,
- mapOutputTracker: MapOutputTracker,
+ mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
extends Logging {
def this(taskSched: TaskScheduler) {
- this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ SparkEnv.get.blockManager.master, SparkEnv.get)
}
taskSched.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
- eventQueue.put(BeginEvent(task, taskInfo))
+ eventProcessActor ! BeginEvent(task, taskInfo)
+ }
+
+ // Called to report that a task has completed and results are being fetched remotely.
+ def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
+ eventProcessActor ! GettingResultEvent(task, taskInfo)
}
// Called by TaskScheduler to report task completions or failures.
@@ -75,35 +84,38 @@ class DAGScheduler(
accumUpdates: Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
- eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
+ eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
}
// Called by TaskScheduler when an executor fails.
def executorLost(execId: String) {
- eventQueue.put(ExecutorLost(execId))
+ eventProcessActor ! ExecutorLost(execId)
}
// Called by TaskScheduler when a host is added
def executorGained(execId: String, host: String) {
- eventQueue.put(ExecutorGained(execId, host))
+ eventProcessActor ! ExecutorGained(execId, host)
}
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
def taskSetFailed(taskSet: TaskSet, reason: String) {
- eventQueue.put(TaskSetFailed(taskSet, reason))
+ eventProcessActor ! TaskSetFailed(taskSet, reason)
}
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
- val RESUBMIT_TIMEOUT = 50L
+ val RESUBMIT_TIMEOUT = 50.milliseconds
// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
val POLL_TIMEOUT = 10L
- private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
+ // Warns the user if a stage contains a task with size greater than this value (in KB)
+ val TASK_SIZE_TO_WARN = 100
+
+ private var eventProcessActor: ActorRef = _
private[scheduler] val nextJobId = new AtomicInteger(0)
@@ -111,9 +123,13 @@ class DAGScheduler(
private val nextStageId = new AtomicInteger(0)
- private val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+ private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]]
- private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+ private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]]
+
+ private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+
+ private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
@@ -144,14 +160,55 @@ class DAGScheduler(
val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup)
- // Start a thread to run the DAGScheduler event loop
+ /**
+ * Starts the event processing actor. The actor has two responsibilities:
+ *
+ * 1. Waits for events like job submission, task finished, task failure etc., and calls
+ * [[org.apache.spark.scheduler.DAGScheduler.processEvent()]] to process them.
+ * 2. Schedules a periodical task to resubmit failed stages.
+ *
+ * NOTE: the actor cannot be started in the constructor, because the periodical task references
+ * some internal states of the enclosing [[org.apache.spark.scheduler.DAGScheduler]] object, thus
+ * cannot be scheduled until the [[org.apache.spark.scheduler.DAGScheduler]] is fully constructed.
+ */
def start() {
- new Thread("DAGScheduler") {
- setDaemon(true)
- override def run() {
- DAGScheduler.this.run()
+ eventProcessActor = env.actorSystem.actorOf(Props(new Actor {
+ /**
+ * A handle to the periodical task, used to cancel the task when the actor is stopped.
+ */
+ var resubmissionTask: Cancellable = _
+
+ override def preStart() {
+ import context.dispatcher
+ /**
+ * A message is sent to the actor itself periodically to remind the actor to resubmit failed
+ * stages. In this way, stage resubmission can be done within the same thread context of
+ * other event processing logic to avoid unnecessary synchronization overhead.
+ */
+ resubmissionTask = context.system.scheduler.schedule(
+ RESUBMIT_TIMEOUT, RESUBMIT_TIMEOUT, self, ResubmitFailedStages)
}
- }.start()
+
+ /**
+ * The main event loop of the DAG scheduler.
+ */
+ def receive = {
+ case event: DAGSchedulerEvent =>
+ logDebug("Got event of type " + event.getClass.getName)
+
+ /**
+ * All events are forwarded to `processEvent()`, so that the event processing logic can
+ * easily tested without starting a dedicated actor. Please refer to `DAGSchedulerSuite`
+ * for details.
+ */
+ if (!processEvent(event)) {
+ submitWaitingStages()
+ } else {
+ resubmissionTask.cancel()
+ context.stop(self)
+ }
+ }
+ }))
}
def addSparkListener(listener: SparkListener) {
@@ -182,34 +239,60 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
- val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
+ val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
}
/**
- * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or
- * as a result stage for the final RDD used directly in an action. The stage will also be
- * associated with the provided jobId.
+ * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation
+ * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided
+ * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage directly.
*/
private def newStage(
rdd: RDD[_],
+ numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
jobId: Int,
callSite: Option[String] = None)
: Stage =
{
- if (shuffleDep != None) {
+ val id = nextStageId.getAndIncrement()
+ val stage =
+ new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
+ stageIdToStage(id) = stage
+ updateJobIdStageIdMaps(jobId, stage)
+ stageToInfos(stage) = new StageInfo(stage)
+ stage
+ }
+
+ /**
+ * Create a shuffle map Stage for the given RDD. The stage will also be associated with the
+ * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is
+ * present in the MapOutputTracker, then the number and location of available outputs are
+ * recovered from the MapOutputTracker
+ */
+ private def newOrUsedStage(
+ rdd: RDD[_],
+ numTasks: Int,
+ shuffleDep: ShuffleDependency[_,_],
+ jobId: Int,
+ callSite: Option[String] = None)
+ : Stage =
+ {
+ val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite)
+ if (mapOutputTracker.has(shuffleDep.shuffleId)) {
+ val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
+ val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
+ for (i <- 0 until locs.size) stage.outputLocs(i) = List(locs(i))
+ stage.numAvailableOutputs = locs.size
+ } else {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of partitions is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
- mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
+ mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size)
}
- val id = nextStageId.getAndIncrement()
- val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
- stageIdToStage(id) = stage
- stageToInfos(stage) = StageInfo(stage)
stage
}
@@ -265,6 +348,89 @@ class DAGScheduler(
}
/**
+ * Registers the given jobId among the jobs that need the given stage and
+ * all of that stage's ancestors.
+ */
+ private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) {
+ def updateJobIdStageIdMapsList(stages: List[Stage]) {
+ if (!stages.isEmpty) {
+ val s = stages.head
+ stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId
+ jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
+ val parents = getParentStages(s.rdd, jobId)
+ val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
+ updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
+ }
+ }
+ updateJobIdStageIdMapsList(List(stage))
+ }
+
+ /**
+ * Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that
+ * were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation.
+ */
+ private def removeJobAndIndependentStages(jobId: Int): Set[Int] = {
+ val registeredStages = jobIdToStageIds(jobId)
+ val independentStages = new HashSet[Int]()
+ if (registeredStages.isEmpty) {
+ logError("No stages registered for job " + jobId)
+ } else {
+ stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach {
+ case (stageId, jobSet) =>
+ if (!jobSet.contains(jobId)) {
+ logError("Job %d not registered for stage %d even though that stage was registered for the job"
+ .format(jobId, stageId))
+ } else {
+ def removeStage(stageId: Int) {
+ // data structures based on Stage
+ stageIdToStage.get(stageId).foreach { s =>
+ if (running.contains(s)) {
+ logDebug("Removing running stage %d".format(stageId))
+ running -= s
+ }
+ stageToInfos -= s
+ shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove)
+ if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) {
+ logDebug("Removing pending status for stage %d".format(stageId))
+ }
+ pendingTasks -= s
+ if (waiting.contains(s)) {
+ logDebug("Removing stage %d from waiting set.".format(stageId))
+ waiting -= s
+ }
+ if (failed.contains(s)) {
+ logDebug("Removing stage %d from failed set.".format(stageId))
+ failed -= s
+ }
+ }
+ // data structures based on StageId
+ stageIdToStage -= stageId
+ stageIdToJobIds -= stageId
+
+ logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size))
+ }
+
+ jobSet -= jobId
+ if (jobSet.isEmpty) { // no other job needs this stage
+ independentStages += stageId
+ removeStage(stageId)
+ }
+ }
+ }
+ }
+ independentStages.toSet
+ }
+
+ private def jobIdToStageIdsRemove(jobId: Int) {
+ if (!jobIdToStageIds.contains(jobId)) {
+ logDebug("Trying to remove unregistered job " + jobId)
+ } else {
+ removeJobAndIndependentStages(jobId)
+ jobIdToStageIds -= jobId
+ }
+ }
+
+ /**
* Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
* can be used to block until the the job finishes executing or can be used to cancel the job.
*/
@@ -277,11 +443,6 @@ class DAGScheduler(
resultHandler: (Int, U) => Unit,
properties: Properties = null): JobWaiter[U] =
{
- val jobId = nextJobId.getAndIncrement()
- if (partitions.size == 0) {
- return new JobWaiter[U](this, jobId, 0, resultHandler)
- }
-
// Check to make sure we are not launching a task on a partition that does not exist.
val maxPartitions = rdd.partitions.length
partitions.find(p => p >= maxPartitions).foreach { p =>
@@ -290,15 +451,19 @@ class DAGScheduler(
"Total number of partitions: " + maxPartitions)
}
+ val jobId = nextJobId.getAndIncrement()
+ if (partitions.size == 0) {
+ return new JobWaiter[U](this, jobId, 0, resultHandler)
+ }
+
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
- eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite,
- waiter, properties))
+ eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
waiter
}
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
@@ -329,8 +494,7 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
- eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite,
- listener, properties))
+ eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
listener.awaitResult() // Will throw an exception if the job fails
}
@@ -339,24 +503,41 @@ class DAGScheduler(
*/
def cancelJob(jobId: Int) {
logInfo("Asked to cancel job " + jobId)
- eventQueue.put(JobCancelled(jobId))
+ eventProcessActor ! JobCancelled(jobId)
+ }
+
+ def cancelJobGroup(groupId: String) {
+ logInfo("Asked to cancel job group " + groupId)
+ eventProcessActor ! JobGroupCancelled(groupId)
}
/**
* Cancel all jobs that are running or waiting in the queue.
*/
def cancelAllJobs() {
- eventQueue.put(AllJobsCancelled)
+ eventProcessActor ! AllJobsCancelled
}
/**
- * Process one event retrieved from the event queue.
- * Returns true if we should stop the event loop.
+ * Process one event retrieved from the event processing actor.
+ *
+ * @param event The event to be processed.
+ * @return `true` if we should stop the event loop.
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
- val finalStage = newStage(rdd, None, jobId, Some(callSite))
+ var finalStage: Stage = null
+ try {
+ // New stage creation at times and if its not protected, the scheduler thread is killed.
+ // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted
+ finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
+ } catch {
+ case e: Exception =>
+ logWarning("Creating new stage failed due to exception - job: " + jobId, e)
+ listener.jobFailed(e)
+ return false
+ }
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@@ -366,26 +547,31 @@ class DAGScheduler(
logInfo("Missing parents: " + getMissingParentStages(finalStage))
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
// Compute very short actions like first() or take() with no parent stages locally.
+ listenerBus.post(SparkListenerJobStart(job, Array(), properties))
runLocally(job)
} else {
- listenerBus.post(SparkListenerJobStart(job, properties))
idToActiveJob(jobId) = job
activeJobs += job
resultStageToJob(finalStage) = job
+ listenerBus.post(SparkListenerJobStart(job, jobIdToStageIds(jobId).toArray, properties))
submitStage(finalStage)
}
case JobCancelled(jobId) =>
- // Cancel a job: find all the running stages that are linked to this job, and cancel them.
- running.filter(_.jobId == jobId).foreach { stage =>
- taskSched.cancelTasks(stage.id)
- }
+ handleJobCancellation(jobId)
+
+ case JobGroupCancelled(groupId) =>
+ // Cancel all jobs belonging to this job group.
+ // First finds all active jobs with this group id, and then kill stages for them.
+ val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
+ val jobIds = activeInGroup.map(_.jobId)
+ jobIds.foreach { handleJobCancellation }
case AllJobsCancelled =>
// Cancel all running jobs.
- running.foreach { stage =>
- taskSched.cancelTasks(stage.id)
- }
+ running.map(_.jobId).foreach { handleJobCancellation }
+ activeJobs.clear() // These should already be empty by this point,
+ idToActiveJob.clear() // but just in case we lost track of some jobs...
case ExecutorGained(execId, host) =>
handleExecutorGained(execId, host)
@@ -393,16 +579,35 @@ class DAGScheduler(
case ExecutorLost(execId) =>
handleExecutorLost(execId)
- case begin: BeginEvent =>
- listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
+ case BeginEvent(task, taskInfo) =>
+ for (
+ job <- idToActiveJob.get(task.stageId);
+ stage <- stageIdToStage.get(task.stageId);
+ stageInfo <- stageToInfos.get(stage)
+ ) {
+ if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) {
+ stageInfo.emittedTaskSizeWarning = true
+ logWarning(("Stage %d (%s) contains a task of very large " +
+ "size (%d KB). The maximum recommended task size is %d KB.").format(
+ task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN))
+ }
+ }
+ listenerBus.post(SparkListenerTaskStart(task, taskInfo))
+
+ case GettingResultEvent(task, taskInfo) =>
+ listenerBus.post(SparkListenerTaskGettingResult(task, taskInfo))
- case completion: CompletionEvent =>
- listenerBus.post(SparkListenerTaskEnd(
- completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
+ case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
+ listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
- abortStage(stageIdToStage(taskSet.stageId), reason)
+ stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason) }
+
+ case ResubmitFailedStages =>
+ if (failed.size > 0) {
+ resubmitFailedStages()
+ }
case StopDAGScheduler =>
// Cancel any active jobs
@@ -448,42 +653,6 @@ class DAGScheduler(
}
}
-
- /**
- * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
- * events and responds by launching tasks. This runs in a dedicated thread and receives events
- * via the eventQueue.
- */
- private def run() {
- SparkEnv.set(env)
-
- while (true) {
- val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
- if (event != null) {
- logDebug("Got event of type " + event.getClass.getName)
- }
- this.synchronized { // needed in case other threads makes calls into methods of this class
- if (event != null) {
- if (processEvent(event)) {
- return
- }
- }
-
- val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
- // Periodically resubmit failed stages if some map output fetches have failed and we have
- // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
- // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
- // the same time, so we want to make sure we've identified all the reduce tasks that depend
- // on the failed node.
- if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
- resubmitFailedStages()
- } else {
- submitWaitingStages()
- }
- }
- }
- }
-
/**
* Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
@@ -500,6 +669,7 @@ class DAGScheduler(
// Broken out for easier testing in DAGSchedulerSuite.
protected def runLocallyWithinThread(job: ActiveJob) {
+ var jobResult: JobResult = JobSucceeded
try {
SparkEnv.set(env)
val rdd = job.finalStage.rdd
@@ -514,31 +684,59 @@ class DAGScheduler(
}
} catch {
case e: Exception =>
+ jobResult = JobFailed(e, Some(job.finalStage))
job.listener.jobFailed(e)
+ } finally {
+ val s = job.finalStage
+ stageIdToJobIds -= s.id // clean up data structures that were populated for a local job,
+ stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through
+ stageToInfos -= s // completion events or stage abort
+ jobIdToStageIds -= job.jobId
+ listenerBus.post(SparkListenerJobEnd(job, jobResult))
+ }
+ }
+
+ /** Finds the earliest-created active job that needs the stage */
+ // TODO: Probably should actually find among the active jobs that need this
+ // stage the one with the highest priority (highest-priority pool, earliest created).
+ // That should take care of at least part of the priority inversion problem with
+ // cross-job dependencies.
+ private def activeJobForStage(stage: Stage): Option[Int] = {
+ if (stageIdToJobIds.contains(stage.id)) {
+ val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted
+ jobsThatUseStage.find(idToActiveJob.contains(_))
+ } else {
+ None
}
}
/** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
- logDebug("submitStage(" + stage + ")")
- if (!waiting(stage) && !running(stage) && !failed(stage)) {
- val missing = getMissingParentStages(stage).sortBy(_.id)
- logDebug("missing: " + missing)
- if (missing == Nil) {
- logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
- submitMissingTasks(stage)
- running += stage
- } else {
- for (parent <- missing) {
- submitStage(parent)
+ val jobId = activeJobForStage(stage)
+ if (jobId.isDefined) {
+ logDebug("submitStage(" + stage + ")")
+ if (!waiting(stage) && !running(stage) && !failed(stage)) {
+ val missing = getMissingParentStages(stage).sortBy(_.id)
+ logDebug("missing: " + missing)
+ if (missing == Nil) {
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
+ submitMissingTasks(stage, jobId.get)
+ running += stage
+ } else {
+ for (parent <- missing) {
+ submitStage(parent)
+ }
+ waiting += stage
}
- waiting += stage
}
+ } else {
+ abortStage(stage, "No active job for stage " + stage.id)
}
}
+
/** Called when stage's parents are available and we can now do its task. */
- private def submitMissingTasks(stage: Stage) {
+ private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
@@ -559,7 +757,7 @@ class DAGScheduler(
}
}
- val properties = if (idToActiveJob.contains(stage.jobId)) {
+ val properties = if (idToActiveJob.contains(jobId)) {
idToActiveJob(stage.jobId).properties
} else {
//this stage will be assigned to "default" pool
@@ -568,7 +766,7 @@ class DAGScheduler(
// must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded"
- listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
+ listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
@@ -589,9 +787,7 @@ class DAGScheduler(
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
- if (!stage.submissionTime.isDefined) {
- stage.submissionTime = Some(System.currentTimeMillis())
- }
+ stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -613,12 +809,12 @@ class DAGScheduler(
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
- val serviceTime = stage.submissionTime match {
+ val serviceTime = stageToInfos(stage).submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
- case _ => "Unkown"
+ case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
- stage.completionTime = Some(System.currentTimeMillis)
+ stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
listenerBus.post(StageCompleted(stageToInfos(stage)))
running -= stage
}
@@ -643,6 +839,7 @@ class DAGScheduler(
activeJobs -= job
resultStageToJob -= stage
markStageAsFinished(stage)
+ jobIdToStageIdsRemove(job.jobId)
listenerBus.post(SparkListenerJobEnd(job, JobSucceeded))
}
job.listener.taskSucceeded(rt.outputId, event.result)
@@ -679,7 +876,7 @@ class DAGScheduler(
changeEpoch = true)
}
clearCacheLocs()
- if (stage.outputLocs.count(_ == Nil) != 0) {
+ if (stage.outputLocs.exists(_ == Nil)) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + stage + " (" + stage.name +
@@ -696,9 +893,12 @@ class DAGScheduler(
}
waiting --= newlyRunnable
running ++= newlyRunnable
- for (stage <- newlyRunnable.sortBy(_.id)) {
+ for {
+ stage <- newlyRunnable.sortBy(_.id)
+ jobId <- activeJobForStage(stage)
+ } {
logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
- submitMissingTasks(stage)
+ submitMissingTasks(stage, jobId)
}
}
}
@@ -782,21 +982,42 @@ class DAGScheduler(
}
}
+ private def handleJobCancellation(jobId: Int) {
+ if (!jobIdToStageIds.contains(jobId)) {
+ logDebug("Trying to cancel unregistered job " + jobId)
+ } else {
+ val independentStages = removeJobAndIndependentStages(jobId)
+ independentStages.foreach { taskSched.cancelTasks }
+ val error = new SparkException("Job %d cancelled".format(jobId))
+ val job = idToActiveJob(jobId)
+ job.listener.jobFailed(error)
+ jobIdToStageIds -= jobId
+ activeJobs -= job
+ idToActiveJob -= jobId
+ listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(job.finalStage))))
+ }
+ }
+
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
private def abortStage(failedStage: Stage, reason: String) {
+ if (!stageIdToStage.contains(failedStage.id)) {
+ // Skip all the actions if the stage has been removed.
+ return
+ }
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
- failedStage.completionTime = Some(System.currentTimeMillis())
+ stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
val error = new SparkException("Job aborted: " + reason)
job.listener.jobFailed(error)
- listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
+ jobIdToStageIdsRemove(job.jobId)
idToActiveJob -= resultStage.jobId
activeJobs -= job
resultStageToJob -= resultStage
+ listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
}
if (dependentStages.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
@@ -854,7 +1075,7 @@ class DAGScheduler(
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
// that has any placement preferences. Ideally we would choose based on transfer sizes,
// but this will do for now.
- rdd.dependencies.foreach(_ match {
+ rdd.dependencies.foreach {
case n: NarrowDependency[_] =>
for (inPart <- n.getParents(partition)) {
val locs = getPreferredLocs(n.rdd, inPart)
@@ -862,30 +1083,29 @@ class DAGScheduler(
return locs
}
case _ =>
- })
+ }
Nil
}
private def cleanup(cleanupTime: Long) {
- var sizeBefore = stageIdToStage.size
- stageIdToStage.clearOldValues(cleanupTime)
- logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size)
-
- sizeBefore = shuffleToMapStage.size
- shuffleToMapStage.clearOldValues(cleanupTime)
- logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size)
-
- sizeBefore = pendingTasks.size
- pendingTasks.clearOldValues(cleanupTime)
- logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
-
- sizeBefore = stageToInfos.size
- stageToInfos.clearOldValues(cleanupTime)
- logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size)
+ Map(
+ "stageIdToStage" -> stageIdToStage,
+ "shuffleToMapStage" -> shuffleToMapStage,
+ "pendingTasks" -> pendingTasks,
+ "stageToInfos" -> stageToInfos,
+ "jobIdToStageIds" -> jobIdToStageIds,
+ "stageIdToJobIds" -> stageIdToJobIds).
+ foreach { case(s, t) => {
+ val sizeBefore = t.size
+ t.clearOldValues(cleanupTime)
+ logInfo("%s %d --> %d".format(s, sizeBefore, t.size))
+ }}
}
def stop() {
- eventQueue.put(StopDAGScheduler)
+ if (eventProcessActor != null) {
+ eventProcessActor ! StopDAGScheduler
+ }
metadataCleaner.cancel()
taskSched.stop()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index ee89bfb38d..add1187613 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -46,11 +46,16 @@ private[scheduler] case class JobSubmitted(
private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
+private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
+
private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
private[scheduler]
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+private[scheduler]
+case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+
private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
@@ -60,12 +65,13 @@ private[scheduler] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
-private[scheduler]
-case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
+private[scheduler] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[scheduler]
case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent
+
private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 370ccd183c..1791ee660d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.deploy.SparkHadoopUtil
import scala.collection.immutable.Set
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.hadoop.security.UserGroupInformation
@@ -87,9 +88,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
- val env = SparkEnv.get
val conf = new JobConf(configuration)
- env.hadoop.addCredentials(conf)
+ SparkHadoopUtil.get.addCredentials(conf)
FileInputFormat.setInputPaths(conf, path)
val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
@@ -108,9 +108,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
- val env = SparkEnv.get
val jobConf = new JobConf(configuration)
- env.hadoop.addCredentials(jobConf)
+ SparkHadoopUtil.get.addCredentials(jobConf)
FileInputFormat.setInputPaths(jobConf, path)
val instance: org.apache.hadoop.mapred.InputFormat[_, _] =
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 3628b1b078..60927831a1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -1,292 +1,384 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import java.io.PrintWriter
-import java.io.File
-import java.io.FileNotFoundException
-import java.text.SimpleDateFormat
-import java.util.{Date, Properties}
-import java.util.concurrent.LinkedBlockingQueue
-
-import scala.collection.mutable.{Map, HashMap, ListBuffer}
-import scala.io.Source
-
-import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.executor.TaskMetrics
-
-// Used to record runtime information for each job, including RDD graph
-// tasks' start/stop shuffle information and information from outside
-
-class JobLogger(val logDirName: String) extends SparkListener with Logging {
- private val logDir =
- if (System.getenv("SPARK_LOG_DIR") != null)
- System.getenv("SPARK_LOG_DIR")
- else
- "/tmp/spark"
- private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
- private val stageIDToJobID = new HashMap[Int, Int]
- private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
- private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
-
- createLogDir()
- def this() = this(String.valueOf(System.currentTimeMillis()))
-
- def getLogDir = logDir
- def getJobIDtoPrintWriter = jobIDToPrintWriter
- def getStageIDToJobID = stageIDToJobID
- def getJobIDToStages = jobIDToStages
- def getEventQueue = eventQueue
-
- // Create a folder for log files, the folder's name is the creation time of the jobLogger
- protected def createLogDir() {
- val dir = new File(logDir + "/" + logDirName + "/")
- if (dir.exists()) {
- return
- }
- if (dir.mkdirs() == false) {
- logError("create log directory error:" + logDir + "/" + logDirName + "/")
- }
- }
-
- // Create a log file for one job, the file name is the jobID
- protected def createLogWriter(jobID: Int) {
- try{
- val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
- jobIDToPrintWriter += (jobID -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
- }
-
- // Close log file, and clean the stage relationship in stageIDToJobID
- protected def closeLogWriter(jobID: Int) =
- jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
- fileWriter.close()
- jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
- stageIDToJobID -= stage.id
- })
- jobIDToPrintWriter -= jobID
- jobIDToStages -= jobID
- }
-
- // Write log information to log file, withTime parameter controls whether to recored
- // time stamp for the information
- protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
- var writeInfo = info
- if (withTime) {
- val date = new Date(System.currentTimeMillis())
- writeInfo = DATE_FORMAT.format(date) + ": " +info
- }
- jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
- }
-
- protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
- stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
-
- protected def buildJobDep(jobID: Int, stage: Stage) {
- if (stage.jobId == jobID) {
- jobIDToStages.get(jobID) match {
- case Some(stageList) => stageList += stage
- case None => val stageList = new ListBuffer[Stage]
- stageList += stage
- jobIDToStages += (jobID -> stageList)
- }
- stageIDToJobID += (stage.id -> jobID)
- stage.parents.foreach(buildJobDep(jobID, _))
- }
- }
-
- protected def recordStageDep(jobID: Int) {
- def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
- var rddList = new ListBuffer[RDD[_]]
- rddList += rdd
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- case _ => rddList ++= getRddsInStage(dep.rdd)
- }
- }
- rddList
- }
- jobIDToStages.get(jobID).foreach {_.foreach { stage =>
- var depRddDesc: String = ""
- getRddsInStage(stage.rdd).foreach { rdd =>
- depRddDesc += rdd.id + ","
- }
- var depStageDesc: String = ""
- stage.parents.foreach { stage =>
- depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
- }
- jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
- depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
- " STAGE_DEP=" + depStageDesc, false)
- }
- }
- }
-
- // Generate indents and convert to String
- protected def indentString(indent: Int) = {
- val sb = new StringBuilder()
- for (i <- 1 to indent) {
- sb.append(" ")
- }
- sb.toString()
- }
-
- protected def getRddName(rdd: RDD[_]) = {
- var rddName = rdd.getClass.getName
- if (rdd.name != null) {
- rddName = rdd.name
- }
- rddName
- }
-
- protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
- val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
- jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
- }
- }
-
- protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
- var stageInfo: String = ""
- if (stage.isShuffleMap) {
- stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
- stage.shuffleDep.get.shuffleId
- }else{
- stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
- }
- if (stage.jobId == jobID) {
- jobLogInfo(jobID, indentString(indent) + stageInfo, false)
- recordRddInStageGraph(jobID, stage.rdd, indent)
- stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
- } else
- jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
- }
-
- // Record task metrics into job log files
- protected def recordTaskMetrics(stageID: Int, status: String,
- taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
- val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
- " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
- " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
- val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val readMetrics =
- taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
- case None => ""
- }
- val writeMetrics =
- taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) =>
- " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
- case None => ""
- }
- stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
- }
-
- override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(
- stageSubmitted.stage.id,
- "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.id, stageSubmitted.taskSize))
- }
-
- override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(
- stageCompleted.stageInfo.stage.id,
- "STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
-
- }
-
- override def onTaskStart(taskStart: SparkListenerTaskStart) { }
-
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- val task = taskEnd.task
- val taskInfo = taskEnd.taskInfo
- var taskStatus = ""
- task match {
- case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
- case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
- }
- taskEnd.reason match {
- case Success => taskStatus += " STATUS=SUCCESS"
- recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
- case Resubmitted =>
- taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId
- stageLogInfo(task.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
- taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
- task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
- mapId + " REDUCE_ID=" + reduceId
- stageLogInfo(task.stageId, taskStatus)
- case OtherFailure(message) =>
- taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId + " INFO=" + message
- stageLogInfo(task.stageId, taskStatus)
- case _ =>
- }
- }
-
- override def onJobEnd(jobEnd: SparkListenerJobEnd) {
- val job = jobEnd.job
- var info = "JOB_ID=" + job.jobId
- jobEnd.jobResult match {
- case JobSucceeded => info += " STATUS=SUCCESS"
- case JobFailed(exception, _) =>
- info += " STATUS=FAILED REASON="
- exception.getMessage.split("\\s+").foreach(info += _ + "_")
- case _ =>
- }
- jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
- closeLogWriter(job.jobId)
- }
-
- protected def recordJobProperties(jobID: Int, properties: Properties) {
- if(properties != null) {
- val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
- jobLogInfo(jobID, description, false)
- }
- }
-
- override def onJobStart(jobStart: SparkListenerJobStart) {
- val job = jobStart.job
- val properties = jobStart.properties
- createLogWriter(job.jobId)
- recordJobProperties(job.jobId, properties)
- buildJobDep(job.jobId, job.finalStage)
- recordStageDep(job.jobId)
- recordStageDepGraph(job.jobId, job.finalStage)
- jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
- }
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{IOException, File, FileNotFoundException, PrintWriter}
+import java.text.SimpleDateFormat
+import java.util.{Date, Properties}
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * A logger class to record runtime information for jobs in Spark. This class outputs one log file
+ * for each Spark job, containing RDD graph, tasks start/stop, shuffle information.
+ * JobLogger is a subclass of SparkListener, use addSparkListener to add JobLogger to a SparkContext
+ * after the SparkContext is created.
+ * Note that each JobLogger only works for one SparkContext
+ * @param logDirName The base directory for the log files.
+ */
+
+class JobLogger(val user: String, val logDirName: String)
+ extends SparkListener with Logging {
+
+ def this() = this(System.getProperty("user.name", "<unknown>"),
+ String.valueOf(System.currentTimeMillis()))
+
+ private val logDir =
+ if (System.getenv("SPARK_LOG_DIR") != null)
+ System.getenv("SPARK_LOG_DIR")
+ else
+ "/tmp/spark-%s".format(user)
+
+ private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
+ private val stageIDToJobID = new HashMap[Int, Int]
+ private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
+ private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
+ private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
+
+ createLogDir()
+
+ // The following 5 functions are used only in testing.
+ private[scheduler] def getLogDir = logDir
+ private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
+ private[scheduler] def getStageIDToJobID = stageIDToJobID
+ private[scheduler] def getJobIDToStages = jobIDToStages
+ private[scheduler] def getEventQueue = eventQueue
+
+ /** Create a folder for log files, the folder's name is the creation time of jobLogger */
+ protected def createLogDir() {
+ val dir = new File(logDir + "/" + logDirName + "/")
+ if (dir.exists()) {
+ return
+ }
+ if (dir.mkdirs() == false) {
+ // JobLogger should throw a exception rather than continue to construct this object.
+ throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/")
+ }
+ }
+
+ /**
+ * Create a log file for one job
+ * @param jobID ID of the job
+ * @exception FileNotFoundException Fail to create log file
+ */
+ protected def createLogWriter(jobID: Int) {
+ try {
+ val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
+ jobIDToPrintWriter += (jobID -> fileWriter)
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
+ }
+
+ /**
+ * Close log file, and clean the stage relationship in stageIDToJobID
+ * @param jobID ID of the job
+ */
+ protected def closeLogWriter(jobID: Int) {
+ jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
+ fileWriter.close()
+ jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
+ stageIDToJobID -= stage.id
+ })
+ jobIDToPrintWriter -= jobID
+ jobIDToStages -= jobID
+ }
+ }
+
+ /**
+ * Write info into log file
+ * @param jobID ID of the job
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
+ var writeInfo = info
+ if (withTime) {
+ val date = new Date(System.currentTimeMillis())
+ writeInfo = DATE_FORMAT.format(date) + ": " +info
+ }
+ jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
+ }
+
+ /**
+ * Write info into log file
+ * @param stageID ID of the stage
+ * @param info Info to be recorded
+ * @param withTime Controls whether to record time stamp before the info, default is true
+ */
+ protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) {
+ stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
+ }
+
+ /**
+ * Build stage dependency for a job
+ * @param jobID ID of the job
+ * @param stage Root stage of the job
+ */
+ protected def buildJobDep(jobID: Int, stage: Stage) {
+ if (stage.jobId == jobID) {
+ jobIDToStages.get(jobID) match {
+ case Some(stageList) => stageList += stage
+ case None => val stageList = new ListBuffer[Stage]
+ stageList += stage
+ jobIDToStages += (jobID -> stageList)
+ }
+ stageIDToJobID += (stage.id -> jobID)
+ stage.parents.foreach(buildJobDep(jobID, _))
+ }
+ }
+
+ /**
+ * Record stage dependency and RDD dependency for a stage
+ * @param jobID Job ID of the stage
+ */
+ protected def recordStageDep(jobID: Int) {
+ def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
+ var rddList = new ListBuffer[RDD[_]]
+ rddList += rdd
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
+ }
+ rddList
+ }
+ jobIDToStages.get(jobID).foreach {_.foreach { stage =>
+ var depRddDesc: String = ""
+ getRddsInStage(stage.rdd).foreach { rdd =>
+ depRddDesc += rdd.id + ","
+ }
+ var depStageDesc: String = ""
+ stage.parents.foreach { stage =>
+ depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
+ }
+ jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
+ depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
+ " STAGE_DEP=" + depStageDesc, false)
+ }
+ }
+ }
+
+ /**
+ * Generate indents and convert to String
+ * @param indent Number of indents
+ * @return string of indents
+ */
+ protected def indentString(indent: Int): String = {
+ val sb = new StringBuilder()
+ for (i <- 1 to indent) {
+ sb.append(" ")
+ }
+ sb.toString()
+ }
+
+ /**
+ * Get RDD's name
+ * @param rdd Input RDD
+ * @return String of RDD's name
+ */
+ protected def getRddName(rdd: RDD[_]): String = {
+ var rddName = rdd.getClass.getSimpleName
+ if (rdd.name != null) {
+ rddName = rdd.name
+ }
+ rddName
+ }
+
+ /**
+ * Record RDD dependency graph in a stage
+ * @param jobID Job ID of the stage
+ * @param rdd Root RDD of the stage
+ * @param indent Indent number before info
+ */
+ protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
+ val rddInfo =
+ if (rdd.getStorageLevel != StorageLevel.NONE) {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " +
+ rdd.origin + " " + rdd.generator
+ } else {
+ "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " +
+ rdd.origin + " " + rdd.generator
+ }
+ jobLogInfo(jobID, indentString(indent) + rddInfo, false)
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
+ }
+ }
+
+ /**
+ * Record stage dependency graph of a job
+ * @param jobID Job ID of the stage
+ * @param stage Root stage of the job
+ * @param indent Indent number before info, default is 0
+ */
+ protected def recordStageDepGraph(jobID: Int, stage: Stage, idSet: HashSet[Int], indent: Int = 0) {
+ val stageInfo = if (stage.isShuffleMap) {
+ "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
+ } else {
+ "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ }
+ if (stage.jobId == jobID) {
+ jobLogInfo(jobID, indentString(indent) + stageInfo, false)
+ if (!idSet.contains(stage.id)) {
+ idSet += stage.id
+ recordRddInStageGraph(jobID, stage.rdd, indent)
+ stage.parents.foreach(recordStageDepGraph(jobID, _, idSet, indent + 2))
+ }
+ } else {
+ jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
+ }
+
+ /**
+ * Record task metrics into job log files, including execution info and shuffle metrics
+ * @param stageID Stage ID of the task
+ * @param status Status info of the task
+ * @param taskInfo Task description info
+ * @param taskMetrics Task running metrics
+ */
+ protected def recordTaskMetrics(stageID: Int, status: String,
+ taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
+ val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
+ " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
+ " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
+ val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
+ val readMetrics = taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics = taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
+ stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
+ }
+
+ /**
+ * When stage is submitted, record stage submit info
+ * @param stageSubmitted Stage submitted event
+ */
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
+ }
+
+ /**
+ * When stage is completed, record stage completion status
+ * @param stageCompleted Stage completed event
+ */
+ override def onStageCompleted(stageCompleted: StageCompleted) {
+ stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
+ stageCompleted.stage.stageId))
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) { }
+
+ /**
+ * When task ends, record task completion status and metrics
+ * @param taskEnd Task end event
+ */
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ val task = taskEnd.task
+ val taskInfo = taskEnd.taskInfo
+ var taskStatus = ""
+ task match {
+ case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
+ case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
+ }
+ taskEnd.reason match {
+ case Success => taskStatus += " STATUS=SUCCESS"
+ recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
+ case Resubmitted =>
+ taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId
+ stageLogInfo(task.stageId, taskStatus)
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
+ task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
+ mapId + " REDUCE_ID=" + reduceId
+ stageLogInfo(task.stageId, taskStatus)
+ case OtherFailure(message) =>
+ taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
+ " STAGE_ID=" + task.stageId + " INFO=" + message
+ stageLogInfo(task.stageId, taskStatus)
+ case _ =>
+ }
+ }
+
+ /**
+ * When job ends, recording job completion status and close log file
+ * @param jobEnd Job end event
+ */
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) {
+ val job = jobEnd.job
+ var info = "JOB_ID=" + job.jobId
+ jobEnd.jobResult match {
+ case JobSucceeded => info += " STATUS=SUCCESS"
+ case JobFailed(exception, _) =>
+ info += " STATUS=FAILED REASON="
+ exception.getMessage.split("\\s+").foreach(info += _ + "_")
+ case _ =>
+ }
+ jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
+ closeLogWriter(job.jobId)
+ }
+
+ /**
+ * Record job properties into job log file
+ * @param jobID ID of the job
+ * @param properties Properties of the job
+ */
+ protected def recordJobProperties(jobID: Int, properties: Properties) {
+ if(properties != null) {
+ val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
+ jobLogInfo(jobID, description, false)
+ }
+ }
+
+ /**
+ * When job starts, record job property and stage graph
+ * @param jobStart Job start event
+ */
+ override def onJobStart(jobStart: SparkListenerJobStart) {
+ val job = jobStart.job
+ val properties = jobStart.properties
+ createLogWriter(job.jobId)
+ recordJobProperties(job.jobId, properties)
+ buildJobDep(job.jobId, job.finalStage)
+ recordStageDep(job.jobId)
+ recordStageDepGraph(job.jobId, job.finalStage, new HashSet[Int])
+ jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 58f238d8cf..b026f860a8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -31,6 +31,7 @@ private[spark] class JobWaiter[T](
private var finishedTasks = 0
// Is the job as a whole finished (succeeded or failed)?
+ @volatile
private var _jobFinished = totalTasks == 0
def jobFinished = _jobFinished
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
index 0a786deb16..3832ee7ff6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
@@ -22,7 +22,7 @@ package org.apache.spark.scheduler
* to order tasks amongst a Schedulable's sub-queues
* "NONE" is used when the a Schedulable has no sub-queues.
*/
-object SchedulingMode extends Enumeration("FAIR", "FIFO", "NONE") {
+object SchedulingMode extends Enumeration {
type SchedulingMode = Value
val FAIR,FIFO,NONE = Value
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 802791797a..0f2deb4bcb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -146,49 +146,56 @@ private[spark] class ShuffleMapTask(
metrics = Some(context.taskMetrics)
val blockManager = SparkEnv.get.blockManager
- var shuffle: ShuffleBlocks = null
- var buckets: ShuffleWriterGroup = null
+ val shuffleBlockManager = blockManager.shuffleBlockManager
+ var shuffle: ShuffleWriterGroup = null
+ var success = false
try {
// Obtain all the block writers for shuffle blocks.
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
- shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
- buckets = shuffle.acquireWriters(partitionId)
+ shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
// Write the map output to its associated buckets.
for (elem <- rdd.iterator(split, context)) {
val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
- buckets.writers(bucketId).write(pair)
+ shuffle.writers(bucketId).write(pair)
}
// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
- val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+ var totalTime = 0L
+ val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
writer.close()
- val size = writer.size()
+ val size = writer.fileSegment().length
totalBytes += size
+ totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
}
// Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
+ shuffleMetrics.shuffleWriteTime = totalTime
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
+ success = true
new MapStatus(blockManager.blockManagerId, compressedSizes)
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
- if (buckets != null) {
- buckets.writers.foreach(_.revertPartialWrites())
+ if (shuffle != null && shuffle.writers != null) {
+ for (writer <- shuffle.writers) {
+ writer.revertPartialWrites()
+ writer.close()
+ }
}
throw e
} finally {
// Release the writers back to the shuffle block manager.
- if (shuffle != null && buckets != null) {
- shuffle.releaseWriters(buckets)
+ if (shuffle != null && shuffle.writers != null) {
+ shuffle.releaseWriters(success)
}
// Execute the callbacks on task completion.
context.executeOnCompleteCallbacks()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 466baf9913..ee63b3c4a1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -24,17 +24,20 @@ import org.apache.spark.executor.TaskMetrics
sealed trait SparkListenerEvents
-case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties)
+case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents
-case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+case class SparkListenerTaskGettingResult(
+ task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
taskMetrics: TaskMetrics) extends SparkListenerEvents
-case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null)
+case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int], properties: Properties = null)
extends SparkListenerEvents
case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
@@ -57,6 +60,12 @@ trait SparkListener {
def onTaskStart(taskStart: SparkListenerTaskStart) { }
/**
+ * Called when a task begins remotely fetching its result (will not be called for tasks that do
+ * not need to fetch the result remotely).
+ */
+ def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
+
+ /**
* Called when a task ends
*/
def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
@@ -80,7 +89,7 @@ class StatsReportListener extends SparkListener with Logging {
override def onStageCompleted(stageCompleted: StageCompleted) {
import org.apache.spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
- this.logInfo("Finished stage: " + stageCompleted.stageInfo)
+ this.logInfo("Finished stage: " + stageCompleted.stage)
showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
//shuffle write
@@ -93,7 +102,7 @@ class StatsReportListener extends SparkListener with Logging {
//runtime breakdown
- val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
+ val runtimePcts = stageCompleted.stage.taskInfos.map{
case (info, metrics) => RuntimePercentage(info.duration, metrics)
}
showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
@@ -111,7 +120,7 @@ object StatsReportListener extends Logging {
val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
- Distribution(stage.stageInfo.taskInfos.flatMap{
+ Distribution(stage.stage.taskInfos.flatMap {
case ((info,metric)) => getMetric(info, metric)})
}
@@ -122,8 +131,8 @@ object StatsReportListener extends Logging {
def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
val stats = d.statCounter
- logInfo(heading + stats)
val quantiles = d.getQuantiles(probabilities).map{formatNumber}
+ logInfo(heading + stats)
logInfo(percentilesHeader)
logInfo("\t" + quantiles.mkString("\t"))
}
@@ -164,8 +173,6 @@ object StatsReportListener extends Logging {
showMillisDistribution(heading, extractLongDistribution(stage, getMetric))
}
-
-
val seconds = 1000L
val minutes = seconds * 60
val hours = minutes * 60
@@ -189,7 +196,6 @@ object StatsReportListener extends Logging {
}
-
case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
object RuntimePercentage {
def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index 4d3e4a17ba..85687ea330 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging {
sparkListeners.foreach(_.onJobEnd(jobEnd))
case taskStart: SparkListenerTaskStart =>
sparkListeners.foreach(_.onTaskStart(taskStart))
+ case taskGettingResult: SparkListenerTaskGettingResult =>
+ sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
case _ =>
@@ -89,4 +91,3 @@ private[spark] class SparkListenerBus() extends Logging {
return true
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index aa293dc6b3..7cb3fe46e5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId
private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
+ val numTasks: Int,
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
@@ -49,11 +50,6 @@ private[spark] class Stage(
val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
-
- /** When first task was submitted to scheduler. */
- var submissionTime: Option[Long] = None
- var completionTime: Option[Long] = None
-
private var nextAttemptId = 0
def isAvailable: Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index b6f11969e5..e9f2198a00 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -21,9 +21,17 @@ import scala.collection._
import org.apache.spark.executor.TaskMetrics
-case class StageInfo(
- val stage: Stage,
+class StageInfo(
+ stage: Stage,
val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
) {
- override def toString = stage.rdd.toString
+ val stageId = stage.id
+ /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
+ var submissionTime: Option[Long] = None
+ var completionTime: Option[Long] = None
+ val rddName = stage.rdd.name
+ val name = stage.name
+ val numPartitions = stage.numPartitions
+ val numTasks = stage.numTasks
+ var emittedTaskSizeWarning = false
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 1fe0d0e4e2..69b42e86ea 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,7 +45,7 @@ import org.apache.spark.util.ByteBufferInputStream
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
- def run(attemptId: Long): T = {
+ final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
if (_killed) {
kill()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 7c2a422aff..3c22edd524 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -31,9 +31,27 @@ class TaskInfo(
val host: String,
val taskLocality: TaskLocality.TaskLocality) {
+ /**
+ * The time when the task started remotely getting the result. Will not be set if the
+ * task result was sent immediately when the task finished (as opposed to sending an
+ * IndirectTaskResult and later fetching the result from the block manager).
+ */
+ var gettingResultTime: Long = 0
+
+ /**
+ * The time when the task has completed successfully (including the time to remotely fetch
+ * results, if necessary).
+ */
var finishTime: Long = 0
+
var failed = false
+ var serializedSize: Int = 0
+
+ def markGettingResult(time: Long = System.currentTimeMillis) {
+ gettingResultTime = time
+ }
+
def markSuccessful(time: Long = System.currentTimeMillis) {
finishTime = time
}
@@ -43,6 +61,8 @@ class TaskInfo(
failed = true
}
+ def gettingResult: Boolean = gettingResultTime != 0
+
def finished: Boolean = finishTime != 0
def successful: Boolean = finished && !failed
@@ -52,6 +72,8 @@ class TaskInfo(
def status: String = {
if (running)
"RUNNING"
+ else if (gettingResult)
+ "GET RESULT"
else if (failed)
"FAILED"
else if (successful)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
index 47b0f387aa..35de13c385 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
@@ -18,9 +18,7 @@
package org.apache.spark.scheduler
-private[spark] object TaskLocality
- extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY")
-{
+private[spark] object TaskLocality extends Enumeration {
// process local is expected to be used ONLY within tasksetmanager for now.
val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 7e468d0d67..e80cc6b0f6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -35,18 +35,15 @@ case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Se
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
-class DirectTaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
+class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
extends TaskResult[T] with Externalizable {
- def this() = this(null.asInstanceOf[T], null, null)
+ def this() = this(null.asInstanceOf[ByteBuffer], null, null)
override def writeExternal(out: ObjectOutput) {
- val objectSer = SparkEnv.get.serializer.newInstance()
- val bb = objectSer.serialize(value)
-
- out.writeInt(bb.remaining())
- Utils.writeByteBuffer(bb, out)
+ out.writeInt(valueBytes.remaining);
+ Utils.writeByteBuffer(valueBytes, out)
out.writeInt(accumUpdates.size)
for ((key, value) <- accumUpdates) {
@@ -58,12 +55,10 @@ class DirectTaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var me
override def readExternal(in: ObjectInput) {
- val objectSer = SparkEnv.get.serializer.newInstance()
-
val blen = in.readInt()
val byteVal = new Array[Byte](blen)
in.readFully(byteVal)
- value = objectSer.deserialize(ByteBuffer.wrap(byteVal))
+ valueBytes = ByteBuffer.wrap(byteVal)
val numUpdates = in.readInt
if (numUpdates == 0) {
@@ -76,4 +71,9 @@ class DirectTaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var me
}
metrics = in.readObject().asInstanceOf[TaskMetrics]
}
+
+ def value(): T = {
+ val resultSer = SparkEnv.get.serializer.newInstance()
+ return resultSer.deserialize(valueBytes)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 4ea8bf8853..66ab8ea4cd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -24,6 +24,7 @@ import java.util.{TimerTask, Timer}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
+import scala.concurrent.duration._
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
@@ -98,8 +99,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
this.dagScheduler = dagScheduler
}
- def initialize(context: SchedulerBackend) {
- backend = context
+ def initialize(backend: SchedulerBackend) {
+ this.backend = backend
// temporarily set rootPool name to empty
rootPool = new Pool("", schedulingMode, 0, 0)
schedulableBuilder = {
@@ -119,21 +120,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.start()
if (System.getProperty("spark.speculation", "false").toBoolean) {
- new Thread("ClusterScheduler speculation check") {
- setDaemon(true)
-
- override def run() {
- logInfo("Starting speculative execution thread")
- while (true) {
- try {
- Thread.sleep(SPECULATION_INTERVAL)
- } catch {
- case e: InterruptedException => {}
- }
- checkSpeculatableTasks()
- }
- }
- }.start()
+ logInfo("Starting speculative execution thread")
+ import sc.env.actorSystem.dispatcher
+ sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds,
+ SPECULATION_INTERVAL milliseconds) {
+ checkSpeculatableTasks()
+ }
}
}
@@ -180,7 +172,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.killTask(tid, execId)
}
}
- tsm.error("Stage %d was cancelled".format(stageId))
+ logInfo("Stage %d was cancelled".format(stageId))
+ tsm.removeAllRunningTasks()
+ taskSetFinished(tsm)
}
}
@@ -256,7 +250,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
var failedExecutor: Option[String] = None
- var taskFailed = false
synchronized {
try {
if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
@@ -276,9 +269,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
taskIdToExecutorId.remove(tid)
}
- if (state == TaskState.FAILED) {
- taskFailed = true
- }
activeTaskSets.get(taskSetId).foreach { taskSet =>
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
@@ -300,10 +290,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
- if (taskFailed) {
- // Also revive offers if a task had failed for some reason other than host lost
- backend.reviveOffers()
- }
+ }
+
+ def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
+ taskSetManager.handleTaskGettingResult(tid)
}
def handleSuccessfulTask(
@@ -319,8 +309,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
taskState: TaskState,
reason: Option[TaskEndReason]) = synchronized {
taskSetManager.handleFailedTask(tid, taskState, reason)
- if (taskState == TaskState.FINISHED) {
- // The task finished successfully but the result was lost, so we should revive offers.
+ if (taskState != TaskState.KILLED) {
+ // Need to revive offers again now that the task set manager state has been updated to
+ // reflect failed tasks that need to be re-run.
backend.reviveOffers()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 29093e3b4f..bf494aa64d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -17,6 +17,7 @@
package org.apache.spark.scheduler.cluster
+import java.io.NotSerializableException
import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
@@ -376,6 +377,7 @@ private[spark] class ClusterTaskSetManager(
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
+ info.serializedSize = serializedTask.limit
if (taskAttempts(index).size == 1)
taskStarted(task,info)
return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
@@ -418,6 +420,12 @@ private[spark] class ClusterTaskSetManager(
sched.dagScheduler.taskStarted(task, info)
}
+ def handleTaskGettingResult(tid: Long) = {
+ val info = taskInfos(tid)
+ info.markGettingResult()
+ sched.dagScheduler.taskGettingResult(tasks(info.index), info)
+ }
+
/**
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
@@ -478,6 +486,14 @@ private[spark] class ClusterTaskSetManager(
case ef: ExceptionFailure =>
sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ if (ef.className == classOf[NotSerializableException].getName()) {
+ // If the task result wasn't serializable, there's no point in trying to re-execute it.
+ logError("Task %s:%s had a not serializable result: %s; not retrying".format(
+ taskSet.id, index, ef.description))
+ abort("Task %s:%s had a not serializable result: %s".format(
+ taskSet.id, index, ef.description))
+ return
+ }
val key = ef.description
val now = clock.getTime()
val (printFull, dupCount) = {
@@ -513,10 +529,10 @@ private[spark] class ClusterTaskSetManager(
addPendingTask(index)
if (state != TaskState.KILLED) {
numFailures(index) += 1
- if (numFailures(index) > MAX_TASK_FAILURES) {
- logError("Task %s:%d failed more than %d times; aborting job".format(
+ if (numFailures(index) >= MAX_TASK_FAILURES) {
+ logError("Task %s:%d failed %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
+ abort("Task %s:%d failed %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
@@ -558,7 +574,7 @@ private[spark] class ClusterTaskSetManager(
runningTasks = runningTasksSet.size
}
- private def removeAllRunningTasks() {
+ private[cluster] def removeAllRunningTasks() {
val numRunningTasks = runningTasksSet.size
runningTasksSet.clear()
if (parent != null) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index a8230ec6bc..53316dae2a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -60,6 +60,10 @@ private[spark] object CoarseGrainedClusterMessages {
case object StopDriver extends CoarseGrainedClusterMessage
+ case object StopExecutor extends CoarseGrainedClusterMessage
+
+ case object StopExecutors extends CoarseGrainedClusterMessage
+
case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index c0f1c6dbad..7e22c843bf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -20,18 +20,17 @@ package org.apache.spark.scheduler.cluster
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.concurrent.Await
+import scala.concurrent.duration._
import akka.actor._
-import akka.dispatch.Await
import akka.pattern.ask
-import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
-import akka.util.Duration
-import akka.util.duration._
+import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{SparkException, Logging, TaskState}
+import org.apache.spark.{Logging, SparkException, TaskState}
import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* A scheduler backend that waits for coarse grained executors to connect to it through Akka.
@@ -48,20 +47,22 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0)
+ private val timeout = AkkaUtils.askTimeout
+
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
private val executorActor = new HashMap[String, ActorRef]
private val executorAddress = new HashMap[String, Address]
private val executorHost = new HashMap[String, String]
private val freeCores = new HashMap[String, Int]
- private val actorToExecutorId = new HashMap[ActorRef, String]
private val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
// Periodically revive offers to allow delay scheduling to work
val reviveInterval = System.getProperty("spark.scheduler.revive.interval", "1000").toLong
+ import context.dispatcher
context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers)
}
@@ -73,12 +74,10 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
} else {
logInfo("Registered executor: " + sender + " with ID " + executorId)
sender ! RegisteredExecutor(sparkProperties)
- context.watch(sender)
executorActor(executorId) = sender
executorHost(executorId) = Utils.parseHostPort(hostPort)._1
freeCores(executorId) = cores
executorAddress(executorId) = sender.path.address
- actorToExecutorId(sender) = executorId
addressToExecutorId(sender.path.address) = executorId
totalCoreCount.addAndGet(cores)
makeOffers()
@@ -87,8 +86,14 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
- freeCores(executorId) += 1
- makeOffers(executorId)
+ if (executorActor.contains(executorId)) {
+ freeCores(executorId) += 1
+ makeOffers(executorId)
+ } else {
+ // Ignoring the update since we don't know about the executor.
+ val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s"
+ logWarning(msg.format(taskId, state, sender, executorId))
+ }
}
case ReviveOffers =>
@@ -101,18 +106,20 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
sender ! true
context.stop(self)
+ case StopExecutors =>
+ logInfo("Asking each executor to shut down")
+ for (executor <- executorActor.values) {
+ executor ! StopExecutor
+ }
+ sender ! true
+
case RemoveExecutor(executorId, reason) =>
removeExecutor(executorId, reason)
sender ! true
- case Terminated(actor) =>
- actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated"))
-
- case RemoteClientDisconnected(transport, address) =>
- addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected"))
+ case DisassociatedEvent(_, address, _) =>
+ addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated"))
- case RemoteClientShutdown(transport, address) =>
- addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown"))
}
// Make fake resource offers on all executors
@@ -140,7 +147,6 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
if (executorActor.contains(executorId)) {
logInfo("Executor " + executorId + " disconnected, so removing it")
val numCores = freeCores(executorId)
- actorToExecutorId -= executorActor(executorId)
addressToExecutorId -= executorAddress(executorId)
executorActor -= executorId
executorHost -= executorId
@@ -168,13 +174,25 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
}
- private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ def stopExecutors() {
+ try {
+ if (driverActor != null) {
+ logInfo("Shutting down all executors")
+ val future = driverActor.ask(StopExecutors)(timeout)
+ Await.ready(future, timeout)
+ }
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Error asking standalone scheduler to shut down executors", e)
+ }
+ }
override def stop() {
+ stopExecutors()
try {
if (driverActor != null) {
val future = driverActor.ask(StopDriver)(timeout)
- Await.result(future, timeout)
+ Await.ready(future, timeout)
}
} catch {
case e: Exception =>
@@ -197,7 +215,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
def removeExecutor(executorId: String, reason: String) {
try {
val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
- Await.result(future, timeout)
+ Await.ready(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error notifying standalone scheduler's driver actor", e)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
new file mode 100644
index 0000000000..e8fecec4a6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{Path, FileSystem}
+import org.apache.spark.{Logging, SparkContext}
+
+private[spark] class SimrSchedulerBackend(
+ scheduler: ClusterScheduler,
+ sc: SparkContext,
+ driverFilePath: String)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ with Logging {
+
+ val tmpPath = new Path(driverFilePath + "_tmp")
+ val filePath = new Path(driverFilePath)
+
+ val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt
+
+ override def start() {
+ super.start()
+
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+ val conf = new Configuration()
+ val fs = FileSystem.get(conf)
+
+ logInfo("Writing to HDFS file: " + driverFilePath)
+ logInfo("Writing Akka address: " + driverUrl)
+ logInfo("Writing Spark UI Address: " + sc.ui.appUIAddress)
+
+ // Create temporary file to prevent race condition where executors get empty driverUrl file
+ val temp = fs.create(tmpPath, true)
+ temp.writeUTF(driverUrl)
+ temp.writeInt(maxCores)
+ temp.writeUTF(sc.ui.appUIAddress)
+ temp.close()
+
+ // "Atomic" rename
+ fs.rename(tmpPath, filePath)
+ }
+
+ override def stop() {
+ val conf = new Configuration()
+ val fs = FileSystem.get(conf)
+ fs.delete(new Path(driverFilePath), false)
+ super.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index cefa970bb9..7127a72d6d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -42,7 +42,7 @@ private[spark] class SparkDeploySchedulerBackend(
super.start()
// The endpoint for executors to talk to us
- val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
index 4312c46cc1..e68c527713 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
@@ -50,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
case directResult: DirectTaskResult[_] => directResult
case IndirectTaskResult(blockId) =>
logDebug("Fetching indirect task result for TID %s".format(tid))
+ scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
if (!serializedTaskResult.isDefined) {
/* We won't be able to get the task result if the machine that ran the task failed
@@ -70,7 +71,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread.getContextClassLoader
taskSetManager.abort("ClassNotFound with classloader: " + loader)
- case ex =>
+ case ex: Throwable =>
taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
}
}
@@ -94,7 +95,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
val loader = Thread.currentThread.getContextClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
- case ex => {}
+ case ex: Throwable => {}
}
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 300fe693f1..84fe3094cc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -120,7 +120,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
- val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"),
System.getProperty("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
@@ -181,6 +181,7 @@ private[spark] class CoarseMesosSchedulerBackend(
!slaveIdsWithExecutors.contains(slaveId)) {
// Launch an executor on the slave
val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired)
+ totalCoresAcquired += cpusToUse
val taskId = newMesosTaskId()
taskIdToSlaveId(taskId) = slaveId
slaveIdsWithExecutors += slaveId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index 2699f0b33e..01e95162c0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -74,7 +74,7 @@ class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
}
}
-private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
+private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
with ExecutorBackend
with Logging {
@@ -144,7 +144,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
localActor ! KillTask(tid)
}
}
- tsm.error("Stage %d was cancelled".format(stageId))
+ logInfo("Stage %d was cancelled".format(stageId))
+ taskSetFinished(tsm)
}
}
@@ -192,17 +193,19 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
synchronized {
taskIdToTaskSetId.get(taskId) match {
case Some(taskSetId) =>
- val taskSetManager = activeTaskSets(taskSetId)
- taskSetTaskIds(taskSetId) -= taskId
-
- state match {
- case TaskState.FINISHED =>
- taskSetManager.taskEnded(taskId, state, serializedData)
- case TaskState.FAILED =>
- taskSetManager.taskFailed(taskId, state, serializedData)
- case TaskState.KILLED =>
- taskSetManager.error("Task %d was killed".format(taskId))
- case _ => {}
+ val taskSetManager = activeTaskSets.get(taskSetId)
+ taskSetManager.foreach { tsm =>
+ taskSetTaskIds(taskSetId) -= taskId
+
+ state match {
+ case TaskState.FINISHED =>
+ tsm.taskEnded(taskId, state, serializedData)
+ case TaskState.FAILED =>
+ tsm.taskFailed(taskId, state, serializedData)
+ case TaskState.KILLED =>
+ tsm.error("Task %d was killed".format(taskId))
+ case _ => {}
+ }
}
case None =>
logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
index 55f8313e87..53bf78267e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -175,7 +175,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
reason.className, reason.description, locs.mkString("\n")))
if (numFailures(index) > MAX_TASK_FAILURES) {
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
- taskSet.id, index, 4, reason.description)
+ taskSet.id, index, MAX_TASK_FAILURES, reason.description)
decreaseRunningTasks(runningTasks)
sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 55b25f145a..e748c2275d 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -27,13 +27,17 @@ import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar}
import org.apache.spark.{SerializableWritable, Logging}
import org.apache.spark.broadcast.HttpBroadcast
-import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId}
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage._
/**
- * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
+ * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
*/
class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging {
- private val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+
+ private val bufferSize = {
+ System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024
+ }
def newKryoOutput() = new KryoOutput(bufferSize)
@@ -42,21 +46,11 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
val kryo = instantiator.newKryo()
val classLoader = Thread.currentThread.getContextClassLoader
- val blockId = TestBlockId("1")
- // Register some commonly used classes
- val toRegister: Seq[AnyRef] = Seq(
- ByteBuffer.allocate(1),
- StorageLevel.MEMORY_ONLY,
- PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
- GotBlock(blockId, ByteBuffer.allocate(1)),
- GetBlock(blockId),
- 1 to 10,
- 1 until 10,
- 1L to 10L,
- 1L until 10L
- )
-
- for (obj <- toRegister) kryo.register(obj.getClass)
+ // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
+ // Do this before we invoke the user registrator so the user registrator can override this.
+ kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean)
+
+ for (cls <- KryoSerializer.toRegister) kryo.register(cls)
// Allow sending SerializableWritable
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
@@ -78,10 +72,6 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
new AllScalaRegistrar().apply(kryo)
kryo.setClassLoader(classLoader)
-
- // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops
- kryo.setReferences(System.getProperty("spark.kryo.referenceTracking", "true").toBoolean)
-
kryo
}
@@ -165,3 +155,21 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
trait KryoRegistrator {
def registerClasses(kryo: Kryo)
}
+
+private[serializer] object KryoSerializer {
+ // Commonly used classes.
+ private val toRegister: Seq[Class[_]] = Seq(
+ ByteBuffer.allocate(1).getClass,
+ classOf[StorageLevel],
+ classOf[PutBlock],
+ classOf[GotBlock],
+ classOf[GetBlock],
+ classOf[MapStatus],
+ classOf[BlockManagerId],
+ classOf[Array[Byte]],
+ (1 to 10).getClass,
+ (1 until 10).getClass,
+ (1L to 10L).getClass,
+ (1L until 10L).getClass
+ )
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index c7efc67a4a..7156d855d8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -32,7 +32,7 @@ private[spark] sealed abstract class BlockId {
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
def isRDD = isInstanceOf[RDDBlockId]
def isShuffle = isInstanceOf[ShuffleBlockId]
- def isBroadcast = isInstanceOf[BroadcastBlockId]
+ def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
override def toString = name
override def hashCode = name.hashCode
@@ -55,6 +55,10 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def name = "broadcast_" + broadcastId
}
+private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
+ def name = broadcastId.name + "_" + hType
+}
+
private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
def name = "taskresult_" + taskId
}
@@ -72,6 +76,7 @@ private[spark] object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)".r
+ val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEST = "test_(.*)".r
@@ -84,6 +89,8 @@ private[spark] object BlockId {
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId) =>
BroadcastBlockId(broadcastId.toLong)
+ case BROADCAST_HELPER(broadcastId, hType) =>
+ BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
case TASKRESULT(taskId) =>
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
new file mode 100644
index 0000000000..c8f397609a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.ConcurrentHashMap
+
+private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
+ // To save space, 'pending' and 'failed' are encoded as special sizes:
+ @volatile var size: Long = BlockInfo.BLOCK_PENDING
+ private def pending: Boolean = size == BlockInfo.BLOCK_PENDING
+ private def failed: Boolean = size == BlockInfo.BLOCK_FAILED
+ private def initThread: Thread = BlockInfo.blockInfoInitThreads.get(this)
+
+ setInitThread()
+
+ private def setInitThread() {
+ // Set current thread as init thread - waitForReady will not block this thread
+ // (in case there is non trivial initialization which ends up calling waitForReady as part of
+ // initialization itself)
+ BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread())
+ }
+
+ /**
+ * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
+ * Return true if the block is available, false otherwise.
+ */
+ def waitForReady(): Boolean = {
+ if (pending && initThread != Thread.currentThread()) {
+ synchronized {
+ while (pending) this.wait()
+ }
+ }
+ !failed
+ }
+
+ /** Mark this BlockInfo as ready (i.e. block is finished writing) */
+ def markReady(sizeInBytes: Long) {
+ require (sizeInBytes >= 0, "sizeInBytes was negative: " + sizeInBytes)
+ assert (pending)
+ size = sizeInBytes
+ BlockInfo.blockInfoInitThreads.remove(this)
+ synchronized {
+ this.notifyAll()
+ }
+ }
+
+ /** Mark this BlockInfo as ready but failed */
+ def markFailure() {
+ assert (pending)
+ size = BlockInfo.BLOCK_FAILED
+ BlockInfo.blockInfoInitThreads.remove(this)
+ synchronized {
+ this.notifyAll()
+ }
+ }
+}
+
+private object BlockInfo {
+ // initThread is logically a BlockInfo field, but we store it here because
+ // it's only needed while this block is in the 'pending' state and we want
+ // to minimize BlockInfo's memory footprint.
+ private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread]
+
+ private val BLOCK_PENDING: Long = -1L
+ private val BLOCK_FAILED: Long = -2L
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 801f88a3db..19a025a329 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,17 +17,18 @@
package org.apache.spark.storage
-import java.io.{InputStream, OutputStream}
+import java.io.{File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
+import scala.collection.mutable.{HashMap, ArrayBuffer}
+import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
-import akka.dispatch.{Await, Future}
-import akka.util.Duration
-import akka.util.duration._
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration.Duration
+import scala.concurrent.duration._
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
@@ -45,74 +46,20 @@ private[spark] class BlockManager(
maxMemory: Long)
extends Logging {
- private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
- @volatile var pending: Boolean = true
- @volatile var size: Long = -1L
- @volatile var initThread: Thread = null
- @volatile var failed = false
-
- setInitThread()
-
- private def setInitThread() {
- // Set current thread as init thread - waitForReady will not block this thread
- // (in case there is non trivial initialization which ends up calling waitForReady as part of
- // initialization itself)
- this.initThread = Thread.currentThread()
- }
-
- /**
- * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
- * Return true if the block is available, false otherwise.
- */
- def waitForReady(): Boolean = {
- if (initThread != Thread.currentThread() && pending) {
- synchronized {
- while (pending) this.wait()
- }
- }
- !failed
- }
-
- /** Mark this BlockInfo as ready (i.e. block is finished writing) */
- def markReady(sizeInBytes: Long) {
- assert (pending)
- size = sizeInBytes
- initThread = null
- failed = false
- initThread = null
- pending = false
- synchronized {
- this.notifyAll()
- }
- }
-
- /** Mark this BlockInfo as ready but failed */
- def markFailure() {
- assert (pending)
- size = 0
- initThread = null
- failed = true
- initThread = null
- pending = false
- synchronized {
- this.notifyAll()
- }
- }
- }
-
val shuffleBlockManager = new ShuffleBlockManager(this)
+ val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: DiskStore =
- new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+ private[storage] val diskStore = new DiskStore(this, diskBlockManager)
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
- if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
+ if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
val connectionManager = new ConnectionManager(0)
@@ -269,7 +216,7 @@ private[spark] class BlockManager(
}
/**
- * Actually send a UpdateBlockInfo message. Returns the mater's response,
+ * Actually send a UpdateBlockInfo message. Returns the master's response,
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
@@ -320,89 +267,14 @@ private[spark] class BlockManager(
*/
def getLocal(blockId: BlockId): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
- val info = blockInfo.get(blockId).orNull
- if (info != null) {
- info.synchronized {
-
- // In the another thread is writing the block, wait for it to become ready.
- if (!info.waitForReady()) {
- // If we get here, the block write failed.
- logWarning("Block " + blockId + " was marked as failure.")
- return None
- }
-
- val level = info.level
- logDebug("Level for block " + blockId + " is " + level)
-
- // Look for the block in memory
- if (level.useMemory) {
- logDebug("Getting block " + blockId + " from memory")
- memoryStore.getValues(blockId) match {
- case Some(iterator) =>
- return Some(iterator)
- case None =>
- logDebug("Block " + blockId + " not found in memory")
- }
- }
-
- // Look for block on disk, potentially loading it back into memory if required
- if (level.useDisk) {
- logDebug("Getting block " + blockId + " from disk")
- if (level.useMemory && level.deserialized) {
- diskStore.getValues(blockId) match {
- case Some(iterator) =>
- // Put the block back in memory before returning it
- // TODO: Consider creating a putValues that also takes in a iterator ?
- val elements = new ArrayBuffer[Any]
- elements ++= iterator
- memoryStore.putValues(blockId, elements, level, true).data match {
- case Left(iterator2) =>
- return Some(iterator2)
- case _ =>
- throw new Exception("Memory store did not return back an iterator")
- }
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- } else if (level.useMemory && !level.deserialized) {
- // Read it as a byte buffer into memory first, then return it
- diskStore.getBytes(blockId) match {
- case Some(bytes) =>
- // Put a copy of the block back in memory before returning it. Note that we can't
- // put the ByteBuffer returned by the disk store as that's a memory-mapped file.
- // The use of rewind assumes this.
- assert (0 == bytes.position())
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- memoryStore.putBytes(blockId, copyForMemory, level)
- bytes.rewind()
- return Some(dataDeserialize(blockId, bytes))
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- } else {
- diskStore.getValues(blockId) match {
- case Some(iterator) =>
- return Some(iterator)
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- }
- }
- }
- } else {
- logDebug("Block " + blockId + " not registered locally")
- }
- return None
+ doGetLocal(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
}
/**
* Get block from the local block manager as serialized bytes.
*/
def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = {
- // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow
logDebug("Getting local block " + blockId + " as bytes")
-
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
if (blockId.isShuffle) {
@@ -413,12 +285,15 @@ private[spark] class BlockManager(
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
}
+ doGetLocal(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
+ }
+ private def doGetLocal(blockId: BlockId, asValues: Boolean): Option[Any] = {
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
- // In the another thread is writing the block, wait for it to become ready.
+ // If another thread is writing the block, wait for it to become ready.
if (!info.waitForReady()) {
// If we get here, the block write failed.
logWarning("Block " + blockId + " was marked as failure.")
@@ -431,62 +306,104 @@ private[spark] class BlockManager(
// Look for the block in memory
if (level.useMemory) {
logDebug("Getting block " + blockId + " from memory")
- memoryStore.getBytes(blockId) match {
- case Some(bytes) =>
- return Some(bytes)
+ val result = if (asValues) {
+ memoryStore.getValues(blockId)
+ } else {
+ memoryStore.getBytes(blockId)
+ }
+ result match {
+ case Some(values) =>
+ return Some(values)
case None =>
logDebug("Block " + blockId + " not found in memory")
}
}
- // Look for block on disk
+ // Look for block on disk, potentially storing it back into memory if required:
if (level.useDisk) {
- // Read it as a byte buffer into memory first, then return it
- diskStore.getBytes(blockId) match {
- case Some(bytes) =>
- assert (0 == bytes.position())
- if (level.useMemory) {
- if (level.deserialized) {
- memoryStore.putBytes(blockId, bytes, level)
- } else {
- // The memory store will hang onto the ByteBuffer, so give it a copy instead of
- // the memory-mapped file buffer we got from the disk store
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- memoryStore.putBytes(blockId, copyForMemory, level)
- }
- }
- bytes.rewind()
- return Some(bytes)
+ logDebug("Getting block " + blockId + " from disk")
+ val bytes: ByteBuffer = diskStore.getBytes(blockId) match {
+ case Some(bytes) => bytes
case None =>
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
+ assert (0 == bytes.position())
+
+ if (!level.useMemory) {
+ // If the block shouldn't be stored in memory, we can just return it:
+ if (asValues) {
+ return Some(dataDeserialize(blockId, bytes))
+ } else {
+ return Some(bytes)
+ }
+ } else {
+ // Otherwise, we also have to store something in the memory store:
+ if (!level.deserialized || !asValues) {
+ // We'll store the bytes in memory if the block's storage level includes
+ // "memory serialized", or if it should be cached as objects in memory
+ // but we only requested its serialized bytes:
+ val copyForMemory = ByteBuffer.allocate(bytes.limit)
+ copyForMemory.put(bytes)
+ memoryStore.putBytes(blockId, copyForMemory, level)
+ bytes.rewind()
+ }
+ if (!asValues) {
+ return Some(bytes)
+ } else {
+ val values = dataDeserialize(blockId, bytes)
+ if (level.deserialized) {
+ // Cache the values before returning them:
+ // TODO: Consider creating a putValues that also takes in a iterator?
+ val valuesBuffer = new ArrayBuffer[Any]
+ valuesBuffer ++= values
+ memoryStore.putValues(blockId, valuesBuffer, level, true).data match {
+ case Left(values2) =>
+ return Some(values2)
+ case _ =>
+ throw new Exception("Memory store did not return back an iterator")
+ }
+ } else {
+ return Some(values)
+ }
+ }
+ }
}
}
} else {
logDebug("Block " + blockId + " not registered locally")
}
- return None
+ None
}
/**
* Get block from remote block managers.
*/
def getRemote(blockId: BlockId): Option[Iterator[Any]] = {
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
logDebug("Getting remote block " + blockId)
- // Get locations of block
- val locations = master.getLocations(blockId)
+ doGetRemote(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
+ }
- // Get block from remote locations
+ /**
+ * Get block from remote block managers as serialized bytes.
+ */
+ def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
+ logDebug("Getting remote block " + blockId + " as bytes")
+ doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
+ }
+
+ private def doGetRemote(blockId: BlockId, asValues: Boolean): Option[Any] = {
+ require(blockId != null, "BlockId is null")
+ val locations = Random.shuffle(master.getLocations(blockId))
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
- return Some(dataDeserialize(blockId, data))
+ if (asValues) {
+ return Some(dataDeserialize(blockId, data))
+ } else {
+ return Some(data)
+ }
}
logDebug("The value of block " + blockId + " is null")
}
@@ -495,31 +412,6 @@ private[spark] class BlockManager(
}
/**
- * Get block from remote block managers as serialized bytes.
- */
- def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
- // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be
- // refactored.
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- logDebug("Getting remote block " + blockId + " as bytes")
-
- val locations = master.getLocations(blockId)
- for (loc <- locations) {
- logDebug("Getting remote block " + blockId + " from " + loc)
- val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
- if (data != null) {
- return Some(data)
- }
- logDebug("The value of block " + blockId + " is null")
- }
- logDebug("Block " + blockId + " not found")
- return None
- }
-
- /**
* Get a block from the block manager (either local or remote).
*/
def get(blockId: BlockId): Option[Iterator[Any]] = {
@@ -566,35 +458,38 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
+ * The Block will be appended to the File specified by filename.
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
- def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
+ def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
- val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
- writer.registerCloseEventHandler(() => {
- val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
- blockInfo.put(blockId, myInfo)
- myInfo.markReady(writer.size())
- })
- writer
+ val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
+ new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
}
/**
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
- tellMaster: Boolean = true) : Long = {
+ tellMaster: Boolean = true) : Long = {
+ require(values != null, "Values is null")
+ doPut(blockId, Left(values), level, tellMaster)
+ }
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (values == null) {
- throw new IllegalArgumentException("Values is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
+ /**
+ * Put a new block of serialized bytes to the block manager.
+ */
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel,
+ tellMaster: Boolean = true) {
+ require(bytes != null, "Bytes is null")
+ doPut(blockId, Right(bytes), level, tellMaster)
+ }
+
+ private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer],
+ level: StorageLevel, tellMaster: Boolean = true): Long = {
+ require(blockId != null, "BlockId is null")
+ require(level != null && level.isValid, "StorageLevel is null or invalid")
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
@@ -610,7 +505,8 @@ private[spark] class BlockManager(
return oldBlockOpt.get.size
}
- // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ // TODO: So the block info exists - but previous attempt to load it (?) failed.
+ // What do we do now ? Retry on it ?
oldBlockOpt.get
} else {
tinfo
@@ -619,10 +515,10 @@ private[spark] class BlockManager(
val startTimeMs = System.currentTimeMillis
- // If we need to replicate the data, we'll want access to the values, but because our
- // put will read the whole iterator, there will be no values left. For the case where
- // the put serializes data, we'll remember the bytes, above; but for the case where it
- // doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
+ // If we're storing values and we need to replicate the data, we'll want access to the values,
+ // but because our put will read the whole iterator, there will be no values left. For the
+ // case where the put serializes data, we'll remember the bytes, above; but for the case where
+ // it doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
var valuesAfterPut: Iterator[Any] = null
// Ditto for the bytes after the put
@@ -631,30 +527,51 @@ private[spark] class BlockManager(
// Size of the block in bytes (to return to caller)
var size = 0L
+ // If we're storing bytes, then initiate the replication before storing them locally.
+ // This is faster as data is already serialized and ready to send.
+ val replicationFuture = if (data.isRight && level.replication > 1) {
+ val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper
+ Future {
+ replicate(blockId, bufferView, level)
+ }
+ } else {
+ null
+ }
+
myInfo.synchronized {
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
var marked = false
try {
- if (level.useMemory) {
- // Save it just to memory first, even if it also has useDisk set to true; we will later
- // drop it to disk if the memory store can't hold it.
- val res = memoryStore.putValues(blockId, values, level, true)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case Left(newIterator) => valuesAfterPut = newIterator
+ data match {
+ case Left(values) => {
+ if (level.useMemory) {
+ // Save it just to memory first, even if it also has useDisk set to true; we will
+ // drop it to disk later if the memory store can't hold it.
+ val res = memoryStore.putValues(blockId, values, level, true)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case Left(newIterator) => valuesAfterPut = newIterator
+ }
+ } else {
+ // Save directly to disk.
+ // Don't get back the bytes unless we replicate them.
+ val askForBytes = level.replication > 1
+ val res = diskStore.putValues(blockId, values, level, askForBytes)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
+ }
+ }
}
- } else {
- // Save directly to disk.
- // Don't get back the bytes unless we replicate them.
- val askForBytes = level.replication > 1
- val res = diskStore.putValues(blockId, values, level, askForBytes)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case _ =>
+ case Right(bytes) => {
+ bytes.rewind()
+ // Store it only in memory at first, even if useDisk is also set to true
+ (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level)
+ size = bytes.limit
}
}
@@ -679,125 +596,39 @@ private[spark] class BlockManager(
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
- // Replicate block if required
+ // Either we're storing bytes and we asynchronously started replication, or we're storing
+ // values and need to serialize and replicate them now:
if (level.replication > 1) {
- val remoteStartTime = System.currentTimeMillis
- // Serialize the block if not already done
- if (bytesAfterPut == null) {
- if (valuesAfterPut == null) {
- throw new SparkException(
- "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
- }
- bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
- }
- replicate(blockId, bytesAfterPut, level)
- logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime))
- }
- BlockManager.dispose(bytesAfterPut)
-
- return size
- }
-
-
- /**
- * Put a new block of serialized bytes to the block manager.
- */
- def putBytes(
- blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
-
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (bytes == null) {
- throw new IllegalArgumentException("Bytes is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
-
- // Remember the block's storage level so that we can correctly drop it to disk if it needs
- // to be dropped right after it got put into memory. Note, however, that other threads will
- // not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = {
- val tinfo = new BlockInfo(level, tellMaster)
- // Do atomically !
- val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
-
- if (oldBlockOpt.isDefined) {
- if (oldBlockOpt.get.waitForReady()) {
- logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return
- }
-
- // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
- oldBlockOpt.get
- } else {
- tinfo
- }
- }
-
- val startTimeMs = System.currentTimeMillis
-
- // Initiate the replication before storing it locally. This is faster as
- // data is already serialized and ready for sending
- val replicationFuture = if (level.replication > 1) {
- val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper
- Future {
- replicate(blockId, bufferView, level)
- }
- } else {
- null
- }
-
- myInfo.synchronized {
- logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
- + " to get into synchronized block")
-
- var marked = false
- try {
- if (level.useMemory) {
- // Store it only in memory at first, even if useDisk is also set to true
- bytes.rewind()
- memoryStore.putBytes(blockId, bytes, level)
- } else {
- bytes.rewind()
- diskStore.putBytes(blockId, bytes, level)
- }
-
- // assert (0 == bytes.position(), "" + bytes)
-
- // Now that the block is in either the memory or disk store, let other threads read it,
- // and tell the master about it.
- marked = true
- myInfo.markReady(bytes.limit)
- if (tellMaster) {
- reportBlockStatus(blockId, myInfo)
- }
- } finally {
- // If we failed at putting the block to memory/disk, notify other possible readers
- // that it has failed, and then remove it from the block info map.
- if (! marked) {
- // Note that the remove must happen before markFailure otherwise another thread
- // could've inserted a new BlockInfo before we remove it.
- blockInfo.remove(blockId)
- myInfo.markFailure()
- logWarning("Putting block " + blockId + " failed")
+ data match {
+ case Right(bytes) => Await.ready(replicationFuture, Duration.Inf)
+ case Left(values) => {
+ val remoteStartTime = System.currentTimeMillis
+ // Serialize the block if not already done
+ if (bytesAfterPut == null) {
+ if (valuesAfterPut == null) {
+ throw new SparkException(
+ "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
+ }
+ bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
+ }
+ replicate(blockId, bytesAfterPut, level)
+ logDebug("Put block " + blockId + " remotely took " +
+ Utils.getUsedTimeMs(remoteStartTime))
}
}
}
- // If replication had started, then wait for it to finish
- if (level.replication > 1) {
- Await.ready(replicationFuture, Duration.Inf)
- }
+ BlockManager.dispose(bytesAfterPut)
if (level.replication > 1) {
- logDebug("PutBytes for block " + blockId + " with replication took " +
+ logDebug("Put for block " + blockId + " with replication took " +
Utils.getUsedTimeMs(startTimeMs))
} else {
- logDebug("PutBytes for block " + blockId + " without replication took " +
+ logDebug("Put for block " + blockId + " without replication took " +
Utils.getUsedTimeMs(startTimeMs))
}
+
+ size
}
/**
@@ -922,34 +753,20 @@ private[spark] class BlockManager(
private def dropOldNonBroadcastBlocks(cleanupTime: Long) {
logInfo("Dropping non broadcast blocks older than " + cleanupTime)
- val iterator = blockInfo.internalMap.entrySet().iterator()
- while (iterator.hasNext) {
- val entry = iterator.next()
- val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
- if (time < cleanupTime && !id.isBroadcast) {
- info.synchronized {
- val level = info.level
- if (level.useMemory) {
- memoryStore.remove(id)
- }
- if (level.useDisk) {
- diskStore.remove(id)
- }
- iterator.remove()
- logInfo("Dropped block " + id)
- }
- reportBlockStatus(id, info)
- }
- }
+ dropOldBlocks(cleanupTime, !_.isBroadcast)
}
private def dropOldBroadcastBlocks(cleanupTime: Long) {
logInfo("Dropping broadcast blocks older than " + cleanupTime)
+ dropOldBlocks(cleanupTime, _.isBroadcast)
+ }
+
+ private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
val iterator = blockInfo.internalMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
- if (time < cleanupTime && id.isBroadcast) {
+ if (time < cleanupTime && shouldDrop(id)) {
info.synchronized {
val level = info.level
if (level.useMemory) {
@@ -987,13 +804,24 @@ private[spark] class BlockManager(
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
+ /** Serializes into a stream. */
+ def dataSerializeStream(
+ blockId: BlockId,
+ outputStream: OutputStream,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer) {
+ val byteStream = new FastBufferedOutputStream(outputStream)
+ val ser = serializer.newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ }
+
+ /** Serializes into a byte buffer. */
def dataSerialize(
blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
- val ser = serializer.newInstance()
- ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ dataSerializeStream(blockId, byteStream, values, serializer)
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}
@@ -1063,9 +891,9 @@ private[spark] object BlockManager extends Logging {
blockManagerMaster: BlockManagerMaster = null)
: Map[BlockId, Seq[BlockManagerId]] =
{
- // env == null and blockManagerMaster != null is used in tests
+ // blockManagerMaster != null is used in tests
assert (env != null || blockManagerMaster != null)
- val blockLocations: Seq[Seq[BlockManagerId]] = if (env != null) {
+ val blockLocations: Seq[Seq[BlockManagerId]] = if (blockManagerMaster == null) {
env.blockManager.getLocationBlockIds(blockIds)
} else {
blockManagerMaster.getLocations(blockIds)
@@ -1096,4 +924,3 @@ private[spark] object BlockManager extends Logging {
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 94038649b3..e1d68ef592 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -17,23 +17,25 @@
package org.apache.spark.storage
-import akka.actor.ActorRef
-import akka.dispatch.{Await, Future}
+import scala.concurrent.{Await, Future}
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import akka.actor._
import akka.pattern.ask
-import akka.util.Duration
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.storage.BlockManagerMessages._
+import org.apache.spark.util.AkkaUtils
-
-private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {
+private[spark]
+class BlockManagerMaster(var driverActor : Either[ActorRef, ActorSelection]) extends Logging {
val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
- val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val timeout = AkkaUtils.askTimeout
/** Remove a dead executor from the driver actor. This is only called on the driver side. */
def removeExecutor(execId: String) {
@@ -156,7 +158,10 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
while (attempts < AKKA_RETRY_ATTEMPTS) {
attempts += 1
try {
- val future = driverActor.ask(message)(timeout)
+ val future = driverActor match {
+ case Left(a: ActorRef) => a.ask(message)(timeout)
+ case Right(b: ActorSelection) => b.ask(message)(timeout)
+ }
val result = Await.result(future, timeout)
if (result == null) {
throw new SparkException("BlockManagerMaster returned null")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 633230c0a8..21022e1cfb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -21,17 +21,15 @@ import java.util.{HashMap => JHashMap}
import scala.collection.mutable
import scala.collection.JavaConversions._
+import scala.concurrent.Future
+import scala.concurrent.duration._
import akka.actor.{Actor, ActorRef, Cancellable}
-import akka.dispatch.Future
import akka.pattern.ask
-import akka.util.Duration
-import akka.util.duration._
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.storage.BlockManagerMessages._
-import org.apache.spark.util.Utils
-
+import org.apache.spark.util.{AkkaUtils, Utils}
/**
* BlockManagerMasterActor is an actor on the master node to track statuses of
@@ -50,8 +48,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// Mapping from block id to the set of block managers that have the block.
private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
- val akkaTimeout = Duration.create(
- System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ private val akkaTimeout = AkkaUtils.askTimeout
initLogging()
@@ -65,6 +62,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
override def preStart() {
if (!BlockManager.getDisableHeartBeatsForTesting) {
+ import context.dispatcher
timeoutCheckingTask = context.system.scheduler.schedule(
0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
}
@@ -227,9 +225,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
- if (id.executorId == "<driver>" && !isLocal) {
- // Got a register message from the master node; don't register it
- } else if (!blockManagerInfo.contains(id)) {
+ if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
// A block manager of the same executor already exists.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index 951503019f..3a65e55733 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -26,6 +26,7 @@ import org.apache.spark.storage.BlockManagerMessages._
* An actor to take commands from the master to execute options. For example,
* this is used to remove blocks from the slave's BlockManager.
*/
+private[storage]
class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
override def receive = {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 2a67800c45..b4451fc7b8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -17,6 +17,13 @@
package org.apache.spark.storage
+import java.io.{FileOutputStream, File, OutputStream}
+import java.nio.channels.FileChannel
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import org.apache.spark.Logging
+import org.apache.spark.serializer.{SerializationStream, Serializer}
/**
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -27,20 +34,12 @@ package org.apache.spark.storage
*/
abstract class BlockObjectWriter(val blockId: BlockId) {
- var closeEventHandler: () => Unit = _
-
def open(): BlockObjectWriter
- def close() {
- closeEventHandler()
- }
+ def close()
def isOpen: Boolean
- def registerCloseEventHandler(handler: () => Unit) {
- closeEventHandler = handler
- }
-
/**
* Flush the partial writes and commit them as a single atomic block. Return the
* number of bytes written for this commit.
@@ -59,7 +58,128 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
def write(value: Any)
/**
- * Size of the valid writes, in bytes.
+ * Returns the file segment of committed data that this Writer has written.
+ */
+ def fileSegment(): FileSegment
+
+ /**
+ * Cumulative time spent performing blocking writes, in ns.
*/
- def size(): Long
+ def timeWriting(): Long
+}
+
+/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
+class DiskBlockObjectWriter(
+ blockId: BlockId,
+ file: File,
+ serializer: Serializer,
+ bufferSize: Int,
+ compressStream: OutputStream => OutputStream)
+ extends BlockObjectWriter(blockId)
+ with Logging
+{
+
+ /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
+ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
+ def timeWriting = _timeWriting
+ private var _timeWriting = 0L
+
+ private def callWithTiming(f: => Unit) = {
+ val start = System.nanoTime()
+ f
+ _timeWriting += (System.nanoTime() - start)
+ }
+
+ def write(i: Int): Unit = callWithTiming(out.write(i))
+ override def write(b: Array[Byte]) = callWithTiming(out.write(b))
+ override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
+ override def close() = out.close()
+ override def flush() = out.flush()
+ }
+
+ private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
+
+ /** The file channel, used for repositioning / truncating the file. */
+ private var channel: FileChannel = null
+ private var bs: OutputStream = null
+ private var fos: FileOutputStream = null
+ private var ts: TimeTrackingOutputStream = null
+ private var objOut: SerializationStream = null
+ private val initialPosition = file.length()
+ private var lastValidPosition = initialPosition
+ private var initialized = false
+ private var _timeWriting = 0L
+
+ override def open(): BlockObjectWriter = {
+ fos = new FileOutputStream(file, true)
+ ts = new TimeTrackingOutputStream(fos)
+ channel = fos.getChannel()
+ lastValidPosition = initialPosition
+ bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
+ objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
+ this
+ }
+
+ override def close() {
+ if (initialized) {
+ if (syncWrites) {
+ // Force outstanding writes to disk and track how long it takes
+ objOut.flush()
+ val start = System.nanoTime()
+ fos.getFD.sync()
+ _timeWriting += System.nanoTime() - start
+ }
+ objOut.close()
+
+ _timeWriting += ts.timeWriting
+
+ channel = null
+ bs = null
+ fos = null
+ ts = null
+ objOut = null
+ }
+ }
+
+ override def isOpen: Boolean = objOut != null
+
+ override def commit(): Long = {
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
+ }
+
+ override def revertPartialWrites() {
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
+ }
+
+ override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
+ objOut.writeObject(value)
+ }
+
+ override def fileSegment(): FileSegment = {
+ val bytesWritten = lastValidPosition - initialPosition
+ new FileSegment(file, initialPosition, bytesWritten)
+ }
+
+ // Only valid if called after close()
+ override def timeWriting() = _timeWriting
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
new file mode 100644
index 0000000000..fcd2e97982
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+import java.text.SimpleDateFormat
+import java.util.{Date, Random}
+
+import org.apache.spark.Logging
+import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
+import org.apache.spark.util.Utils
+
+/**
+ * Creates and maintains the logical mapping between logical blocks and physical on-disk
+ * locations. By default, one block is mapped to one file with a name given by its BlockId.
+ * However, it is also possible to have a block map to only a segment of a file, by calling
+ * mapBlockToFileSegment().
+ *
+ * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
+ */
+private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String)
+ extends PathResolver with Logging {
+
+ private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+
+ // Create one local directory for each path mentioned in spark.local.dir; then, inside this
+ // directory, create multiple subdirectories that we will hash files into, in order to avoid
+ // having really large inodes at the top level.
+ private val localDirs: Array[File] = createLocalDirs()
+ private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ private var shuffleSender : ShuffleSender = null
+
+ addShutdownHook()
+
+ /**
+ * Returns the phyiscal file segment in which the given BlockId is located.
+ * If the BlockId has been mapped to a specific FileSegment, that will be returned.
+ * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
+ */
+ def getBlockLocation(blockId: BlockId): FileSegment = {
+ if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
+ shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
+ } else {
+ val file = getFile(blockId.name)
+ new FileSegment(file, 0, file.length())
+ }
+ }
+
+ def getFile(filename: String): File = {
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = Utils.nonNegativeHash(filename)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+
+ // Create the subdirectory if it doesn't already exist
+ var subDir = subDirs(dirId)(subDirId)
+ if (subDir == null) {
+ subDir = subDirs(dirId).synchronized {
+ val old = subDirs(dirId)(subDirId)
+ if (old != null) {
+ old
+ } else {
+ val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ newDir.mkdir()
+ subDirs(dirId)(subDirId) = newDir
+ newDir
+ }
+ }
+ }
+
+ new File(subDir, filename)
+ }
+
+ def getFile(blockId: BlockId): File = getFile(blockId.name)
+
+ private def createLocalDirs(): Array[File] = {
+ logDebug("Creating local directories at root dirs '" + rootDirs + "'")
+ val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirId: String = null
+ var tries = 0
+ val rand = new Random()
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
+ localDir = new File(rootDir, "spark-local-" + localDirId)
+ if (!localDir.exists) {
+ foundLocalDir = localDir.mkdirs()
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
+ " attempts to create local dir in " + rootDir)
+ System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
+ }
+ logInfo("Created local directory at " + localDir)
+ localDir
+ }
+ }
+
+ private def addShutdownHook() {
+ localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ localDirs.foreach { localDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
+ }
+ }
+
+ if (shuffleSender != null) {
+ shuffleSender.stop()
+ }
+ }
+ })
+ }
+
+ private[storage] def startShuffleBlockSender(port: Int): Int = {
+ shuffleSender = new ShuffleSender(port, this)
+ logInfo("Created ShuffleSender binding to port : " + shuffleSender.port)
+ shuffleSender.port
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index b7ca61e938..5a1e7b4444 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,120 +17,25 @@
package org.apache.spark.storage
-import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
+import java.io.{FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
-import java.util.{Random, Date}
-import java.text.SimpleDateFormat
import scala.collection.mutable.ArrayBuffer
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.serializer.{Serializer, SerializationStream}
import org.apache.spark.Logging
-import org.apache.spark.network.netty.ShuffleSender
-import org.apache.spark.network.netty.PathResolver
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
/**
* Stores BlockManager blocks on disk.
*/
-private class DiskStore(blockManager: BlockManager, rootDirs: String)
+private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
- class DiskBlockObjectWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
- extends BlockObjectWriter(blockId) {
-
- private val f: File = createFile(blockId /*, allowAppendExisting */)
-
- // The file channel, used for repositioning / truncating the file.
- private var channel: FileChannel = null
- private var bs: OutputStream = null
- private var objOut: SerializationStream = null
- private var lastValidPosition = 0L
- private var initialized = false
-
- override def open(): DiskBlockObjectWriter = {
- val fos = new FileOutputStream(f, true)
- channel = fos.getChannel()
- bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
- objOut = serializer.newInstance().serializeStream(bs)
- initialized = true
- this
- }
-
- override def close() {
- if (initialized) {
- objOut.close()
- channel = null
- bs = null
- objOut = null
- }
- // Invoke the close callback handler.
- super.close()
- }
-
- override def isOpen: Boolean = objOut != null
-
- // Flush the partial writes, and set valid length to be the length of the entire file.
- // Return the number of bytes written for this commit.
- override def commit(): Long = {
- if (initialized) {
- // NOTE: Flush the serializer first and then the compressed/buffered output stream
- objOut.flush()
- bs.flush()
- val prevPos = lastValidPosition
- lastValidPosition = channel.position()
- lastValidPosition - prevPos
- } else {
- // lastValidPosition is zero if stream is uninitialized
- lastValidPosition
- }
- }
-
- override def revertPartialWrites() {
- if (initialized) {
- // Discard current writes. We do this by flushing the outstanding writes and
- // truncate the file to the last valid position.
- objOut.flush()
- bs.flush()
- channel.truncate(lastValidPosition)
- }
- }
-
- override def write(value: Any) {
- if (!initialized) {
- open()
- }
- objOut.writeObject(value)
- }
-
- override def size(): Long = lastValidPosition
- }
-
- private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
-
- private var shuffleSender : ShuffleSender = null
- // Create one local directory for each path mentioned in spark.local.dir; then, inside this
- // directory, create multiple subdirectories that we will hash files into, in order to avoid
- // having really large inodes at the top level.
- private val localDirs: Array[File] = createLocalDirs()
- private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
-
- addShutdownHook()
-
- def getBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
- : BlockObjectWriter = {
- new DiskBlockObjectWriter(blockId, serializer, bufferSize)
- }
-
override def getSize(blockId: BlockId): Long = {
- getFile(blockId).length()
+ diskManager.getBlockLocation(blockId).length
}
override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
@@ -139,27 +44,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val channel = new RandomAccessFile(file, "rw").getChannel()
+ val file = diskManager.getFile(blockId)
+ val channel = new FileOutputStream(file).getChannel()
while (bytes.remaining > 0) {
channel.write(bytes)
}
channel.close()
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
- }
-
- private def getFileBytes(file: File): ByteBuffer = {
- val length = file.length()
- val channel = new RandomAccessFile(file, "r").getChannel()
- val buffer = try {
- channel.map(MapMode.READ_ONLY, 0, length)
- } finally {
- channel.close()
- }
-
- buffer
+ file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
}
override def putValues(
@@ -171,21 +64,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val fileOut = blockManager.wrapForCompression(blockId,
- new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
- objOut.writeAll(values.iterator)
- objOut.close()
- val length = file.length()
+ val file = diskManager.getFile(blockId)
+ val outputStream = new FileOutputStream(file)
+ blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
+ val length = file.length
val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(length), timeTaken))
+ file.getName, Utils.bytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
- val buffer = getFileBytes(file)
+ val buffer = getBytes(blockId).get
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
@@ -193,13 +83,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
- val file = getFile(blockId)
- val bytes = getFileBytes(file)
- Some(bytes)
+ val segment = diskManager.getBlockLocation(blockId)
+ val channel = new RandomAccessFile(segment.file, "r").getChannel()
+ val buffer = try {
+ channel.map(MapMode.READ_ONLY, segment.offset, segment.length)
+ } finally {
+ channel.close()
+ }
+ Some(buffer)
}
override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
- getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
+ getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
}
/**
@@ -211,118 +106,20 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def remove(blockId: BlockId): Boolean = {
- val file = getFile(blockId)
- if (file.exists()) {
+ val fileSegment = diskManager.getBlockLocation(blockId)
+ val file = fileSegment.file
+ if (file.exists() && file.length() == fileSegment.length) {
file.delete()
} else {
+ if (fileSegment.length < file.length()) {
+ logWarning("Could not delete block associated with only a part of a file: " + blockId)
+ }
false
}
}
override def contains(blockId: BlockId): Boolean = {
- getFile(blockId).exists()
- }
-
- private def createFile(blockId: BlockId, allowAppendExisting: Boolean = false): File = {
- val file = getFile(blockId)
- if (!allowAppendExisting && file.exists()) {
- // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
- // was rescheduled on the same machine as the old task.
- logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
- file.delete()
- }
- file
- }
-
- private def getFile(blockId: BlockId): File = {
- logDebug("Getting file for block " + blockId)
-
- // Figure out which local directory it hashes to, and which subdirectory in that
- val hash = Utils.nonNegativeHash(blockId)
- val dirId = hash % localDirs.length
- val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
-
- // Create the subdirectory if it doesn't already exist
- var subDir = subDirs(dirId)(subDirId)
- if (subDir == null) {
- subDir = subDirs(dirId).synchronized {
- val old = subDirs(dirId)(subDirId)
- if (old != null) {
- old
- } else {
- val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
- newDir.mkdir()
- subDirs(dirId)(subDirId) = newDir
- newDir
- }
- }
- }
-
- new File(subDir, blockId.name)
- }
-
- private def createLocalDirs(): Array[File] = {
- logDebug("Creating local directories at root dirs '" + rootDirs + "'")
- val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").map { rootDir =>
- var foundLocalDir = false
- var localDir: File = null
- var localDirId: String = null
- var tries = 0
- val rand = new Random()
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
- localDir = new File(rootDir, "spark-local-" + localDirId)
- if (!localDir.exists) {
- foundLocalDir = localDir.mkdirs()
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
- " attempts to create local dir in " + rootDir)
- System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
- }
- logInfo("Created local directory at " + localDir)
- localDir
- }
- }
-
- private def addShutdownHook() {
- localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
- override def run() {
- logDebug("Shutdown hook called")
- localDirs.foreach { localDir =>
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case t: Throwable =>
- logError("Exception while deleting local spark dir: " + localDir, t)
- }
- }
- if (shuffleSender != null) {
- shuffleSender.stop()
- }
- }
- })
- }
-
- private[storage] def startShuffleBlockSender(port: Int): Int = {
- val pResolver = new PathResolver {
- override def getAbsolutePath(blockIdString: String): String = {
- val blockId = BlockId(blockIdString)
- if (!blockId.isShuffle) null
- else DiskStore.this.getFile(blockId).getAbsolutePath
- }
- }
- shuffleSender = new ShuffleSender(port, pResolver)
- logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
- shuffleSender.port
+ val file = diskManager.getBlockLocation(blockId).file
+ file.exists()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
new file mode 100644
index 0000000000..555486830a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+
+/**
+ * References a particular segment of a file (potentially the entire file),
+ * based off an offset and a length.
+ */
+private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) {
+ override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index f39fcd87fb..e828e1d1c5 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -17,36 +17,198 @@
package org.apache.spark.storage
-import org.apache.spark.serializer.Serializer
+import java.io.File
+import java.util.concurrent.ConcurrentLinkedQueue
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConversions._
-private[spark]
-class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
+/** A group of writers for a ShuffleMapTask, one writer per reducer. */
+private[spark] trait ShuffleWriterGroup {
+ val writers: Array[BlockObjectWriter]
-private[spark]
-trait ShuffleBlocks {
- def acquireWriters(mapId: Int): ShuffleWriterGroup
- def releaseWriters(group: ShuffleWriterGroup)
+ /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
+ def releaseWriters(success: Boolean)
}
+/**
+ * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file
+ * per reducer (this set of files is called a ShuffleFileGroup).
+ *
+ * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
+ * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle
+ * files, it releases them for another task.
+ * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
+ * - shuffleId: The unique id given to the entire shuffle stage.
+ * - bucketId: The id of the output partition (i.e., reducer id)
+ * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
+ * time owns a particular fileId, and this id is returned to a pool when the task finishes.
+ * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length)
+ * that specifies where in a given file the actual block data is located.
+ *
+ * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping
+ * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for
+ * each block stored in each file. In order to find the location of a shuffle block, we search the
+ * files within a ShuffleFileGroups associated with the block's reducer.
+ */
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
+ // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
+ // TODO: Remove this once the shuffle file consolidation feature is stable.
+ val consolidateShuffleFiles =
+ System.getProperty("spark.shuffle.consolidateFiles", "false").toBoolean
+
+ private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+
+ /**
+ * Contains all the state related to a particular shuffle. This includes a pool of unused
+ * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle.
+ */
+ private class ShuffleState() {
+ val nextFileId = new AtomicInteger(0)
+ val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+ val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+ }
+
+ type ShuffleId = Int
+ private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
- def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
- new ShuffleBlocks {
- // Get a group of writers for a map task.
- override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
- val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
- val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ private
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup)
+
+ def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
+ new ShuffleWriterGroup {
+ shuffleStates.putIfAbsent(shuffleId, new ShuffleState())
+ private val shuffleState = shuffleStates(shuffleId)
+ private var fileGroup: ShuffleFileGroup = null
+
+ val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+ fileGroup = getUnusedFileGroup()
+ Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+ val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+ blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
+ }
+ } else {
+ Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
+ val blockFile = blockManager.diskBlockManager.getFile(blockId)
+ blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
+ }
+ }
+
+ override def releaseWriters(success: Boolean) {
+ if (consolidateShuffleFiles) {
+ if (success) {
+ val offsets = writers.map(_.fileSegment().offset)
+ fileGroup.recordMapOutput(mapId, offsets)
+ }
+ recycleFileGroup(fileGroup)
+ }
+ }
+
+ private def getUnusedFileGroup(): ShuffleFileGroup = {
+ val fileGroup = shuffleState.unusedFileGroups.poll()
+ if (fileGroup != null) fileGroup else newFileGroup()
+ }
+
+ private def newFileGroup(): ShuffleFileGroup = {
+ val fileId = shuffleState.nextFileId.getAndIncrement()
+ val files = Array.tabulate[File](numBuckets) { bucketId =>
+ val filename = physicalFileName(shuffleId, bucketId, fileId)
+ blockManager.diskBlockManager.getFile(filename)
}
- new ShuffleWriterGroup(mapId, writers)
+ val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files)
+ shuffleState.allFileGroups.add(fileGroup)
+ fileGroup
+ }
+
+ private def recycleFileGroup(group: ShuffleFileGroup) {
+ shuffleState.unusedFileGroups.add(group)
+ }
+ }
+ }
+
+ /**
+ * Returns the physical file segment in which the given BlockId is located.
+ * This function should only be called if shuffle file consolidation is enabled, as it is
+ * an error condition if we don't find the expected block.
+ */
+ def getBlockLocation(id: ShuffleBlockId): FileSegment = {
+ // Search all file groups associated with this shuffle.
+ val shuffleState = shuffleStates(id.shuffleId)
+ for (fileGroup <- shuffleState.allFileGroups) {
+ val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId)
+ if (segment.isDefined) { return segment.get }
+ }
+ throw new IllegalStateException("Failed to find shuffle block: " + id)
+ }
+
+ private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
+ "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
+ }
+
+ private def cleanup(cleanupTime: Long) {
+ shuffleStates.clearOldValues(cleanupTime)
+ }
+}
+
+private[spark]
+object ShuffleBlockManager {
+ /**
+ * A group of shuffle files, one per reducer.
+ * A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
+ */
+ private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) {
+ /**
+ * Stores the absolute index of each mapId in the files of this group. For instance,
+ * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
+ */
+ private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
+
+ /**
+ * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
+ * This ordering allows us to compute block lengths by examining the following block offset.
+ * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every
+ * reducer.
+ */
+ private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) {
+ new PrimitiveVector[Long]()
+ }
+
+ def numBlocks = mapIdToIndex.size
+
+ def apply(bucketId: Int) = files(bucketId)
+
+ def recordMapOutput(mapId: Int, offsets: Array[Long]) {
+ mapIdToIndex(mapId) = numBlocks
+ for (i <- 0 until offsets.length) {
+ blockOffsetsByReducer(i) += offsets(i)
}
+ }
- override def releaseWriters(group: ShuffleWriterGroup) = {
- // Nothing really to release here.
+ /** Returns the FileSegment associated with the given map task, or None if no entry exists. */
+ def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = {
+ val file = files(reducerId)
+ val blockOffsets = blockOffsetsByReducer(reducerId)
+ val index = mapIdToIndex.getOrElse(mapId, -1)
+ if (index >= 0) {
+ val offset = blockOffsets(index)
+ val length =
+ if (index + 1 < numBlocks) {
+ blockOffsets(index + 1) - offset
+ } else {
+ file.length() - offset
+ }
+ assert(length >= 0)
+ Some(new FileSegment(file, offset, length))
+ } else {
+ None
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 632ff047d1..b5596dffd3 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -101,7 +101,7 @@ class StorageLevel private(
var result = ""
result += (if (useDisk) "Disk " else "")
result += (if (useMemory) "Memory " else "")
- result += (if (deserialized) "Deserialized " else "Serialized")
+ result += (if (deserialized) "Deserialized " else "Serialized ")
result += "%sx Replicated".format(replication)
result
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
new file mode 100644
index 0000000000..d52b3d8284
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{CountDownLatch, Executors}
+
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.SparkContext
+import org.apache.spark.util.Utils
+
+/**
+ * Utility for micro-benchmarking shuffle write performance.
+ *
+ * Writes simulated shuffle output from several threads and records the observed throughput.
+ */
+object StoragePerfTester {
+ def main(args: Array[String]) = {
+ /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */
+ val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g"))
+
+ /** Number of map tasks. All tasks execute concurrently. */
+ val numMaps = sys.env.get("NUM_MAPS").map(_.toInt).getOrElse(8)
+
+ /** Number of reduce splits for each map task. */
+ val numOutputSplits = sys.env.get("NUM_REDUCERS").map(_.toInt).getOrElse(500)
+
+ val recordLength = 1000 // ~1KB records
+ val totalRecords = dataSizeMb * 1000
+ val recordsPerMap = totalRecords / numMaps
+
+ val writeData = "1" * recordLength
+ val executor = Executors.newFixedThreadPool(numMaps)
+
+ System.setProperty("spark.shuffle.compress", "false")
+ System.setProperty("spark.shuffle.sync", "true")
+
+ // This is only used to instantiate a BlockManager. All thread scheduling is done manually.
+ val sc = new SparkContext("local[4]", "Write Tester")
+ val blockManager = sc.env.blockManager
+
+ def writeOutputBytes(mapId: Int, total: AtomicLong) = {
+ val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
+ new KryoSerializer())
+ val writers = shuffle.writers
+ for (i <- 1 to recordsPerMap) {
+ writers(i % numOutputSplits).write(writeData)
+ }
+ writers.map {w =>
+ w.commit()
+ total.addAndGet(w.fileSegment().length)
+ w.close()
+ }
+
+ shuffle.releaseWriters(true)
+ }
+
+ val start = System.currentTimeMillis()
+ val latch = new CountDownLatch(numMaps)
+ val totalBytes = new AtomicLong()
+ for (task <- 1 to numMaps) {
+ executor.submit(new Runnable() {
+ override def run() = {
+ try {
+ writeOutputBytes(task, totalBytes)
+ latch.countDown()
+ } catch {
+ case e: Exception =>
+ println("Exception in child thread: " + e + " " + e.getMessage)
+ System.exit(1)
+ }
+ }
+ })
+ }
+ latch.await()
+ val end = System.currentTimeMillis()
+ val time = (end - start) / 1000.0
+ val bytesPerSecond = totalBytes.get() / time
+ val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong
+
+ System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits))
+ System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile)))
+ System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong)))
+
+ executor.shutdown()
+ sc.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index 860e680576..a8db37ded1 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -93,7 +93,7 @@ private[spark] object ThreadingTest {
val actorSystem = ActorSystem("test")
val serializer = new KryoSerializer
val blockManagerMaster = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
+ Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))))
val blockManager = new BlockManager(
"<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
index 42e9be6e19..a31a7e1d58 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
@@ -56,7 +56,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
val diskSpaceUsed = storageStatusList.flatMap(_.blocks.values.map(_.diskSize)).fold(0L)(_+_)
val execHead = Seq("Executor ID", "Address", "RDD blocks", "Memory used", "Disk used",
- "Active tasks", "Failed tasks", "Complete tasks", "Total tasks")
+ "Active tasks", "Failed tasks", "Complete tasks", "Total tasks", "Task Time", "Shuffle Read",
+ "Shuffle Write")
def execRow(kv: Seq[String]) = {
<tr>
@@ -73,10 +74,13 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
<td>{kv(7)}</td>
<td>{kv(8)}</td>
<td>{kv(9)}</td>
+ <td>{Utils.msDurationToString(kv(10).toLong)}</td>
+ <td>{Utils.bytesToString(kv(11).toLong)}</td>
+ <td>{Utils.bytesToString(kv(12).toLong)}</td>
</tr>
}
- val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b)
+ val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId)
val execTable = UIUtils.listingTable(execHead, execRow, execInfo)
val content =
@@ -99,17 +103,21 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors)
}
- def getExecInfo(a: Int): Seq[String] = {
- val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId
- val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort
- val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString
- val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString
- val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString
- val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString
- val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0)
- val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0)
- val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0)
+ def getExecInfo(statusId: Int): Seq[String] = {
+ val status = sc.getExecutorStorageStatus(statusId)
+ val execId = status.blockManagerId.executorId
+ val hostPort = status.blockManagerId.hostPort
+ val rddBlocks = status.blocks.size.toString
+ val memUsed = status.memUsed().toString
+ val maxMem = status.maxMem.toString
+ val diskUsed = status.diskUsed().toString
+ val activeTasks = listener.executorToTasksActive.getOrElse(execId, HashSet.empty[Long]).size
+ val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0)
+ val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0)
val totalTasks = activeTasks + failedTasks + completedTasks
+ val totalDuration = listener.executorToDuration.getOrElse(execId, 0)
+ val totalShuffleRead = listener.executorToShuffleRead.getOrElse(execId, 0)
+ val totalShuffleWrite = listener.executorToShuffleWrite.getOrElse(execId, 0)
Seq(
execId,
@@ -121,7 +129,10 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
activeTasks.toString,
failedTasks.toString,
completedTasks.toString,
- totalTasks.toString
+ totalTasks.toString,
+ totalDuration.toString,
+ totalShuffleRead.toString,
+ totalShuffleWrite.toString
)
}
@@ -129,6 +140,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
val executorToTasksActive = HashMap[String, HashSet[TaskInfo]]()
val executorToTasksComplete = HashMap[String, Int]()
val executorToTasksFailed = HashMap[String, Int]()
+ val executorToDuration = HashMap[String, Long]()
+ val executorToShuffleRead = HashMap[String, Long]()
+ val executorToShuffleWrite = HashMap[String, Long]()
override def onTaskStart(taskStart: SparkListenerTaskStart) {
val eid = taskStart.taskInfo.executorId
@@ -139,6 +153,9 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
val eid = taskEnd.taskInfo.executorId
val activeTasks = executorToTasksActive.getOrElseUpdate(eid, new HashSet[TaskInfo]())
+ val newDuration = executorToDuration.getOrElse(eid, 0L) + taskEnd.taskInfo.duration
+ executorToDuration.put(eid, newDuration)
+
activeTasks -= taskEnd.taskInfo
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
@@ -149,6 +166,17 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1
(None, Option(taskEnd.taskMetrics))
}
+
+ // update shuffle read/write
+ if (null != taskEnd.taskMetrics) {
+ taskEnd.taskMetrics.shuffleReadMetrics.foreach(shuffleRead =>
+ executorToShuffleRead.put(eid, executorToShuffleRead.getOrElse(eid, 0L) +
+ shuffleRead.remoteBytesRead))
+
+ taskEnd.taskMetrics.shuffleWriteMetrics.foreach(shuffleWrite =>
+ executorToShuffleWrite.put(eid, executorToShuffleWrite.getOrElse(eid, 0L) +
+ shuffleWrite.shuffleBytesWritten))
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala
new file mode 100644
index 0000000000..3c53e88380
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorSummary.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+/** class for reporting aggregated metrics for each executors in stageUI */
+private[spark] class ExecutorSummary {
+ var taskTime : Long = 0
+ var failedTasks : Int = 0
+ var succeededTasks : Int = 0
+ var shuffleRead : Long = 0
+ var shuffleWrite : Long = 0
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
new file mode 100644
index 0000000000..0dd876480a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import scala.xml.Node
+
+import org.apache.spark.scheduler.SchedulingMode
+import org.apache.spark.util.Utils
+import scala.collection.mutable
+
+/** Page showing executor summary */
+private[spark] class ExecutorTable(val parent: JobProgressUI, val stageId: Int) {
+
+ val listener = parent.listener
+ val dateFmt = parent.dateFmt
+ val isFairScheduler = listener.sc.getSchedulingMode == SchedulingMode.FAIR
+
+ def toNodeSeq(): Seq[Node] = {
+ listener.synchronized {
+ executorTable()
+ }
+ }
+
+ /** Special table which merges two header cells. */
+ private def executorTable[T](): Seq[Node] = {
+ <table class="table table-bordered table-striped table-condensed sortable">
+ <thead>
+ <th>Executor ID</th>
+ <th>Address</th>
+ <th>Task Time</th>
+ <th>Total Tasks</th>
+ <th>Failed Tasks</th>
+ <th>Succeeded Tasks</th>
+ <th>Shuffle Read</th>
+ <th>Shuffle Write</th>
+ </thead>
+ <tbody>
+ {createExecutorTable()}
+ </tbody>
+ </table>
+ }
+
+ private def createExecutorTable() : Seq[Node] = {
+ // make a executor-id -> address map
+ val executorIdToAddress = mutable.HashMap[String, String]()
+ val storageStatusList = parent.sc.getExecutorStorageStatus
+ for (statusId <- 0 until storageStatusList.size) {
+ val blockManagerId = parent.sc.getExecutorStorageStatus(statusId).blockManagerId
+ val address = blockManagerId.hostPort
+ val executorId = blockManagerId.executorId
+ executorIdToAddress.put(executorId, address)
+ }
+
+ val executorIdToSummary = listener.stageIdToExecutorSummaries.get(stageId)
+ executorIdToSummary match {
+ case Some(x) => {
+ x.toSeq.sortBy(_._1).map{
+ case (k,v) => {
+ <tr>
+ <td>{k}</td>
+ <td>{executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}</td>
+ <td>{parent.formatDuration(v.taskTime)}</td>
+ <td>{v.failedTasks + v.succeededTasks}</td>
+ <td>{v.failedTasks}</td>
+ <td>{v.succeededTasks}</td>
+ <td>{Utils.bytesToString(v.shuffleRead)}</td>
+ <td>{Utils.bytesToString(v.shuffleWrite)}</td>
+ </tr>
+ }
+ }
+ }
+ case _ => { Seq[Node]() }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
index b39c0e9769..ca5a28625b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
@@ -38,7 +38,7 @@ private[spark] class IndexPage(parent: JobProgressUI) {
val now = System.currentTimeMillis()
var activeTime = 0L
- for (tasks <- listener.stageToTasksActive.values; t <- tasks) {
+ for (tasks <- listener.stageIdToTasksActive.values; t <- tasks) {
activeTime += t.timeRunning(now)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index eb3b4e8522..07a42f0503 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -36,52 +36,53 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt
val DEFAULT_POOL_NAME = "default"
- val stageToPool = new HashMap[Stage, String]()
- val stageToDescription = new HashMap[Stage, String]()
- val poolToActiveStages = new HashMap[String, HashSet[Stage]]()
+ val stageIdToPool = new HashMap[Int, String]()
+ val stageIdToDescription = new HashMap[Int, String]()
+ val poolToActiveStages = new HashMap[String, HashSet[StageInfo]]()
- val activeStages = HashSet[Stage]()
- val completedStages = ListBuffer[Stage]()
- val failedStages = ListBuffer[Stage]()
+ val activeStages = HashSet[StageInfo]()
+ val completedStages = ListBuffer[StageInfo]()
+ val failedStages = ListBuffer[StageInfo]()
// Total metrics reflect metrics only for completed tasks
var totalTime = 0L
var totalShuffleRead = 0L
var totalShuffleWrite = 0L
- val stageToTime = HashMap[Int, Long]()
- val stageToShuffleRead = HashMap[Int, Long]()
- val stageToShuffleWrite = HashMap[Int, Long]()
- val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
- val stageToTasksComplete = HashMap[Int, Int]()
- val stageToTasksFailed = HashMap[Int, Int]()
- val stageToTaskInfos =
+ val stageIdToTime = HashMap[Int, Long]()
+ val stageIdToShuffleRead = HashMap[Int, Long]()
+ val stageIdToShuffleWrite = HashMap[Int, Long]()
+ val stageIdToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
+ val stageIdToTasksComplete = HashMap[Int, Int]()
+ val stageIdToTasksFailed = HashMap[Int, Int]()
+ val stageIdToTaskInfos =
HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
+ val stageIdToExecutorSummaries = HashMap[Int, HashMap[String, ExecutorSummary]]()
override def onJobStart(jobStart: SparkListenerJobStart) {}
override def onStageCompleted(stageCompleted: StageCompleted) = synchronized {
- val stage = stageCompleted.stageInfo.stage
- poolToActiveStages(stageToPool(stage)) -= stage
+ val stage = stageCompleted.stage
+ poolToActiveStages(stageIdToPool(stage.stageId)) -= stage
activeStages -= stage
completedStages += stage
trimIfNecessary(completedStages)
}
/** If stages is too large, remove and garbage collect old stages */
- def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized {
+ def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
if (stages.size > RETAINED_STAGES) {
val toRemove = RETAINED_STAGES / 10
stages.takeRight(toRemove).foreach( s => {
- stageToTaskInfos.remove(s.id)
- stageToTime.remove(s.id)
- stageToShuffleRead.remove(s.id)
- stageToShuffleWrite.remove(s.id)
- stageToTasksActive.remove(s.id)
- stageToTasksComplete.remove(s.id)
- stageToTasksFailed.remove(s.id)
- stageToPool.remove(s)
- if (stageToDescription.contains(s)) {stageToDescription.remove(s)}
+ stageIdToTaskInfos.remove(s.stageId)
+ stageIdToTime.remove(s.stageId)
+ stageIdToShuffleRead.remove(s.stageId)
+ stageIdToShuffleWrite.remove(s.stageId)
+ stageIdToTasksActive.remove(s.stageId)
+ stageIdToTasksComplete.remove(s.stageId)
+ stageIdToTasksFailed.remove(s.stageId)
+ stageIdToPool.remove(s.stageId)
+ if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)}
})
stages.trimEnd(toRemove)
}
@@ -95,63 +96,102 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val poolName = Option(stageSubmitted.properties).map {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
- stageToPool(stage) = poolName
+ stageIdToPool(stage.stageId) = poolName
val description = Option(stageSubmitted.properties).flatMap {
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
}
- description.map(d => stageToDescription(stage) = d)
+ description.map(d => stageIdToDescription(stage.stageId) = d)
- val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]())
+ val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]())
stages += stage
}
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
val sid = taskStart.task.stageId
- val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+ val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive += taskStart.taskInfo
- val taskList = stageToTaskInfos.getOrElse(
+ val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList += ((taskStart.taskInfo, None, None))
- stageToTaskInfos(sid) = taskList
+ stageIdToTaskInfos(sid) = taskList
}
-
+
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult)
+ = synchronized {
+ // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in
+ // stageToTaskInfos already has the updated status.
+ }
+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
- val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+
+ // create executor summary map if necessary
+ val executorSummaryMap = stageIdToExecutorSummaries.getOrElseUpdate(key = sid,
+ op = new HashMap[String, ExecutorSummary]())
+ executorSummaryMap.getOrElseUpdate(key = taskEnd.taskInfo.executorId,
+ op = new ExecutorSummary())
+
+ val executorSummary = executorSummaryMap.get(taskEnd.taskInfo.executorId)
+ executorSummary match {
+ case Some(y) => {
+ // first update failed-task, succeed-task
+ taskEnd.reason match {
+ case Success =>
+ y.succeededTasks += 1
+ case _ =>
+ y.failedTasks += 1
+ }
+
+ // update duration
+ y.taskTime += taskEnd.taskInfo.duration
+
+ taskEnd.taskMetrics.shuffleReadMetrics.foreach { shuffleRead =>
+ y.shuffleRead += shuffleRead.remoteBytesRead
+ }
+
+ taskEnd.taskMetrics.shuffleWriteMetrics.foreach { shuffleWrite =>
+ y.shuffleWrite += shuffleWrite.shuffleBytesWritten
+ }
+ }
+ case _ => {}
+ }
+
+ val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive -= taskEnd.taskInfo
+
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
case e: ExceptionFailure =>
- stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1
+ stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1
(Some(e), e.metrics)
case _ =>
- stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1
+ stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1
(None, Option(taskEnd.taskMetrics))
}
- stageToTime.getOrElseUpdate(sid, 0L)
+ stageIdToTime.getOrElseUpdate(sid, 0L)
val time = metrics.map(m => m.executorRunTime).getOrElse(0)
- stageToTime(sid) += time
+ stageIdToTime(sid) += time
totalTime += time
- stageToShuffleRead.getOrElseUpdate(sid, 0L)
+ stageIdToShuffleRead.getOrElseUpdate(sid, 0L)
val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s =>
s.remoteBytesRead).getOrElse(0L)
- stageToShuffleRead(sid) += shuffleRead
+ stageIdToShuffleRead(sid) += shuffleRead
totalShuffleRead += shuffleRead
- stageToShuffleWrite.getOrElseUpdate(sid, 0L)
+ stageIdToShuffleWrite.getOrElseUpdate(sid, 0L)
val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s =>
s.shuffleBytesWritten).getOrElse(0L)
- stageToShuffleWrite(sid) += shuffleWrite
+ stageIdToShuffleWrite(sid) += shuffleWrite
totalShuffleWrite += shuffleWrite
- val taskList = stageToTaskInfos.getOrElse(
+ val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList -= ((taskEnd.taskInfo, None, None))
taskList += ((taskEnd.taskInfo, metrics, failureInfo))
- stageToTaskInfos(sid) = taskList
+ stageIdToTaskInfos(sid) = taskList
}
override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
@@ -159,10 +199,15 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
case end: SparkListenerJobEnd =>
end.jobResult match {
case JobFailed(ex, Some(stage)) =>
- activeStages -= stage
- poolToActiveStages(stageToPool(stage)) -= stage
- failedStages += stage
- trimIfNecessary(failedStages)
+ /* If two jobs share a stage we could get this failure message twice. So we first
+ * check whether we've already retired this stage. */
+ val stageInfo = activeStages.filter(s => s.stageId == stage.id).headOption
+ stageInfo.foreach {s =>
+ activeStages -= s
+ poolToActiveStages(stageIdToPool(stage.id)) -= s
+ failedStages += s
+ trimIfNecessary(failedStages)
+ }
case _ =>
}
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
index e7eab374ad..c1ee2f3d00 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ui.jobs
-import akka.util.Duration
+import scala.concurrent.duration._
import java.text.SimpleDateFormat
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
index 06810d8dbc..cfeeccda41 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
@@ -21,13 +21,13 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.xml.Node
-import org.apache.spark.scheduler.{Schedulable, Stage}
+import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.UIUtils
/** Table showing list of pools */
private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) {
- var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages
+ var poolToActiveStages: HashMap[String, HashSet[StageInfo]] = listener.poolToActiveStages
def toNodeSeq(): Seq[Node] = {
listener.synchronized {
@@ -35,7 +35,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
}
}
- private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node],
+ private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[StageInfo]]) => Seq[Node],
rows: Seq[Schedulable]
): Seq[Node] = {
<table class="table table-bordered table-striped table-condensed sortable table-fixed">
@@ -53,7 +53,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
</table>
}
- private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]])
+ private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[StageInfo]])
: Seq[Node] = {
val activeStages = poolToActiveStages.get(p.name) match {
case Some(stages) => stages.size
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 163a3746ea..8dcfeacb60 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -40,7 +40,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val stageId = request.getParameter("id").toInt
val now = System.currentTimeMillis()
- if (!listener.stageToTaskInfos.contains(stageId)) {
+ if (!listener.stageIdToTaskInfos.contains(stageId)) {
val content =
<div>
<h4>Summary Metrics</h4> No tasks have started yet
@@ -49,23 +49,25 @@ private[spark] class StagePage(parent: JobProgressUI) {
return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages)
}
- val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
+ val tasks = listener.stageIdToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
val numCompleted = tasks.count(_._1.finished)
- val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L)
+ val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L)
val hasShuffleRead = shuffleReadBytes > 0
- val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L)
+ val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L)
val hasShuffleWrite = shuffleWriteBytes > 0
var activeTime = 0L
- listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
+ listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
+
+ val finishedTasks = listener.stageIdToTaskInfos(stageId).filter(_._1.finished)
val summary =
<div>
<ul class="unstyled">
<li>
- <strong>CPU time: </strong>
- {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)}
+ <strong>Total task time across all tasks: </strong>
+ {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
</li>
{if (hasShuffleRead)
<li>
@@ -83,10 +85,10 @@ private[spark] class StagePage(parent: JobProgressUI) {
</div>
val taskHeaders: Seq[String] =
- Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++
- Seq("GC Time") ++
+ Seq("Task Index", "Task ID", "Status", "Locality Level", "Executor", "Launch Time") ++
+ Seq("Duration", "GC Time", "Result Ser Time") ++
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
- {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++
+ {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++
Seq("Errors")
val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks)
@@ -99,11 +101,43 @@ private[spark] class StagePage(parent: JobProgressUI) {
None
}
else {
+ val serializationTimes = validTasks.map{case (info, metrics, exception) =>
+ metrics.get.resultSerializationTime.toDouble}
+ val serializationQuantiles = "Result serialization time" +: Distribution(serializationTimes).get.getQuantiles().map(
+ ms => parent.formatDuration(ms.toLong))
+
val serviceTimes = validTasks.map{case (info, metrics, exception) =>
metrics.get.executorRunTime.toDouble}
val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map(
ms => parent.formatDuration(ms.toLong))
+ val gettingResultTimes = validTasks.map{case (info, metrics, exception) =>
+ if (info.gettingResultTime > 0) {
+ (info.finishTime - info.gettingResultTime).toDouble
+ } else {
+ 0.0
+ }
+ }
+ val gettingResultQuantiles = ("Time spent fetching task results" +:
+ Distribution(gettingResultTimes).get.getQuantiles().map(
+ millis => parent.formatDuration(millis.toLong)))
+ // The scheduler delay includes the network delay to send the task to the worker
+ // machine and to send back the result (but not the time to fetch the task result,
+ // if it needed to be fetched from the block manager on the worker).
+ val schedulerDelays = validTasks.map{case (info, metrics, exception) =>
+ val totalExecutionTime = {
+ if (info.gettingResultTime > 0) {
+ (info.gettingResultTime - info.launchTime).toDouble
+ } else {
+ (info.finishTime - info.launchTime).toDouble
+ }
+ }
+ totalExecutionTime - metrics.get.executorRunTime
+ }
+ val schedulerDelayQuantiles = ("Scheduler delay" +:
+ Distribution(schedulerDelays).get.getQuantiles().map(
+ millis => parent.formatDuration(millis.toLong)))
+
def getQuantileCols(data: Seq[Double]) =
Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong))
@@ -119,7 +153,11 @@ private[spark] class StagePage(parent: JobProgressUI) {
}
val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes)
- val listings: Seq[Seq[String]] = Seq(serviceQuantiles,
+ val listings: Seq[Seq[String]] = Seq(
+ serializationQuantiles,
+ serviceQuantiles,
+ gettingResultQuantiles,
+ schedulerDelayQuantiles,
if (hasShuffleRead) shuffleReadQuantiles else Nil,
if (hasShuffleWrite) shuffleWriteQuantiles else Nil)
@@ -128,12 +166,13 @@ private[spark] class StagePage(parent: JobProgressUI) {
def quantileRow(data: Seq[String]): Seq[Node] = <tr> {data.map(d => <td>{d}</td>)} </tr>
Some(listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
}
-
+ val executorTable = new ExecutorTable(parent, stageId)
val content =
summary ++
<h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++
<div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++
- <h4>Tasks</h4> ++ taskTable;
+ <h4>Aggregated Metrics by Executors</h4> ++ executorTable.toNodeSeq() ++
+ <h4>Tasks</h4> ++ taskTable
headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages)
}
@@ -151,8 +190,23 @@ private[spark] class StagePage(parent: JobProgressUI) {
val formatDuration = if (info.status == "RUNNING") parent.formatDuration(duration)
else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("")
val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
+ val serializationTime = metrics.map(m => m.resultSerializationTime).getOrElse(0L)
+
+ val maybeShuffleRead = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}
+ val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("")
+ val shuffleReadReadable = maybeShuffleRead.map{Utils.bytesToString(_)}.getOrElse("")
+
+ val maybeShuffleWrite = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}
+ val shuffleWriteSortable = maybeShuffleWrite.map(_.toString).getOrElse("")
+ val shuffleWriteReadable = maybeShuffleWrite.map{Utils.bytesToString(_)}.getOrElse("")
+
+ val maybeWriteTime = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleWriteTime}
+ val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("")
+ val writeTimeReadable = maybeWriteTime.map{ t => t / (1000 * 1000)}.map{ ms =>
+ if (ms == 0) "" else parent.formatDuration(ms)}.getOrElse("")
<tr>
+ <td>{info.index}</td>
<td>{info.taskId}</td>
<td>{info.status}</td>
<td>{info.taskLocality}</td>
@@ -164,13 +218,21 @@ private[spark] class StagePage(parent: JobProgressUI) {
<td sorttable_customkey={gcTime.toString}>
{if (gcTime > 0) parent.formatDuration(gcTime) else ""}
</td>
+ <td sorttable_customkey={serializationTime.toString}>
+ {if (serializationTime > 0) parent.formatDuration(serializationTime) else ""}
+ </td>
{if (shuffleRead) {
- <td>{metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
- Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td>
+ <td sorttable_customkey={shuffleReadSortable}>
+ {shuffleReadReadable}
+ </td>
}}
{if (shuffleWrite) {
- <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
- Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td>
+ <td sorttable_customkey={writeTimeSortable}>
+ {writeTimeReadable}
+ </td>
+ <td sorttable_customkey={shuffleWriteSortable}>
+ {shuffleWriteReadable}
+ </td>
}}
<td>{exception.map(e =>
<span>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 07db8622da..463d85dfd5 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -22,13 +22,13 @@ import java.util.Date
import scala.xml.Node
import scala.collection.mutable.HashSet
-import org.apache.spark.scheduler.{SchedulingMode, Stage, TaskInfo}
+import org.apache.spark.scheduler.{SchedulingMode, StageInfo, TaskInfo}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
/** Page showing list of all ongoing and recently finished stages */
-private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) {
+private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgressUI) {
val listener = parent.listener
val dateFmt = parent.dateFmt
@@ -48,7 +48,7 @@ private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressU
{if (isFairScheduler) {<th>Pool Name</th>} else {}}
<th>Description</th>
<th>Submitted</th>
- <th>Duration</th>
+ <th>Task Time</th>
<th>Tasks: Succeeded/Total</th>
<th>Shuffle Read</th>
<th>Shuffle Write</th>
@@ -73,40 +73,43 @@ private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressU
}
- private def stageRow(s: Stage): Seq[Node] = {
+ private def stageRow(s: StageInfo): Seq[Node] = {
val submissionTime = s.submissionTime match {
case Some(t) => dateFmt.format(new Date(t))
case None => "Unknown"
}
- val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match {
+ val shuffleReadSortable = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L)
+ val shuffleRead = shuffleReadSortable match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match {
+
+ val shuffleWriteSortable = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L)
+ val shuffleWrite = shuffleWriteSortable match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size
- val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0)
- val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match {
+ val startedTasks = listener.stageIdToTasksActive.getOrElse(s.stageId, HashSet[TaskInfo]()).size
+ val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0)
+ val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match {
case f if f > 0 => "(%s failed)".format(f)
case _ => ""
}
- val totalTasks = s.numPartitions
+ val totalTasks = s.numTasks
- val poolName = listener.stageToPool.get(s)
+ val poolName = listener.stageIdToPool.get(s.stageId)
val nameLink =
- <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.id)}>{s.name}</a>
- val description = listener.stageToDescription.get(s)
+ <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.stageId)}>{s.name}</a>
+ val description = listener.stageIdToDescription.get(s.stageId)
.map(d => <div><em>{d}</em></div><div>{nameLink}</div>).getOrElse(nameLink)
val finishTime = s.completionTime.getOrElse(System.currentTimeMillis())
val duration = s.submissionTime.map(t => finishTime - t)
<tr>
- <td>{s.id}</td>
+ <td>{s.stageId}</td>
{if (isFairScheduler) {
<td><a href={"%s/stages/pool?poolname=%s".format(UIUtils.prependBaseUri(),poolName.get)}>
{poolName.get}</a></td>}
@@ -119,8 +122,8 @@ private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressU
<td class="progress-cell">
{makeProgressBar(startedTasks, completedTasks, failedTasks, totalTasks)}
</td>
- <td>{shuffleRead}</td>
- <td>{shuffleWrite}</td>
+ <td sorttable_customekey={shuffleReadSortable.toString}>{shuffleRead}</td>
+ <td sorttable_customekey={shuffleWriteSortable.toString}>{shuffleWrite}</td>
</tr>
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
index 1d633d374a..39f422dd6b 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ui.storage
-import akka.util.Duration
+import scala.concurrent.duration._
import javax.servlet.http.HttpServletRequest
@@ -28,9 +28,6 @@ import org.apache.spark.ui.JettyUtils._
/** Web UI showing storage status of all RDD's in the given SparkContext. */
private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging {
- implicit val timeout = Duration.create(
- System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
-
val indexPage = new IndexPage(this)
val rddPage = new RDDPage(this)
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index d4c5065c3f..1c8b51b8bc 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -17,11 +17,10 @@
package org.apache.spark.util
-import akka.actor.{ActorSystem, ExtendedActorSystem}
-import com.typesafe.config.ConfigFactory
-import akka.util.duration._
-import akka.remote.RemoteActorRefProvider
+import scala.concurrent.duration.{Duration, FiniteDuration}
+import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem}
+import com.typesafe.config.ConfigFactory
/**
* Various utility classes for working with Akka.
@@ -34,39 +33,61 @@ private[spark] object AkkaUtils {
*
* Note: the `name` parameter is important, as even if a client sends a message to right
* host + port, if the system name is incorrect, Akka will drop the message.
+ *
+ * If indestructible is set to true, the Actor System will continue running in the event
+ * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]].
*/
- def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
- val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
+ def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false)
+ : (ActorSystem, Int) = {
+
+ val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
- val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt
+
+ val akkaTimeout = System.getProperty("spark.akka.timeout", "100").toInt
+
val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt
- val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off"
- // 10 seconds is the default akka timeout, but in a cluster, we need higher by default.
- val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt
-
- val akkaConf = ConfigFactory.parseString("""
- akka.daemonic = on
- akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
- akka.stdout-loglevel = "ERROR"
- akka.actor.provider = "akka.remote.RemoteActorRefProvider"
- akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
- akka.remote.netty.hostname = "%s"
- akka.remote.netty.port = %d
- akka.remote.netty.connection-timeout = %ds
- akka.remote.netty.message-frame-size = %d MiB
- akka.remote.netty.execution-pool-size = %d
- akka.actor.default-dispatcher.throughput = %d
- akka.remote.log-remote-lifecycle-events = %s
- akka.remote.netty.write-timeout = %ds
- """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize,
- lifecycleEvents, akkaWriteTimeout))
+ val lifecycleEvents =
+ if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off"
+
+ val akkaHeartBeatPauses = System.getProperty("spark.akka.heartbeat.pauses", "600").toInt
+ val akkaFailureDetector =
+ System.getProperty("spark.akka.failure-detector.threshold", "300.0").toDouble
+ val akkaHeartBeatInterval = System.getProperty("spark.akka.heartbeat.interval", "1000").toInt
+
+ val akkaConf = ConfigFactory.parseString(
+ s"""
+ |akka.daemonic = on
+ |akka.loggers = [""akka.event.slf4j.Slf4jLogger""]
+ |akka.stdout-loglevel = "ERROR"
+ |akka.jvm-exit-on-fatal-error = off
+ |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s
+ |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s
+ |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector
+ |akka.actor.provider = "akka.remote.RemoteActorRefProvider"
+ |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport"
+ |akka.remote.netty.tcp.hostname = "$host"
+ |akka.remote.netty.tcp.port = $port
+ |akka.remote.netty.tcp.tcp-nodelay = on
+ |akka.remote.netty.tcp.connection-timeout = $akkaTimeout s
+ |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB
+ |akka.remote.netty.tcp.execution-pool-size = $akkaThreads
+ |akka.actor.default-dispatcher.throughput = $akkaBatchSize
+ |akka.remote.log-remote-lifecycle-events = $lifecycleEvents
+ """.stripMargin)
- val actorSystem = ActorSystem(name, akkaConf)
+ val actorSystem = if (indestructible) {
+ IndestructibleActorSystem(name, akkaConf)
+ } else {
+ ActorSystem(name, akkaConf)
+ }
- // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
- // hack because Akka doesn't let you figure out the port through the public API yet.
val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider
- val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get
- return (actorSystem, boundPort)
+ val boundPort = provider.getDefaultAddress.port.get
+ (actorSystem, boundPort)
+ }
+
+ /** Returns the default Spark timeout to use for Akka ask operations. */
+ def askTimeout: FiniteDuration = {
+ Duration.create(System.getProperty("spark.akka.askTimeout", "30").toLong, "seconds")
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
index f60deafc6f..8bb4ee3bfa 100644
--- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
@@ -35,6 +35,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
private var capacity = nextPowerOf2(initialCapacity)
private var mask = capacity - 1
private var curSize = 0
+ private var growThreshold = LOAD_FACTOR * capacity
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
@@ -56,7 +57,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
var i = 1
while (true) {
val curKey = data(2 * pos)
- if (k.eq(curKey) || k == curKey) {
+ if (k.eq(curKey) || k.equals(curKey)) {
return data(2 * pos + 1).asInstanceOf[V]
} else if (curKey.eq(null)) {
return null.asInstanceOf[V]
@@ -80,9 +81,23 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
haveNullValue = true
return
}
- val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
- if (isNewEntry) {
- incrementSize()
+ var pos = rehash(key.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (curKey.eq(null)) {
+ data(2 * pos) = k
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ incrementSize() // Since we added a new key
+ return
+ } else if (k.eq(curKey) || k.equals(curKey)) {
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
}
}
@@ -104,7 +119,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
var i = 1
while (true) {
val curKey = data(2 * pos)
- if (k.eq(curKey) || k == curKey) {
+ if (k.eq(curKey) || k.equals(curKey)) {
val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
return newValue
@@ -161,45 +176,17 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
/** Increase table size by 1, rehashing if necessary */
private def incrementSize() {
curSize += 1
- if (curSize > LOAD_FACTOR * capacity) {
+ if (curSize > growThreshold) {
growTable()
}
}
/**
- * Re-hash a value to deal better with hash functions that don't differ
- * in the lower bits, similar to java.util.HashMap
+ * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
+ * We use the Murmur Hash 3 finalization step that's also used in fastutil.
*/
private def rehash(h: Int): Int = {
- val r = h ^ (h >>> 20) ^ (h >>> 12)
- r ^ (r >>> 7) ^ (r >>> 4)
- }
-
- /**
- * Put an entry into a table represented by data, returning true if
- * this increases the size of the table or false otherwise. Assumes
- * that "data" has at least one empty slot.
- */
- private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
- val mask = (data.length / 2) - 1
- var pos = rehash(key.hashCode) & mask
- var i = 1
- while (true) {
- val curKey = data(2 * pos)
- if (curKey.eq(null)) {
- data(2 * pos) = key
- data(2 * pos + 1) = value.asInstanceOf[AnyRef]
- return true
- } else if (curKey.eq(key) || curKey == key) {
- data(2 * pos + 1) = value.asInstanceOf[AnyRef]
- return false
- } else {
- val delta = i
- pos = (pos + delta) & mask
- i += 1
- }
- }
- return false // Never reached but needed to keep compiler happy
+ it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
}
/** Double the table's size and re-hash everything */
@@ -211,16 +198,36 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
throw new Exception("Can't make capacity bigger than 2^29 elements")
}
val newData = new Array[AnyRef](2 * newCapacity)
- var pos = 0
- while (pos < capacity) {
- if (!data(2 * pos).eq(null)) {
- putInto(newData, data(2 * pos), data(2 * pos + 1))
+ val newMask = newCapacity - 1
+ // Insert all our old values into the new array. Note that because our old keys are
+ // unique, there's no need to check for equality here when we insert.
+ var oldPos = 0
+ while (oldPos < capacity) {
+ if (!data(2 * oldPos).eq(null)) {
+ val key = data(2 * oldPos)
+ val value = data(2 * oldPos + 1)
+ var newPos = rehash(key.hashCode) & newMask
+ var i = 1
+ var keepGoing = true
+ while (keepGoing) {
+ val curKey = newData(2 * newPos)
+ if (curKey.eq(null)) {
+ newData(2 * newPos) = key
+ newData(2 * newPos + 1) = value
+ keepGoing = false
+ } else {
+ val delta = i
+ newPos = (newPos + delta) & newMask
+ i += 1
+ }
+ }
}
- pos += 1
+ oldPos += 1
}
data = newData
capacity = newCapacity
- mask = newCapacity - 1
+ mask = newMask
+ growThreshold = LOAD_FACTOR * newCapacity
}
private def nextPowerOf2(n: Int): Int = {
diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
index 0b51c23f7b..a38329df03 100644
--- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
+++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
@@ -34,6 +34,8 @@ class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
override def iterator: Iterator[A] = underlying.iterator.asScala
+ override def size: Int = underlying.size
+
override def ++=(xs: TraversableOnce[A]): this.type = {
xs.foreach { this += _ }
this
diff --git a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
new file mode 100644
index 0000000000..bf71882ef7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Must be in akka.actor package as ActorSystemImpl is protected[akka].
+package akka.actor
+
+import scala.util.control.{ControlThrowable, NonFatal}
+
+import com.typesafe.config.Config
+
+/**
+ * An [[akka.actor.ActorSystem]] which refuses to shut down in the event of a fatal exception.
+ * This is necessary as Spark Executors are allowed to recover from fatal exceptions
+ * (see [[org.apache.spark.executor.Executor]]).
+ */
+object IndestructibleActorSystem {
+ def apply(name: String, config: Config): ActorSystem =
+ apply(name, config, ActorSystem.findClassLoader())
+
+ def apply(name: String, config: Config, classLoader: ClassLoader): ActorSystem =
+ new IndestructibleActorSystemImpl(name, config, classLoader).start()
+}
+
+private[akka] class IndestructibleActorSystemImpl(
+ override val name: String,
+ applicationConfig: Config,
+ classLoader: ClassLoader)
+ extends ActorSystemImpl(name, applicationConfig, classLoader) {
+
+ protected override def uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = {
+ val fallbackHandler = super.uncaughtExceptionHandler
+
+ new Thread.UncaughtExceptionHandler() {
+ def uncaughtException(thread: Thread, cause: Throwable): Unit = {
+ if (isFatalError(cause) && !settings.JvmExitOnFatalError) {
+ log.error(cause, "Uncaught fatal error from thread [{}] not shutting down " +
+ "ActorSystem [{}] tolerating and continuing.... ", thread.getName, name)
+ //shutdown() //TODO make it configurable
+ } else {
+ fallbackHandler.uncaughtException(thread, cause)
+ }
+ }
+ }
+ }
+
+ def isFatalError(e: Throwable): Boolean = {
+ e match {
+ case NonFatal(_) | _: InterruptedException | _: NotImplementedError | _: ControlThrowable =>
+ false
+ case _ =>
+ true
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 0ce1394c77..7b41ef89f1 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -55,10 +55,10 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea
}
}
-object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
- "ShuffleMapTask", "BlockManager", "BroadcastVars") {
+object MetadataCleanerType extends Enumeration {
- val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, BROADCAST_VARS = Value
+ val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
+ SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index 277de2f8a6..dbff571de9 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -85,7 +85,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging {
}
override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
- JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
+ JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
}
override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index a3b3968c5e..3f7858d2de 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,15 +18,15 @@
package org.apache.spark.util
import java.io._
-import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
+import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address}
import java.util.{Locale, Random, UUID}
-import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
-import java.util.regex.Pattern
+import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
-import scala.collection.Map
-import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
+import scala.collection.Map
+import scala.collection.mutable.ArrayBuffer
import scala.io.Source
+import scala.reflect.ClassTag
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
@@ -36,7 +36,7 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import org.apache.spark.deploy.SparkHadoopUtil
import java.nio.ByteBuffer
-import org.apache.spark.{SparkEnv, SparkException, Logging}
+import org.apache.spark.{SparkException, Logging}
/**
@@ -148,7 +148,7 @@ private[spark] object Utils extends Logging {
return buf
}
- private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
+ private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]()
// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
@@ -280,9 +280,8 @@ private[spark] object Utils extends Logging {
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
- val env = SparkEnv.get
val uri = new URI(url)
- val conf = env.hadoop.newConfiguration()
+ val conf = SparkHadoopUtil.get.newConfiguration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(tempFile)
@@ -321,7 +320,7 @@ private[spark] object Utils extends Logging {
* result in a new collection. Unlike scala.util.Random.shuffle, this method
* uses a local random number generator, avoiding inter-thread contention.
*/
- def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
+ def randomize[T: ClassTag](seq: TraversableOnce[T]): Seq[T] = {
randomizeInPlace(seq.toArray)
}
@@ -819,4 +818,34 @@ private[spark] object Utils extends Logging {
// Nothing else to guard against ?
hashAbs
}
+
+ /** Returns a copy of the system properties that is thread-safe to iterator over. */
+ def getSystemProperties(): Map[String, String] = {
+ return System.getProperties().clone()
+ .asInstanceOf[java.util.Properties].toMap[String, String]
+ }
+
+ /**
+ * Method executed for repeating a task for side effects.
+ * Unlike a for comprehension, it permits JVM JIT optimization
+ */
+ def times(numIters: Int)(f: => Unit): Unit = {
+ var i = 0
+ while (i < numIters) {
+ f
+ i += 1
+ }
+ }
+
+ /**
+ * Timing method based on iterations that permit JVM JIT optimization.
+ * @param numIters number of iterations
+ * @param f function to be executed
+ */
+ def timeIt(numIters: Int)(f: => Unit): Long = {
+ val start = System.currentTimeMillis
+ times(numIters)(f)
+ System.currentTimeMillis - start
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
new file mode 100644
index 0000000000..e9907e6c85
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.{Random => JavaRandom}
+import org.apache.spark.util.Utils.timeIt
+
+/**
+ * This class implements a XORShift random number generator algorithm
+ * Source:
+ * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14.
+ * @see <a href="http://www.jstatsoft.org/v08/i14/paper">Paper</a>
+ * This implementation is approximately 3.5 times faster than
+ * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due
+ * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class
+ * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG
+ * for each thread.
+ */
+private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
+
+ def this() = this(System.nanoTime)
+
+ private var seed = init
+
+ // we need to just override next - this will be called by nextInt, nextDouble,
+ // nextGaussian, nextLong, etc.
+ override protected def next(bits: Int): Int = {
+ var nextSeed = seed ^ (seed << 21)
+ nextSeed ^= (nextSeed >>> 35)
+ nextSeed ^= (nextSeed << 4)
+ seed = nextSeed
+ (nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
+ }
+}
+
+/** Contains benchmark method and main method to run benchmark of the RNG */
+private[spark] object XORShiftRandom {
+
+ /**
+ * Main method for running benchmark
+ * @param args takes one argument - the number of random numbers to generate
+ */
+ def main(args: Array[String]): Unit = {
+ if (args.length != 1) {
+ println("Benchmark of XORShiftRandom vis-a-vis java.util.Random")
+ println("Usage: XORShiftRandom number_of_random_numbers_to_generate")
+ System.exit(1)
+ }
+ println(benchmark(args(0).toInt))
+ }
+
+ /**
+ * @param numIters Number of random numbers to generate while running the benchmark
+ * @return Map of execution times for {@link java.util.Random java.util.Random}
+ * and XORShift
+ */
+ def benchmark(numIters: Int) = {
+
+ val seed = 1L
+ val million = 1e6.toInt
+ val javaRand = new JavaRandom(seed)
+ val xorRand = new XORShiftRandom(seed)
+
+ // this is just to warm up the JIT - we're not timing anything
+ timeIt(1e6.toInt) {
+ javaRand.nextInt()
+ xorRand.nextInt()
+ }
+
+ val iters = timeIt(numIters)(_)
+
+ /* Return results as a map instead of just printing to screen
+ in case the user wants to do something with them */
+ Map("javaTime" -> iters {javaRand.nextInt()},
+ "xorTime" -> iters {xorRand.nextInt()})
+
+ }
+
+} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
new file mode 100644
index 0000000000..a1a452315d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+
+/**
+ * A simple, fixed-size bit set implementation. This implementation is fast because it avoids
+ * safety/bound checking.
+ */
+class BitSet(numBits: Int) {
+
+ private[this] val words = new Array[Long](bit2words(numBits))
+ private[this] val numWords = words.length
+
+ /**
+ * Sets the bit at the specified index to true.
+ * @param index the bit index
+ */
+ def set(index: Int) {
+ val bitmask = 1L << (index & 0x3f) // mod 64 and shift
+ words(index >> 6) |= bitmask // div by 64 and mask
+ }
+
+ /**
+ * Return the value of the bit with the specified index. The value is true if the bit with
+ * the index is currently set in this BitSet; otherwise, the result is false.
+ *
+ * @param index the bit index
+ * @return the value of the bit with the specified index
+ */
+ def get(index: Int): Boolean = {
+ val bitmask = 1L << (index & 0x3f) // mod 64 and shift
+ (words(index >> 6) & bitmask) != 0 // div by 64 and mask
+ }
+
+ /** Return the number of bits set to true in this BitSet. */
+ def cardinality(): Int = {
+ var sum = 0
+ var i = 0
+ while (i < numWords) {
+ sum += java.lang.Long.bitCount(words(i))
+ i += 1
+ }
+ sum
+ }
+
+ /**
+ * Returns the index of the first bit that is set to true that occurs on or after the
+ * specified starting index. If no such bit exists then -1 is returned.
+ *
+ * To iterate over the true bits in a BitSet, use the following loop:
+ *
+ * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) {
+ * // operate on index i here
+ * }
+ *
+ * @param fromIndex the index to start checking from (inclusive)
+ * @return the index of the next set bit, or -1 if there is no such bit
+ */
+ def nextSetBit(fromIndex: Int): Int = {
+ var wordIndex = fromIndex >> 6
+ if (wordIndex >= numWords) {
+ return -1
+ }
+
+ // Try to find the next set bit in the current word
+ val subIndex = fromIndex & 0x3f
+ var word = words(wordIndex) >> subIndex
+ if (word != 0) {
+ return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word)
+ }
+
+ // Find the next set bit in the rest of the words
+ wordIndex += 1
+ while (wordIndex < numWords) {
+ word = words(wordIndex)
+ if (word != 0) {
+ return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word)
+ }
+ wordIndex += 1
+ }
+
+ -1
+ }
+
+ /** Return the number of longs it would take to hold numBits. */
+ private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
new file mode 100644
index 0000000000..c26f23d500
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect.ClassTag
+
+/**
+ * A fast hash map implementation for nullable keys. This hash map supports insertions and updates,
+ * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less
+ * space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[spark]
+class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
+ initialCapacity: Int)
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ def this() = this(64)
+
+ protected var _keySet = new OpenHashSet[K](initialCapacity)
+
+ // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+ // bug that would generate two arrays (one for Object and one for specialized T).
+ private var _values: Array[V] = _
+ _values = new Array[V](_keySet.capacity)
+
+ @transient private var _oldValues: Array[V] = null
+
+ // Treat the null key differently so we can use nulls in "data" to represent empty items.
+ private var haveNullValue = false
+ private var nullValue: V = null.asInstanceOf[V]
+
+ override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ if (k == null) {
+ nullValue
+ } else {
+ val pos = _keySet.getPos(k)
+ if (pos < 0) {
+ null.asInstanceOf[V]
+ } else {
+ _values(pos)
+ }
+ }
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ if (k == null) {
+ haveNullValue = true
+ nullValue = v
+ } else {
+ val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ _keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+ }
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ if (k == null) {
+ if (haveNullValue) {
+ nullValue = mergeValue(nullValue)
+ } else {
+ haveNullValue = true
+ nullValue = defaultValue
+ }
+ nullValue
+ } else {
+ val pos = _keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ _keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = -1
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ if (pos == -1) { // Treat position -1 as looking at the null value
+ if (haveNullValue) {
+ pos += 1
+ return (null.asInstanceOf[K], nullValue)
+ }
+ pos += 1
+ }
+ pos = _keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (_keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the non-specialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
new file mode 100644
index 0000000000..87e009a4de
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect._
+
+/**
+ * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
+ * removed.
+ *
+ * The underlying implementation uses Scala compiler's specialization to generate optimized
+ * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet
+ * while incurring much less memory overhead. This can serve as building blocks for higher level
+ * data structures such as an optimized HashMap.
+ *
+ * This OpenHashSet is designed to serve as building blocks for higher level data structures
+ * such as an optimized hash map. Compared with standard hash set implementations, this class
+ * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to
+ * retrieve the position of a key in the underlying array.
+ *
+ * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed
+ * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing).
+ */
+private[spark]
+class OpenHashSet[@specialized(Long, Int) T: ClassTag](
+ initialCapacity: Int,
+ loadFactor: Double)
+ extends Serializable {
+
+ require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+ require(initialCapacity >= 1, "Invalid initial capacity")
+ require(loadFactor < 1.0, "Load factor must be less than 1.0")
+ require(loadFactor > 0.0, "Load factor must be greater than 0.0")
+
+ import OpenHashSet._
+
+ def this(initialCapacity: Int) = this(initialCapacity, 0.7)
+
+ def this() = this(64)
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the non-specialized one and needs access
+ // to the "private" variables).
+
+ protected val hasher: Hasher[T] = {
+ // It would've been more natural to write the following using pattern matching. But Scala 2.9.x
+ // compiler has a bug when specialization is used together with this pattern matching, and
+ // throws:
+ // scala.tools.nsc.symtab.Types$TypeError: type mismatch;
+ // found : scala.reflect.AnyValManifest[Long]
+ // required: scala.reflect.ClassTag[Int]
+ // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298)
+ // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207)
+ // ...
+ val mt = classTag[T]
+ if (mt == ClassTag.Long) {
+ (new LongHasher).asInstanceOf[Hasher[T]]
+ } else if (mt == ClassTag.Int) {
+ (new IntHasher).asInstanceOf[Hasher[T]]
+ } else {
+ new Hasher[T]
+ }
+ }
+
+ protected var _capacity = nextPowerOf2(initialCapacity)
+ protected var _mask = _capacity - 1
+ protected var _size = 0
+ protected var _growThreshold = (loadFactor * _capacity).toInt
+
+ protected var _bitset = new BitSet(_capacity)
+
+ // Init of the array in constructor (instead of in declaration) to work around a Scala compiler
+ // specialization bug that would generate two arrays (one for Object and one for specialized T).
+ protected var _data: Array[T] = _
+ _data = new Array[T](_capacity)
+
+ /** Number of elements in the set. */
+ def size: Int = _size
+
+ /** The capacity of the set (i.e. size of the underlying array). */
+ def capacity: Int = _capacity
+
+ /** Return true if this set contains the specified element. */
+ def contains(k: T): Boolean = getPos(k) != INVALID_POS
+
+ /**
+ * Add an element to the set. If the set is over capacity after the insertion, grow the set
+ * and rehash all elements.
+ */
+ def add(k: T) {
+ addWithoutResize(k)
+ rehashIfNeeded(k, grow, move)
+ }
+
+ /**
+ * Add an element to the set. This one differs from add in that it doesn't trigger rehashing.
+ * The caller is responsible for calling rehashIfNeeded.
+ *
+ * Use (retval & POSITION_MASK) to get the actual position, and
+ * (retval & EXISTENCE_MASK) != 0 for prior existence.
+ *
+ * @return The position where the key is placed, plus the highest order bit is set if the key
+ * exists previously.
+ */
+ def addWithoutResize(k: T): Int = {
+ var pos = hashcode(hasher.hash(k)) & _mask
+ var i = 1
+ while (true) {
+ if (!_bitset.get(pos)) {
+ // This is a new key.
+ _data(pos) = k
+ _bitset.set(pos)
+ _size += 1
+ return pos | NONEXISTENCE_MASK
+ } else if (_data(pos) == k) {
+ // Found an existing key.
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & _mask
+ i += 1
+ }
+ }
+ // Never reached here
+ assert(INVALID_POS != INVALID_POS)
+ INVALID_POS
+ }
+
+ /**
+ * Rehash the set if it is overloaded.
+ * @param k A parameter unused in the function, but to force the Scala compiler to specialize
+ * this method.
+ * @param allocateFunc Callback invoked when we are allocating a new, larger array.
+ * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
+ * to a new position (in the new data array).
+ */
+ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ if (_size > _growThreshold) {
+ rehash(k, allocateFunc, moveFunc)
+ }
+ }
+
+ /**
+ * Return the position of the element in the underlying array, or INVALID_POS if it is not found.
+ */
+ def getPos(k: T): Int = {
+ var pos = hashcode(hasher.hash(k)) & _mask
+ var i = 1
+ while (true) {
+ if (!_bitset.get(pos)) {
+ return INVALID_POS
+ } else if (k == _data(pos)) {
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & _mask
+ i += 1
+ }
+ }
+ // Never reached here
+ INVALID_POS
+ }
+
+ /** Return the value at the specified position. */
+ def getValue(pos: Int): T = _data(pos)
+
+ /**
+ * Return the next position with an element stored, starting from the given position inclusively.
+ */
+ def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
+
+ /**
+ * Double the table's size and re-hash everything. We are not really using k, but it is declared
+ * so Scala compiler can specialize this method (which leads to calling the specialized version
+ * of putInto).
+ *
+ * @param k A parameter unused in the function, but to force the Scala compiler to specialize
+ * this method.
+ * @param allocateFunc Callback invoked when we are allocating a new, larger array.
+ * @param moveFunc Callback invoked when we move the key from one position (in the old data array)
+ * to a new position (in the new data array).
+ */
+ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
+ val newCapacity = _capacity * 2
+ allocateFunc(newCapacity)
+ val newBitset = new BitSet(newCapacity)
+ val newData = new Array[T](newCapacity)
+ val newMask = newCapacity - 1
+
+ var oldPos = 0
+ while (oldPos < capacity) {
+ if (_bitset.get(oldPos)) {
+ val key = _data(oldPos)
+ var newPos = hashcode(hasher.hash(key)) & newMask
+ var i = 1
+ var keepGoing = true
+ // No need to check for equality here when we insert so this has one less if branch than
+ // the similar code path in addWithoutResize.
+ while (keepGoing) {
+ if (!newBitset.get(newPos)) {
+ // Inserting the key at newPos
+ newData(newPos) = key
+ newBitset.set(newPos)
+ moveFunc(oldPos, newPos)
+ keepGoing = false
+ } else {
+ val delta = i
+ newPos = (newPos + delta) & newMask
+ i += 1
+ }
+ }
+ }
+ oldPos += 1
+ }
+
+ _bitset = newBitset
+ _data = newData
+ _capacity = newCapacity
+ _mask = newMask
+ _growThreshold = (loadFactor * newCapacity).toInt
+ }
+
+ /**
+ * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
+ * We use the Murmur Hash 3 finalization step that's also used in fastutil.
+ */
+ private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
+
+ private def nextPowerOf2(n: Int): Int = {
+ val highBit = Integer.highestOneBit(n)
+ if (highBit == n) n else highBit << 1
+ }
+}
+
+
+private[spark]
+object OpenHashSet {
+
+ val INVALID_POS = -1
+ val NONEXISTENCE_MASK = 0x80000000
+ val POSITION_MASK = 0xEFFFFFF
+
+ /**
+ * A set of specialized hash function implementation to avoid boxing hash code computation
+ * in the specialized implementation of OpenHashSet.
+ */
+ sealed class Hasher[@specialized(Long, Int) T] {
+ def hash(o: T): Int = o.hashCode()
+ }
+
+ class LongHasher extends Hasher[Long] {
+ override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt
+ }
+
+ class IntHasher extends Hasher[Int] {
+ override def hash(o: Int): Int = o
+ }
+
+ private def grow1(newSize: Int) {}
+ private def move1(oldPos: Int, newPos: Int) { }
+
+ private val grow = grow1 _
+ private val move = move1 _
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
new file mode 100644
index 0000000000..2e1ef06cbc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect._
+
+/**
+ * A fast hash map implementation for primitive, non-null keys. This hash map supports
+ * insertions and updates, but not deletions. This map is about an order of magnitude
+ * faster than java.util.HashMap, while using much less space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[spark]
+class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
+ @specialized(Long, Int, Double) V: ClassTag](
+ initialCapacity: Int)
+ extends Iterable[(K, V)]
+ with Serializable {
+
+ def this() = this(64)
+
+ require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int])
+
+ // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+ // bug that would generate two arrays (one for Object and one for specialized T).
+ protected var _keySet: OpenHashSet[K] = _
+ private var _values: Array[V] = _
+ _keySet = new OpenHashSet[K](initialCapacity)
+ _values = new Array[V](_keySet.capacity)
+
+ private var _oldValues: Array[V] = null
+
+ override def size = _keySet.size
+
+ /** Get the value for a given key */
+ def apply(k: K): V = {
+ val pos = _keySet.getPos(k)
+ _values(pos)
+ }
+
+ /** Get the value for a given key, or returns elseValue if it doesn't exist. */
+ def getOrElse(k: K, elseValue: V): V = {
+ val pos = _keySet.getPos(k)
+ if (pos >= 0) _values(pos) else elseValue
+ }
+
+ /** Set the value for a key */
+ def update(k: K, v: V) {
+ val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+ _values(pos) = v
+ _keySet.rehashIfNeeded(k, grow, move)
+ _oldValues = null
+ }
+
+ /**
+ * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+ * set its value to mergeValue(oldValue).
+ *
+ * @return the newly updated value.
+ */
+ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+ val pos = _keySet.addWithoutResize(k)
+ if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+ val newValue = defaultValue
+ _values(pos & OpenHashSet.POSITION_MASK) = newValue
+ _keySet.rehashIfNeeded(k, grow, move)
+ newValue
+ } else {
+ _values(pos) = mergeValue(_values(pos))
+ _values(pos)
+ }
+ }
+
+ override def iterator = new Iterator[(K, V)] {
+ var pos = 0
+ var nextPair: (K, V) = computeNextPair()
+
+ /** Get the next value we should return from next(), or null if we're finished iterating */
+ def computeNextPair(): (K, V) = {
+ pos = _keySet.nextPos(pos)
+ if (pos >= 0) {
+ val ret = (_keySet.getValue(pos), _values(pos))
+ pos += 1
+ ret
+ } else {
+ null
+ }
+ }
+
+ def hasNext = nextPair != null
+
+ def next() = {
+ val pair = nextPair
+ nextPair = computeNextPair()
+ pair
+ }
+ }
+
+ // The following member variables are declared as protected instead of private for the
+ // specialization to work (specialized class extends the unspecialized one and needs access
+ // to the "private" variables).
+ // They also should have been val's. We use var's because there is a Scala compiler bug that
+ // would throw illegal access error at runtime if they are declared as val's.
+ protected var grow = (newCapacity: Int) => {
+ _oldValues = _values
+ _values = new Array[V](newCapacity)
+ }
+
+ protected var move = (oldPos: Int, newPos: Int) => {
+ _values(newPos) = _oldValues(oldPos)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
new file mode 100644
index 0000000000..b84eb65c62
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.reflect.ClassTag
+
+/**
+ * An append-only, non-threadsafe, array-backed vector that is optimized for primitive types.
+ */
+private[spark]
+class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: Int = 64) {
+ private var _numElements = 0
+ private var _array: Array[V] = _
+
+ // NB: This must be separate from the declaration, otherwise the specialized parent class
+ // will get its own array with the same initial size.
+ _array = new Array[V](initialSize)
+
+ def apply(index: Int): V = {
+ require(index < _numElements)
+ _array(index)
+ }
+
+ def +=(value: V) {
+ if (_numElements == _array.length) {
+ resize(_array.length * 2)
+ }
+ _array(_numElements) = value
+ _numElements += 1
+ }
+
+ def capacity: Int = _array.length
+
+ def length: Int = _numElements
+
+ def size: Int = _numElements
+
+ /** Gets the underlying array backing this vector. */
+ def array: Array[V] = _array
+
+ /** Trims this vector so that the capacity is equal to the size. */
+ def trim(): PrimitiveVector[V] = resize(size)
+
+ /** Resizes the array, dropping elements if the total length decreases. */
+ def resize(newLength: Int): PrimitiveVector[V] = {
+ val newArray = new Array[V](newLength)
+ _array.copyToArray(newArray)
+ _array = newArray
+ if (newLength < _numElements) {
+ _numElements = newLength
+ }
+ this
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 4434f3b87c..c443c5266e 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -27,6 +27,21 @@ import org.apache.spark.SparkContext._
class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
+
+ implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] {
+ def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = {
+ t1 ++= t2
+ t1
+ }
+ def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = {
+ t1 += t2
+ t1
+ }
+ def zero(t: mutable.Set[A]) : mutable.Set[A] = {
+ new mutable.HashSet[A]()
+ }
+ }
+
test ("basic accumulation"){
sc = new SparkContext("local", "test")
val acc : Accumulator[Int] = sc.accumulator(0)
@@ -51,7 +66,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
}
test ("add value to collection accumulators") {
- import SetAccum._
val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
@@ -68,22 +82,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
}
}
- implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] {
- def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = {
- t1 ++= t2
- t1
- }
- def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = {
- t1 += t2
- t1
- }
- def zero(t: mutable.Set[Any]) : mutable.Set[Any] = {
- new mutable.HashSet[Any]()
- }
- }
-
test ("value not readable in tasks") {
- import SetAccum._
val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
@@ -125,7 +124,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
}
test ("localValue readable in tasks") {
- import SetAccum._
val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
index b3a53d928b..e022accee6 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -20,8 +20,42 @@ package org.apache.spark
import org.scalatest.FunSuite
class BroadcastSuite extends FunSuite with LocalSparkContext {
-
- test("basic broadcast") {
+
+ override def afterEach() {
+ super.afterEach()
+ System.clearProperty("spark.broadcast.factory")
+ }
+
+ test("Using HttpBroadcast locally") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ sc = new SparkContext("local", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === Set((1, 10), (2, 10)))
+ }
+
+ test("Accessing HttpBroadcast variables from multiple threads") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ sc = new SparkContext("local[10]", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+ }
+
+ test("Accessing HttpBroadcast variables in a local cluster") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
+ test("Using TorrentBroadcast locally") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
@@ -29,11 +63,23 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
assert(results.collect.toSet === Set((1, 10), (2, 10)))
}
- test("broadcast variables accessed in multiple threads") {
+ test("Accessing TorrentBroadcast variables from multiple threads") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
}
+
+ test("Accessing TorrentBroadcast variables in a local cluster") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index f26c44d3e7..f25d921d3f 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import scala.reflect.ClassTag
import org.scalatest.FunSuite
import java.io.File
import org.apache.spark.rdd._
@@ -62,8 +63,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithContextRDD(r,
- (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
testCheckpointing(_.pipe(Seq("cat")))
@@ -207,7 +206,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
* not, but this is not done by default as usually the partitions do not refer to any RDD and
* therefore never store the lineage.
*/
- def testCheckpointing[U: ClassManifest](
+ def testCheckpointing[U: ClassTag](
op: (RDD[Int]) => RDD[U],
testRDDSize: Boolean = true,
testRDDPartitionSize: Boolean = false
@@ -276,7 +275,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
* RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
* this RDD will remember the partitions and therefore potentially the whole lineage.
*/
- def testParentCheckpointing[U: ClassManifest](
+ def testParentCheckpointing[U: ClassTag](
op: (RDD[Int]) => RDD[U],
testRDDSize: Boolean,
testRDDPartitionSize: Boolean
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 480bac84f3..d9cb7fead5 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -122,7 +122,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
sc.parallelize(1 to 10, 10).foreach(x => println(x / 0))
}
assert(thrown.getClass === classOf[SparkException])
- assert(thrown.getMessage.contains("more than 4 times"))
+ assert(thrown.getMessage.contains("failed 4 times"))
}
test("caching") {
@@ -303,12 +303,13 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
Thread.sleep(200)
}
} catch {
- case _ => { Thread.sleep(10) }
+ case _: Throwable => { Thread.sleep(10) }
// Do nothing. We might see exceptions because block manager
// is racing this thread to remove entries from the driver.
}
}
}
+
}
object DistributedSuite {
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
index 01a72d8401..6d1695eae7 100644
--- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -34,7 +34,7 @@ class DriverSuite extends FunSuite with Timeouts {
// Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
forAll(masters) { (master: String) =>
- failAfter(30 seconds) {
+ failAfter(60 seconds) {
Utils.execute(Seq("./spark-class", "org.apache.spark.DriverWithoutCleanup", master),
new File(System.getenv("SPARK_HOME")))
}
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
index 35d1d41af1..c210dd5c3b 100644
--- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -120,4 +120,20 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
}.collect()
assert(result.toSet === Set((1,2), (2,7), (3,121)))
}
+
+ test ("Dynamically adding JARS on a standalone cluster using local: URL") {
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+ val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile()
+ sc.addJar(sampleJarFile.replace("file", "local"))
+ val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0))
+ val result = sc.parallelize(testData).reduceByKey { (x,y) =>
+ val fac = Thread.currentThread.getContextClassLoader()
+ .loadClass("org.uncommons.maths.Maths")
+ .getDeclaredMethod("factorial", classOf[Int])
+ val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
+ a + b
+ }.collect()
+ assert(result.toSet === Set((1,2), (2,7), (3,121)))
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 7b0bb89ab2..79913dc718 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -365,6 +365,20 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void javaDoubleRDDHistoGram() {
+ JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
+ // Test using generated buckets
+ Tuple2<double[], long[]> results = rdd.histogram(2);
+ double[] expected_buckets = {1.0, 2.5, 4.0};
+ long[] expected_counts = {2, 2};
+ Assert.assertArrayEquals(expected_buckets, results._1, 0.1);
+ Assert.assertArrayEquals(expected_counts, results._2);
+ // Test with provided buckets
+ long[] histogram = rdd.histogram(expected_buckets);
+ Assert.assertArrayEquals(expected_counts, histogram);
+ }
+
+ @Test
public void map() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
@@ -473,6 +487,27 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void repartition() {
+ // Shrinking number of partitions
+ JavaRDD<Integer> in1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 2);
+ JavaRDD<Integer> repartitioned1 = in1.repartition(4);
+ List<List<Integer>> result1 = repartitioned1.glom().collect();
+ Assert.assertEquals(4, result1.size());
+ for (List<Integer> l: result1) {
+ Assert.assertTrue(l.size() > 0);
+ }
+
+ // Growing number of partitions
+ JavaRDD<Integer> in2 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 4);
+ JavaRDD<Integer> repartitioned2 = in2.repartition(2);
+ List<List<Integer>> result2 = repartitioned2.glom().collect();
+ Assert.assertEquals(2, result2.size());
+ for (List<Integer> l: result2) {
+ Assert.assertTrue(l.size() > 0);
+ }
+ }
+
+ @Test
public void persist() {
JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY());
@@ -862,4 +897,37 @@ public class JavaAPISuite implements Serializable {
new Tuple2<Integer, Integer>(0, 4)), rdd3.collect());
}
+
+ @Test
+ public void collectPartitions() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3);
+
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i, i % 2);
+ }
+ });
+
+ List[] parts = rdd1.collectPartitions(new int[] {0});
+ Assert.assertEquals(Arrays.asList(1, 2), parts[0]);
+
+ parts = rdd1.collectPartitions(new int[] {1, 2});
+ Assert.assertEquals(Arrays.asList(3, 4), parts[0]);
+ Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]);
+
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(2, 0)),
+ rdd2.collectPartitions(new int[] {0})[0]);
+
+ parts = rdd2.collectPartitions(new int[] {1, 2});
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1),
+ new Tuple2<Integer, Integer>(4, 0)),
+ parts[0]);
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1),
+ new Tuple2<Integer, Integer>(6, 0),
+ new Tuple2<Integer, Integer>(7, 1)),
+ parts[1]);
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index a192651491..1121e06e2e 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark
import java.util.concurrent.Semaphore
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
import scala.concurrent.future
import scala.concurrent.ExecutionContext.Implicits.global
@@ -83,6 +85,36 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(sc.parallelize(1 to 10, 2).count === 10)
}
+ test("job group") {
+ sc = new SparkContext("local[2]", "test")
+
+ // Add a listener to release the semaphore once any tasks are launched.
+ val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+
+ // jobA is the one to be cancelled.
+ val jobA = future {
+ sc.setJobGroup("jobA", "this is a job to be cancelled")
+ sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
+ }
+
+ sc.clearJobGroup()
+ val jobB = sc.parallelize(1 to 100, 2).countAsync()
+
+ // Block until both tasks of job A have started and cancel job A.
+ sem.acquire(2)
+ sc.cancelJobGroup("jobA")
+ val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) }
+ assert(e.getMessage contains "cancel")
+
+ // Once A is cancelled, job B should finish fairly quickly.
+ assert(jobB.get() === 100)
+ }
+/*
test("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
// sem2: make sure the first stage is not finished until cancel is issued
@@ -116,7 +148,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
intercept[SparkException] { f1.get() }
intercept[SparkException] { f2.get() }
}
-
+ */
def testCount() {
// Cancel before launching any tasks
{
diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
index 459e257d79..8dd5786da6 100644
--- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
@@ -30,7 +30,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self
@transient var sc: SparkContext = _
override def beforeAll() {
- InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory());
+ InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory())
super.beforeAll()
}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 6013320eaa..271dc905bc 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -48,15 +48,15 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))))
tracker.stop()
}
test("master register and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -74,19 +74,17 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
- val tracker = new MapOutputTracker()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
+ val tracker = new MapOutputTrackerMaster()
+ tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
- val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
- val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
- // As if we had two simulatenous fetch failures
+ // As if we had two simultaneous fetch failures
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
@@ -102,14 +100,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
- val masterTracker = new MapOutputTracker()
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
+ val masterTracker = new MapOutputTrackerMaster()
+ masterTracker.trackerActor = Left(actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker"))
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
val slaveTracker = new MapOutputTracker()
- slaveTracker.trackerActor = slaveSystem.actorFor(
- "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker")
+ slaveTracker.trackerActor = Right(slaveSystem.actorSelection(
+ "akka.tcp://spark@localhost:" + boundPort + "/user/MapOutputTracker"))
masterTracker.registerShuffle(10, 1)
masterTracker.incrementEpoch()
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
index 7d938917f2..1374d01774 100644
--- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -142,11 +142,11 @@ class PartitioningSuite extends FunSuite with SharedSparkContext {
.filter(_ >= 0.0)
// Run the partitions, including the consecutive empty ones, through StatCounter
- val stats: StatCounter = rdd.stats();
- assert(abs(6.0 - stats.sum) < 0.01);
- assert(abs(6.0/2 - rdd.mean) < 0.01);
- assert(abs(1.0 - rdd.variance) < 0.01);
- assert(abs(1.0 - rdd.stdev) < 0.01);
+ val stats: StatCounter = rdd.stats()
+ assert(abs(6.0 - stats.sum) < 0.01)
+ assert(abs(6.0/2 - rdd.mean) < 0.01)
+ assert(abs(1.0 - rdd.variance) < 0.01)
+ assert(abs(1.0 - rdd.stdev) < 0.01)
// Add other tests here for classes that should be able to handle empty partitions correctly
}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
new file mode 100644
index 0000000000..151af0d213
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.{FunSuite, PrivateMethodTester}
+
+import org.apache.spark.scheduler.TaskScheduler
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, SimrSchedulerBackend, SparkDeploySchedulerBackend}
+import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
+import org.apache.spark.scheduler.local.LocalScheduler
+
+class SparkContextSchedulerCreationSuite
+ extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging {
+
+ def createTaskScheduler(master: String): TaskScheduler = {
+ // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the
+ // real schedulers, so we don't want to create a full SparkContext with the desired scheduler.
+ sc = new SparkContext("local", "test")
+ val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler)
+ SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test")
+ }
+
+ test("bad-master") {
+ val e = intercept[SparkException] {
+ createTaskScheduler("localhost:1234")
+ }
+ assert(e.getMessage.contains("Could not parse Master URL"))
+ }
+
+ test("local") {
+ createTaskScheduler("local") match {
+ case s: LocalScheduler =>
+ assert(s.threads === 1)
+ assert(s.maxFailures === 0)
+ case _ => fail()
+ }
+ }
+
+ test("local-n") {
+ createTaskScheduler("local[5]") match {
+ case s: LocalScheduler =>
+ assert(s.threads === 5)
+ assert(s.maxFailures === 0)
+ case _ => fail()
+ }
+ }
+
+ test("local-n-failures") {
+ createTaskScheduler("local[4, 2]") match {
+ case s: LocalScheduler =>
+ assert(s.threads === 4)
+ assert(s.maxFailures === 2)
+ case _ => fail()
+ }
+ }
+
+ test("simr") {
+ createTaskScheduler("simr://uri") match {
+ case s: ClusterScheduler =>
+ assert(s.backend.isInstanceOf[SimrSchedulerBackend])
+ case _ => fail()
+ }
+ }
+
+ test("local-cluster") {
+ createTaskScheduler("local-cluster[3, 14, 512]") match {
+ case s: ClusterScheduler =>
+ assert(s.backend.isInstanceOf[SparkDeploySchedulerBackend])
+ case _ => fail()
+ }
+ }
+
+ def testYarn(master: String, expectedClassName: String) {
+ try {
+ createTaskScheduler(master) match {
+ case s: ClusterScheduler =>
+ assert(s.getClass === Class.forName(expectedClassName))
+ case _ => fail()
+ }
+ } catch {
+ case e: SparkException =>
+ assert(e.getMessage.contains("YARN mode not available"))
+ logWarning("YARN not available, could not test actual YARN scheduler creation")
+ case e: Throwable => fail(e)
+ }
+ }
+
+ test("yarn-standalone") {
+ testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler")
+ }
+
+ test("yarn-client") {
+ testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
+ }
+
+ def testMesos(master: String, expectedClass: Class[_]) {
+ try {
+ createTaskScheduler(master) match {
+ case s: ClusterScheduler =>
+ assert(s.backend.getClass === expectedClass)
+ case _ => fail()
+ }
+ } catch {
+ case e: UnsatisfiedLinkError =>
+ assert(e.getMessage.contains("no mesos in"))
+ logWarning("Mesos not available, could not test actual Mesos scheduler creation")
+ case e: Throwable => fail(e)
+ }
+ }
+
+ test("mesos fine-grained") {
+ System.setProperty("spark.mesos.coarse", "false")
+ testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend])
+ }
+
+ test("mesos coarse-grained") {
+ System.setProperty("spark.mesos.coarse", "true")
+ testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend])
+ }
+
+ test("mesos with zookeeper") {
+ System.setProperty("spark.mesos.coarse", "false")
+ testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend])
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
index 46a2da1724..768ca3850e 100644
--- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
+++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
@@ -37,7 +37,7 @@ class UnpersistSuite extends FunSuite with LocalSparkContext {
Thread.sleep(200)
}
} catch {
- case _ => { Thread.sleep(10) }
+ case _: Throwable => { Thread.sleep(10) }
// Do nothing. We might see exceptions because block manager
// is racing this thread to remove entries from the driver.
}
diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
index 21f16ef2c6..4cb4ddc9cd 100644
--- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -15,31 +15,22 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.deploy.worker
+import java.io.File
import org.scalatest.FunSuite
-import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.{RDD, PartitionPruningRDD}
+import org.apache.spark.deploy.{ExecutorState, Command, ApplicationDescription}
+class ExecutorRunnerTest extends FunSuite {
+ test("command includes appId") {
+ def f(s:String) = new File(s)
+ val sparkHome = sys.env("SPARK_HOME")
+ val appDesc = new ApplicationDescription("app name", 8, 500, Command("foo", Seq(),Map()),
+ sparkHome, "appUiUrl")
+ val appId = "12345-worker321-9876"
+ val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome),
+ f("ooga"), ExecutorState.RUNNING)
-class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
-
- test("Pruned Partitions inherit locality prefs correctly") {
- class TestPartition(i: Int) extends Partition {
- def index = i
- }
- val rdd = new RDD[Int](sc, Nil) {
- override protected def getPartitions = {
- Array[Partition](
- new TestPartition(1),
- new TestPartition(2),
- new TestPartition(3))
- }
- def compute(split: Partition, context: TaskContext) = {Iterator()}
- }
- val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false})
- val p = prunedRDD.partitions(0)
- assert(p.index == 2)
- assert(prunedRDD.partitions.length == 1)
+ assert(er.buildCommandSeq().last === appId)
}
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index da032b17d9..0d4c10db8e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.rdd
import java.util.concurrent.Semaphore
+import scala.concurrent.{Await, TimeoutException}
+import scala.concurrent.duration.Duration
import scala.concurrent.ExecutionContext.Implicits.global
import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -173,4 +175,28 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts
sem.acquire(2)
}
}
+
+ /**
+ * Awaiting FutureAction results
+ */
+ test("FutureAction result, infinite wait") {
+ val f = sc.parallelize(1 to 100, 4)
+ .countAsync()
+ assert(Await.result(f, Duration.Inf) === 100)
+ }
+
+ test("FutureAction result, finite wait") {
+ val f = sc.parallelize(1 to 100, 4)
+ .countAsync()
+ assert(Await.result(f, Duration(30, "seconds")) === 100)
+ }
+
+ test("FutureAction result, timeout") {
+ val f = sc.parallelize(1 to 100, 4)
+ .mapPartitions(itr => { Thread.sleep(20); itr })
+ .countAsync()
+ intercept[TimeoutException] {
+ Await.result(f, Duration(20, "milliseconds"))
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
new file mode 100644
index 0000000000..7f50a5a47c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.math.abs
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd._
+import org.apache.spark._
+
+class DoubleRDDSuite extends FunSuite with SharedSparkContext {
+ // Verify tests on the histogram functionality. We test with both evenly
+ // and non-evenly spaced buckets as the bucket lookup function changes.
+ test("WorksOnEmpty") {
+ // Make sure that it works on an empty input
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithOneBucket") {
+ // Verify that if all of the elements are out of range the counts are zero
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucket") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucketExactMatch") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(1.0, 4.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoUnEvenBuckets") {
+ // Verify that out of range works with two un even buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 4.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBuckets") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBucketsAndNaN") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithFourUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithUnevenBucketsAndNaN") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket and an inifity
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 4)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithInfiniteBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN))
+ val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(1, 1)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Test the failure mode with an invalid bucket array
+ test("ThrowsExceptionOnInvalidBucketArray") {
+ val rdd = sc.parallelize(Seq(1.0))
+ // Empty array
+ intercept[IllegalArgumentException] {
+ val buckets = Array.empty[Double]
+ val result = rdd.histogram(buckets)
+ }
+ // Single element array
+ intercept[IllegalArgumentException] {
+ val buckets = Array(1.0)
+ val result = rdd.histogram(buckets)
+ }
+ }
+
+ // Test automatic histogram function
+ test("WorksWithoutBucketsBasic") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicSingleElement") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(1)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicNoRange") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 1, 1, 1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsBasicTwo") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(2)
+ val expectedHistogramResults = Array(2, 2)
+ val expectedHistogramBuckets = Array(1.0, 2.5, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsWithMoreRequestedThanElements") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2))
+ val (histogramBuckets, histogramResults) = rdd.histogram(10)
+ val expectedHistogramResults =
+ Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1)
+ val expectedHistogramBuckets =
+ Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ // Test the failure mode with an invalid RDD
+ test("ThrowsExceptionOnInvalidRDDs") {
+ // infinity
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, 1.0/0.0))
+ val result = rdd.histogram(1)
+ }
+ // NaN
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, Double.NaN))
+ val result = rdd.histogram(1)
+ }
+ // Empty
+ intercept[UnsupportedOperationException] {
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val result = rdd.histogram(1)
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala
new file mode 100644
index 0000000000..53a7b7c44d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.scalatest.FunSuite
+import org.apache.spark.{TaskContext, Partition, SharedSparkContext}
+
+
+class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
+
+
+ test("Pruned Partitions inherit locality prefs correctly") {
+
+ val rdd = new RDD[Int](sc, Nil) {
+ override protected def getPartitions = {
+ Array[Partition](
+ new TestPartition(0, 1),
+ new TestPartition(1, 1),
+ new TestPartition(2, 1))
+ }
+
+ def compute(split: Partition, context: TaskContext) = {
+ Iterator()
+ }
+ }
+ val prunedRDD = PartitionPruningRDD.create(rdd, {
+ x => if (x == 2) true else false
+ })
+ assert(prunedRDD.partitions.length == 1)
+ val p = prunedRDD.partitions(0)
+ assert(p.index == 0)
+ assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)
+ }
+
+
+ test("Pruned Partitions can be unioned ") {
+
+ val rdd = new RDD[Int](sc, Nil) {
+ override protected def getPartitions = {
+ Array[Partition](
+ new TestPartition(0, 4),
+ new TestPartition(1, 5),
+ new TestPartition(2, 6))
+ }
+
+ def compute(split: Partition, context: TaskContext) = {
+ List(split.asInstanceOf[TestPartition].testValue).iterator
+ }
+ }
+ val prunedRDD1 = PartitionPruningRDD.create(rdd, {
+ x => if (x == 0) true else false
+ })
+
+ val prunedRDD2 = PartitionPruningRDD.create(rdd, {
+ x => if (x == 2) true else false
+ })
+
+ val merged = prunedRDD1 ++ prunedRDD2
+ assert(merged.count() == 2)
+ val take = merged.take(2)
+ assert(take.apply(0) == 4)
+ assert(take.apply(1) == 6)
+ }
+}
+
+class TestPartition(i: Int, value: Int) extends Partition with Serializable {
+ def index = i
+
+ def testValue = this.value
+
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 413ea85322..2f81b81797 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -152,6 +152,26 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(rdd.union(emptyKv).collect().size === 2)
}
+ test("repartitioned RDDs") {
+ val data = sc.parallelize(1 to 1000, 10)
+
+ // Coalesce partitions
+ val repartitioned1 = data.repartition(2)
+ assert(repartitioned1.partitions.size == 2)
+ val partitions1 = repartitioned1.glom().collect()
+ assert(partitions1(0).length > 0)
+ assert(partitions1(1).length > 0)
+ assert(repartitioned1.collect().toSet === (1 to 1000).toSet)
+
+ // Split partitions
+ val repartitioned2 = data.repartition(20)
+ assert(repartitioned2.partitions.size == 20)
+ val partitions2 = repartitioned2.glom().collect()
+ assert(partitions2(0).length > 0)
+ assert(partitions2(19).length > 0)
+ assert(repartitioned2.collect().toSet === (1 to 1000).toSet)
+ }
+
test("coalesced RDDs") {
val data = sc.parallelize(1 to 10, 10)
@@ -237,8 +257,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
// test that you get over 90% locality in each group
val minLocality = coalesced2.partitions
.map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction)
- .foldLeft(1.)((perc, loc) => math.min(perc,loc))
- assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%")
+ .foldLeft(1.0)((perc, loc) => math.min(perc,loc))
+ assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.0).toInt + "%")
// test that the groups are load balanced with 100 +/- 20 elements in each
val maxImbalance = coalesced2.partitions
@@ -250,9 +270,9 @@ class RDDSuite extends FunSuite with SharedSparkContext {
val coalesced3 = data3.coalesce(numMachines*2)
val minLocality2 = coalesced3.partitions
.map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction)
- .foldLeft(1.)((perc, loc) => math.min(perc,loc))
+ .foldLeft(1.0)((perc, loc) => math.min(perc,loc))
assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " +
- (minLocality2*100.).toInt + "%")
+ (minLocality2*100.0).toInt + "%")
}
test("zipped RDDs") {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 2a2f828be6..706d84a58b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.apache.spark.LocalSparkContext
-import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTrackerMaster
import org.apache.spark.SparkContext
import org.apache.spark.Partition
import org.apache.spark.TaskContext
@@ -64,7 +64,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def defaultParallelism() = 2
}
- var mapOutputTracker: MapOutputTracker = null
+ var mapOutputTracker: MapOutputTrackerMaster = null
var scheduler: DAGScheduler = null
/**
@@ -99,8 +99,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSets.clear()
cacheLocations.clear()
results.clear()
- mapOutputTracker = new MapOutputTracker()
- scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
+ mapOutputTracker = new MapOutputTrackerMaster()
+ scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, sc.env) {
override def runLocally(job: ActiveJob) {
// don't bother with the thread while unit testing
runLocallyWithinThread(job)
@@ -206,6 +206,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
submit(rdd, Array(0))
complete(taskSets(0), List((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("local job") {
@@ -219,6 +220,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val jobId = scheduler.nextJobId.getAndIncrement()
runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("run trivial job w/ dependency") {
@@ -227,6 +229,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
submit(finalRdd, Array(0))
complete(taskSets(0), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("cache location preferences w/ dependency") {
@@ -239,12 +242,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
assertLocations(taskSet, Seq(Seq("hostA", "hostB")))
complete(taskSet, Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("trivial job failure") {
submit(makeRdd(1, Nil), Array(0))
failed(taskSets(0), "some failure")
assert(failure.getMessage === "Job aborted: some failure")
+ assertDataStructuresEmpty
}
test("run trivial shuffle") {
@@ -260,6 +265,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
complete(taskSets(1), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("run trivial shuffle with fetch failure") {
@@ -285,6 +291,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
complete(taskSets(3), Seq((Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
+ assertDataStructuresEmpty
}
test("ignore late map task completions") {
@@ -313,6 +320,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
complete(taskSets(1), Seq((Success, 42), (Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
+ assertDataStructuresEmpty
}
test("run trivial shuffle with out-of-band failure and retry") {
@@ -329,15 +337,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))))
- // have hostC complete the resubmitted task
- complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
- Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
- complete(taskSets(2), Seq((Success, 42)))
- assert(results === Map(0 -> 42))
- }
-
- test("recursive shuffle failures") {
+ // have hostC complete the resubmitted task
+ complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ complete(taskSets(2), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
+ }
+
+ test("recursive shuffle failures") {
val shuffleOneRdd = makeRdd(2, Nil)
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
@@ -363,6 +372,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
complete(taskSets(5), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("cached post-shuffle") {
@@ -394,6 +404,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
complete(taskSets(4), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
/**
@@ -413,4 +424,18 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
private def makeBlockManagerId(host: String): BlockManagerId =
BlockManagerId("exec-" + host, host, 12345, 0)
+ private def assertDataStructuresEmpty = {
+ assert(scheduler.pendingTasks.isEmpty)
+ assert(scheduler.activeJobs.isEmpty)
+ assert(scheduler.failed.isEmpty)
+ assert(scheduler.idToActiveJob.isEmpty)
+ assert(scheduler.jobIdToStageIds.isEmpty)
+ assert(scheduler.stageIdToJobIds.isEmpty)
+ assert(scheduler.stageIdToStage.isEmpty)
+ assert(scheduler.stageToInfos.isEmpty)
+ assert(scheduler.resultStageToJob.isEmpty)
+ assert(scheduler.running.isEmpty)
+ assert(scheduler.shuffleToMapStage.isEmpty)
+ assert(scheduler.waiting.isEmpty)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
index cece60dda7..002368ff55 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD
class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+ val WAIT_TIMEOUT_MILLIS = 10000
test("inner method") {
sc = new SparkContext("local", "joblogger")
@@ -58,11 +59,14 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val parentRdd = makeRdd(4, Nil)
val shuffleDep = new ShuffleDependency(parentRdd, null)
val rootRdd = makeRdd(4, List(shuffleDep))
- val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None)
- val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None)
-
- joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null))
- joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
+ val shuffleMapStage =
+ new Stage(1, parentRdd, parentRdd.partitions.size, Some(shuffleDep), Nil, jobID, None)
+ val rootStage =
+ new Stage(0, rootRdd, rootRdd.partitions.size, None, List(shuffleMapStage), jobID, None)
+ val rootStageInfo = new StageInfo(rootStage)
+
+ joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null))
+ joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getSimpleName)
parentRdd.setName("MyRDD")
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
joblogger.createLogWriterTest(jobID)
@@ -88,8 +92,12 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
+ val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER)
- joblogger.getLogDir should be ("/tmp/spark")
+ joblogger.getLogDir should be ("/tmp/spark-%s".format(user))
joblogger.getJobIDtoPrintWriter.size should be (1)
joblogger.getStageIDToJobID.size should be (2)
joblogger.getStageIDToJobID.get(0) should be (Some(0))
@@ -115,7 +123,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
-
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
joblogger.onJobStartCount should be (1)
joblogger.onJobEndCount should be (1)
joblogger.onTaskEndCount should be (8)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index a549417a47..2e41438a52 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -17,16 +17,62 @@
package org.apache.spark.scheduler
-import org.scalatest.FunSuite
-import org.apache.spark.{SparkContext, LocalSparkContext}
-import scala.collection.mutable
+import scala.collection.mutable.{Buffer, HashSet}
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._
-class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
+ with BeforeAndAfterAll {
+ /** Length of time to wait while draining listener events. */
+ val WAIT_TIMEOUT_MILLIS = 10000
+
+ override def afterAll {
+ System.clearProperty("spark.akka.frameSize")
+ }
+
+ test("basic creation of StageInfo") {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
+ val listener = new SaveStageInfo
+ sc.addSparkListener(listener)
+ val rdd1 = sc.parallelize(1 to 100, 4)
+ val rdd2 = rdd1.map(x => x.toString)
+ rdd2.setName("Target RDD")
+ rdd2.count
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
+ listener.stageInfos.size should be {1}
+ val first = listener.stageInfos.head
+ first.rddName should be {"Target RDD"}
+ first.numTasks should be {4}
+ first.numPartitions should be {4}
+ first.submissionTime should be ('defined)
+ first.completionTime should be ('defined)
+ first.taskInfos.length should be {4}
+ }
+
+ test("StageInfo with fewer tasks than partitions") {
+ sc = new SparkContext("local", "DAGSchedulerSuite")
+ val listener = new SaveStageInfo
+ sc.addSparkListener(listener)
+ val rdd1 = sc.parallelize(1 to 100, 4)
+ val rdd2 = rdd1.map(x => x.toString)
+ sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
+ listener.stageInfos.size should be {1}
+ val first = listener.stageInfos.head
+ first.numTasks should be {2}
+ first.numPartitions should be {4}
+ }
test("local metrics") {
- sc = new SparkContext("local[4]", "test")
+ sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
@@ -37,9 +83,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
i
}
- val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
+ val d = sc.parallelize(0 to 1e4.toInt, 64).map{i => w(i)}
d.count()
- val WAIT_TIMEOUT_MILLIS = 10000
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (1)
@@ -64,7 +109,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong},
stageInfo + " executorDeserializeTime")
- if (stageInfo.stage.rdd.name == d4.name) {
+ if (stageInfo.rddName == d4.name) {
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime},
stageInfo + " fetchWaitTime")
@@ -72,11 +117,11 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
stageInfo.taskInfos.foreach { case (taskInfo, taskMetrics) =>
taskMetrics.resultSize should be > (0l)
- if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
+ if (stageInfo.rddName == d2.name || stageInfo.rddName == d3.name) {
taskMetrics.shuffleWriteMetrics should be ('defined)
taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
}
- if (stageInfo.stage.rdd.name == d4.name) {
+ if (stageInfo.rddName == d4.name) {
taskMetrics.shuffleReadMetrics should be ('defined)
val sm = taskMetrics.shuffleReadMetrics.get
sm.totalBlocksFetched should be > (0)
@@ -89,20 +134,73 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
}
- def checkNonZeroAvg(m: Traversable[Long], msg: String) {
- assert(m.sum / m.size.toDouble > 0.0, msg)
+ test("onTaskGettingResult() called when result fetched remotely") {
+ // Need to use local cluster mode here, because results are not ever returned through the
+ // block manager when using the LocalScheduler.
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+
+ val listener = new SaveTaskEvents
+ sc.addSparkListener(listener)
+
+ // Make a task whose result is larger than the akka frame size
+ System.setProperty("spark.akka.frameSize", "1")
+ val akkaFrameSize =
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
+ val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x)
+ assert(result === 1.to(akkaFrameSize).toArray)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ val TASK_INDEX = 0
+ assert(listener.startedTasks.contains(TASK_INDEX))
+ assert(listener.startedGettingResultTasks.contains(TASK_INDEX))
+ assert(listener.endedTasks.contains(TASK_INDEX))
}
- def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
- val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
- !names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
+ test("onTaskGettingResult() not called when result sent directly") {
+ // Need to use local cluster mode here, because results are not ever returned through the
+ // block manager when using the LocalScheduler.
+ sc = new SparkContext("local-cluster[1,1,512]", "test")
+
+ val listener = new SaveTaskEvents
+ sc.addSparkListener(listener)
+
+ // Make a task whose result is larger than the akka frame size
+ val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
+ assert(result === 2)
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ val TASK_INDEX = 0
+ assert(listener.startedTasks.contains(TASK_INDEX))
+ assert(listener.startedGettingResultTasks.isEmpty == true)
+ assert(listener.endedTasks.contains(TASK_INDEX))
+ }
+
+ def checkNonZeroAvg(m: Traversable[Long], msg: String) {
+ assert(m.sum / m.size.toDouble > 0.0, msg)
}
class SaveStageInfo extends SparkListener {
- val stageInfos = mutable.Buffer[StageInfo]()
+ val stageInfos = Buffer[StageInfo]()
override def onStageCompleted(stage: StageCompleted) {
- stageInfos += stage.stageInfo
+ stageInfos += stage.stage
}
}
+ class SaveTaskEvents extends SparkListener {
+ val startedTasks = new HashSet[Int]()
+ val startedGettingResultTasks = new HashSet[Int]()
+ val endedTasks = new HashSet[Int]()
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ startedTasks += taskStart.taskInfo.index
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ endedTasks += taskEnd.taskInfo.index
+ }
+
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
+ startedGettingResultTasks += taskGettingResult.taskInfo.index
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
index b97f2b19b5..bb28a31a99 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -283,7 +283,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
// Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted
// after the last failure.
- (0 until manager.MAX_TASK_FAILURES).foreach { index =>
+ (1 to manager.MAX_TASK_FAILURES).foreach { index =>
val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY)
assert(offerResult != None,
"Expect resource offer on iteration %s to return a task".format(index))
@@ -313,6 +313,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
}
def createTaskResult(id: Int): DirectTaskResult[Int] = {
- new DirectTaskResult[Int](id, mutable.Map.empty, new TaskMetrics)
+ val valueSer = SparkEnv.get.serializer.newInstance()
+ new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
index ee150a3107..27c2d53361 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
@@ -82,7 +82,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
test("handling results larger than Akka frame size") {
val akkaFrameSize =
- sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
assert(result === 1.to(akkaFrameSize).toArray)
@@ -103,7 +103,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
}
scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
val akkaFrameSize =
- sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
assert(result === 1.to(akkaFrameSize).toArray)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
index cb76275e39..b647e8a672 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
@@ -39,7 +39,7 @@ class BlockIdSuite extends FunSuite {
fail()
} catch {
case e: IllegalStateException => // OK
- case _ => fail()
+ case _: Throwable => fail()
}
}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 484a654108..5b4d63b954 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -56,7 +56,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
System.setProperty("spark.hostPort", "localhost:" + boundPort)
master = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
+ Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))))
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
oldArch = System.setProperty("os.arch", "amd64")
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
new file mode 100644
index 0000000000..070982e798
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.{FileWriter, File}
+
+import scala.collection.mutable
+
+import com.google.common.io.Files
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
+
+class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
+
+ val rootDir0 = Files.createTempDir()
+ rootDir0.deleteOnExit()
+ val rootDir1 = Files.createTempDir()
+ rootDir1.deleteOnExit()
+ val rootDirs = rootDir0.getName + "," + rootDir1.getName
+ println("Created root dirs: " + rootDirs)
+
+ // This suite focuses primarily on consolidation features,
+ // so we coerce consolidation if not already enabled.
+ val consolidateProp = "spark.shuffle.consolidateFiles"
+ val oldConsolidate = Option(System.getProperty(consolidateProp))
+ System.setProperty(consolidateProp, "true")
+
+ val shuffleBlockManager = new ShuffleBlockManager(null) {
+ var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]()
+ override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id)
+ }
+
+ var diskBlockManager: DiskBlockManager = _
+
+ override def afterAll() {
+ oldConsolidate.map(c => System.setProperty(consolidateProp, c))
+ }
+
+ override def beforeEach() {
+ diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs)
+ shuffleBlockManager.idToSegmentMap.clear()
+ }
+
+ test("basic block creation") {
+ val blockId = new TestBlockId("test")
+ assertSegmentEquals(blockId, blockId.name, 0, 0)
+
+ val newFile = diskBlockManager.getFile(blockId)
+ writeToFile(newFile, 10)
+ assertSegmentEquals(blockId, blockId.name, 0, 10)
+
+ newFile.delete()
+ }
+
+ test("block appending") {
+ val blockId = new TestBlockId("test")
+ val newFile = diskBlockManager.getFile(blockId)
+ writeToFile(newFile, 15)
+ assertSegmentEquals(blockId, blockId.name, 0, 15)
+ val newFile2 = diskBlockManager.getFile(blockId)
+ assert(newFile === newFile2)
+ writeToFile(newFile2, 12)
+ assertSegmentEquals(blockId, blockId.name, 0, 27)
+ newFile.delete()
+ }
+
+ test("block remapping") {
+ val filename = "test"
+ val blockId0 = new ShuffleBlockId(1, 2, 3)
+ val newFile = diskBlockManager.getFile(filename)
+ writeToFile(newFile, 15)
+ shuffleBlockManager.idToSegmentMap(blockId0) = new FileSegment(newFile, 0, 15)
+ assertSegmentEquals(blockId0, filename, 0, 15)
+
+ val blockId1 = new ShuffleBlockId(1, 2, 4)
+ val newFile2 = diskBlockManager.getFile(filename)
+ writeToFile(newFile2, 12)
+ shuffleBlockManager.idToSegmentMap(blockId1) = new FileSegment(newFile, 15, 12)
+ assertSegmentEquals(blockId1, filename, 15, 12)
+
+ assert(newFile === newFile2)
+ newFile.delete()
+ }
+
+ def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) {
+ val segment = diskBlockManager.getBlockLocation(blockId)
+ assert(segment.file.getName === filename)
+ assert(segment.offset === offset)
+ assert(segment.length === length)
+ }
+
+ def writeToFile(file: File, numBytes: Int) {
+ val writer = new FileWriter(file, true)
+ for (i <- 0 until numBytes) writer.write(i)
+ writer.close()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
index 8f0ec6683b..3764f4d1a0 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
@@ -34,7 +34,6 @@ class UISuite extends FunSuite {
}
val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("localhost", startPort, Seq())
val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("localhost", startPort, Seq())
-
// Allow some wiggle room in case ports on the machine are under contention
assert(boundPort1 > startPort && boundPort1 < startPort + 10)
assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10)
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
new file mode 100644
index 0000000000..67a57a0e7f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import org.scalatest.FunSuite
+import org.apache.spark.scheduler._
+import org.apache.spark.{LocalSparkContext, SparkContext, Success}
+import org.apache.spark.scheduler.SparkListenerTaskStart
+import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
+
+class JobProgressListenerSuite extends FunSuite with LocalSparkContext {
+ test("test executor id to summary") {
+ val sc = new SparkContext("local", "test")
+ val listener = new JobProgressListener(sc)
+ val taskMetrics = new TaskMetrics()
+ val shuffleReadMetrics = new ShuffleReadMetrics()
+
+ // nothing in it
+ assert(listener.stageIdToExecutorSummaries.size == 0)
+
+ // finish this task, should get updated shuffleRead
+ shuffleReadMetrics.remoteBytesRead = 1000
+ taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+ var taskInfo = new TaskInfo(1234L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail())
+ .shuffleRead == 1000)
+
+ // finish a task with unknown executor-id, nothing should happen
+ taskInfo = new TaskInfo(1234L, 0, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.size == 1)
+
+ // finish this task, should get updated duration
+ shuffleReadMetrics.remoteBytesRead = 1000
+ taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+ taskInfo = new TaskInfo(1235L, 0, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-1", fail())
+ .shuffleRead == 2000)
+
+ // finish this task, should get updated duration
+ shuffleReadMetrics.remoteBytesRead = 1000
+ taskMetrics.shuffleReadMetrics = Some(shuffleReadMetrics)
+ taskInfo = new TaskInfo(1236L, 0, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL)
+ taskInfo.finishTime = 1
+ listener.onTaskEnd(new SparkListenerTaskEnd(
+ new ShuffleMapTask(0, null, null, 0, null), Success, taskInfo, taskMetrics))
+ assert(listener.stageIdToExecutorSummaries.getOrElse(0, fail()).getOrElse("exe-2", fail())
+ .shuffleRead == 1000)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
index 4e40dcbdee..5aff26f9fc 100644
--- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
@@ -63,54 +63,53 @@ class SizeEstimatorSuite
}
test("simple classes") {
- assert(SizeEstimator.estimate(new DummyClass1) === 16)
- assert(SizeEstimator.estimate(new DummyClass2) === 16)
- assert(SizeEstimator.estimate(new DummyClass3) === 24)
- assert(SizeEstimator.estimate(new DummyClass4(null)) === 24)
- assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48)
+ expectResult(16)(SizeEstimator.estimate(new DummyClass1))
+ expectResult(16)(SizeEstimator.estimate(new DummyClass2))
+ expectResult(24)(SizeEstimator.estimate(new DummyClass3))
+ expectResult(24)(SizeEstimator.estimate(new DummyClass4(null)))
+ expectResult(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3)))
}
// NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
// (Sun vs IBM). Use a DummyString class to make tests deterministic.
test("strings") {
- assert(SizeEstimator.estimate(DummyString("")) === 40)
- assert(SizeEstimator.estimate(DummyString("a")) === 48)
- assert(SizeEstimator.estimate(DummyString("ab")) === 48)
- assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
+ expectResult(40)(SizeEstimator.estimate(DummyString("")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("a")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("ab")))
+ expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
}
test("primitive arrays") {
- assert(SizeEstimator.estimate(new Array[Byte](10)) === 32)
- assert(SizeEstimator.estimate(new Array[Char](10)) === 40)
- assert(SizeEstimator.estimate(new Array[Short](10)) === 40)
- assert(SizeEstimator.estimate(new Array[Int](10)) === 56)
- assert(SizeEstimator.estimate(new Array[Long](10)) === 96)
- assert(SizeEstimator.estimate(new Array[Float](10)) === 56)
- assert(SizeEstimator.estimate(new Array[Double](10)) === 96)
- assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016)
- assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016)
+ expectResult(32)(SizeEstimator.estimate(new Array[Byte](10)))
+ expectResult(40)(SizeEstimator.estimate(new Array[Char](10)))
+ expectResult(40)(SizeEstimator.estimate(new Array[Short](10)))
+ expectResult(56)(SizeEstimator.estimate(new Array[Int](10)))
+ expectResult(96)(SizeEstimator.estimate(new Array[Long](10)))
+ expectResult(56)(SizeEstimator.estimate(new Array[Float](10)))
+ expectResult(96)(SizeEstimator.estimate(new Array[Double](10)))
+ expectResult(4016)(SizeEstimator.estimate(new Array[Int](1000)))
+ expectResult(8016)(SizeEstimator.estimate(new Array[Long](1000)))
}
test("object arrays") {
// Arrays containing nulls should just have one pointer per element
- assert(SizeEstimator.estimate(new Array[String](10)) === 56)
- assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56)
-
+ expectResult(56)(SizeEstimator.estimate(new Array[String](10)))
+ expectResult(56)(SizeEstimator.estimate(new Array[AnyRef](10)))
// For object arrays with non-null elements, each object should take one pointer plus
// however many bytes that class takes. (Note that Array.fill calls the code in its
// second parameter separately for each object, so we get distinct objects.)
- assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216)
- assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216)
- assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296)
- assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56)
+ expectResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)))
+ expectResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)))
+ expectResult(296)(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)))
+ expectResult(56)(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)))
// Past size 100, our samples 100 elements, but we should still get the right size.
- assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016)
+ expectResult(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)))
// If an array contains the *same* element many times, we should only count it once.
val d1 = new DummyClass1
- assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object
- assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object
+ expectResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object
+ expectResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object
// Same thing with huge array containing the same element many times. Note that this won't
// return exactly 4032 because it can't tell that *all* the elements will equal the first
@@ -128,11 +127,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- assert(SizeEstimator.estimate(DummyString("")) === 40)
- assert(SizeEstimator.estimate(DummyString("a")) === 48)
- assert(SizeEstimator.estimate(DummyString("ab")) === 48)
- assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
-
+ expectResult(40)(SizeEstimator.estimate(DummyString("")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("a")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("ab")))
+ expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
}
@@ -145,10 +143,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- assert(SizeEstimator.estimate(DummyString("")) === 56)
- assert(SizeEstimator.estimate(DummyString("a")) === 64)
- assert(SizeEstimator.estimate(DummyString("ab")) === 64)
- assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72)
+ expectResult(56)(SizeEstimator.estimate(DummyString("")))
+ expectResult(64)(SizeEstimator.estimate(DummyString("a")))
+ expectResult(64)(SizeEstimator.estimate(DummyString("ab")))
+ expectResult(72)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
resetOrClear("spark.test.useCompressedOops", oops)
diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
new file mode 100644
index 0000000000..b78367b6ca
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.Random
+import org.scalatest.FlatSpec
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.Utils.times
+
+class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
+
+ def fixture = new {
+ val seed = 1L
+ val xorRand = new XORShiftRandom(seed)
+ val hundMil = 1e8.toInt
+ }
+
+ /*
+ * This test is based on a chi-squared test for randomness. The values are hard-coded
+ * so as not to create Spark's dependency on apache.commons.math3 just to call one
+ * method for calculating the exact p-value for a given number of random numbers
+ * and bins. In case one would want to move to a full-fledged test based on
+ * apache.commons.math3, the relevant class is here:
+ * org.apache.commons.math3.stat.inference.ChiSquareTest
+ */
+ test ("XORShift generates valid random numbers") {
+
+ val f = fixture
+
+ val numBins = 10
+ // create 10 bins
+ val bins = Array.fill(numBins)(0)
+
+ // populate bins based on modulus of the random number
+ times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1}
+
+ /* since the seed is deterministic, until the algorithm is changed, we know the result will be
+ * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
+ * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%)
+ * significance level. However, should the RNG implementation change, the test should still
+ * pass at the same significance level. The chi-squared test done in R gave the following
+ * results:
+ * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
+ * 10000790, 10002286, 9998699))
+ * Chi-squared test for given probabilities
+ * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790,
+ * 10002286, 9998699)
+ * X-squared = 11.975, df = 9, p-value = 0.2147
+ * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million
+ * random numbers
+ * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared
+ * is greater than or equal to that number.
+ */
+ val binSize = f.hundMil/numBins
+ val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum
+ xSquared should be < (16.9196)
+
+ }
+
+} \ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
new file mode 100644
index 0000000000..0f1ab3d20e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+
+
+class BitSetSuite extends FunSuite {
+
+ test("basic set and get") {
+ val setBits = Seq(0, 9, 1, 10, 90, 96)
+ val bitset = new BitSet(100)
+
+ for (i <- 0 until 100) {
+ assert(!bitset.get(i))
+ }
+
+ setBits.foreach(i => bitset.set(i))
+
+ for (i <- 0 until 100) {
+ if (setBits.contains(i)) {
+ assert(bitset.get(i))
+ } else {
+ assert(!bitset.get(i))
+ }
+ }
+ assert(bitset.cardinality() === setBits.size)
+ }
+
+ test("100% full bit set") {
+ val bitset = new BitSet(10000)
+ for (i <- 0 until 10000) {
+ assert(!bitset.get(i))
+ bitset.set(i)
+ }
+ for (i <- 0 until 10000) {
+ assert(bitset.get(i))
+ }
+ assert(bitset.cardinality() === 10000)
+ }
+
+ test("nextSetBit") {
+ val setBits = Seq(0, 9, 1, 10, 90, 96)
+ val bitset = new BitSet(100)
+ setBits.foreach(i => bitset.set(i))
+
+ assert(bitset.nextSetBit(0) === 0)
+ assert(bitset.nextSetBit(1) === 1)
+ assert(bitset.nextSetBit(2) === 9)
+ assert(bitset.nextSetBit(9) === 9)
+ assert(bitset.nextSetBit(10) === 10)
+ assert(bitset.nextSetBit(11) === 90)
+ assert(bitset.nextSetBit(80) === 90)
+ assert(bitset.nextSetBit(91) === 96)
+ assert(bitset.nextSetBit(96) === 96)
+ assert(bitset.nextSetBit(97) === -1)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
new file mode 100644
index 0000000000..e9b62ea70d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.HashSet
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
+
+class OpenHashMapSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive value (int)") {
+ val capacity = 1024
+ val map = new OpenHashMap[String, Int](capacity)
+ val actualSize = SizeEstimator.estimate(map)
+ // 64 bit for pointers, 32 bit for ints, and 1 bit for the bitset.
+ val expectedSize = capacity * (64 + 32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
+
+ test("initialization") {
+ val goodMap1 = new OpenHashMap[String, Int](1)
+ assert(goodMap1.size === 0)
+ val goodMap2 = new OpenHashMap[String, Int](255)
+ assert(goodMap2.size === 0)
+ val goodMap3 = new OpenHashMap[String, String](256)
+ assert(goodMap3.size === 0)
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, Int](1 << 30) // Invalid map size: bigger than 2^29
+ }
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, Int](-1)
+ }
+ intercept[IllegalArgumentException] {
+ new OpenHashMap[String, String](0)
+ }
+ }
+
+ test("primitive value") {
+ val map = new OpenHashMap[String, Int]
+
+ for (i <- 1 to 1000) {
+ map(i.toString) = i
+ assert(map(i.toString) === i)
+ }
+
+ assert(map.size === 1000)
+ assert(map(null) === 0)
+
+ map(null) = -1
+ assert(map.size === 1001)
+ assert(map(null) === -1)
+
+ for (i <- 1 to 1000) {
+ assert(map(i.toString) === i)
+ }
+
+ // Test iterator
+ val set = new HashSet[(String, Int)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ val expected = (1 to 1000).map(x => (x.toString, x)) :+ (null.asInstanceOf[String], -1)
+ assert(set === expected.toSet)
+ }
+
+ test("non-primitive value") {
+ val map = new OpenHashMap[String, String]
+
+ for (i <- 1 to 1000) {
+ map(i.toString) = i.toString
+ assert(map(i.toString) === i.toString)
+ }
+
+ assert(map.size === 1000)
+ assert(map(null) === null)
+
+ map(null) = "-1"
+ assert(map.size === 1001)
+ assert(map(null) === "-1")
+
+ for (i <- 1 to 1000) {
+ assert(map(i.toString) === i.toString)
+ }
+
+ // Test iterator
+ val set = new HashSet[(String, String)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ val expected = (1 to 1000).map(_.toString).map(x => (x, x)) :+ (null.asInstanceOf[String], "-1")
+ assert(set === expected.toSet)
+ }
+
+ test("null keys") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ assert(map(null) === null)
+ map(null) = "hello"
+ assert(map.size === 101)
+ assert(map(null) === "hello")
+ }
+
+ test("null values") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = null
+ }
+ assert(map.size === 100)
+ assert(map("1") === null)
+ assert(map(null) === null)
+ assert(map.size === 100)
+ map(null) = null
+ assert(map.size === 101)
+ assert(map(null) === null)
+ }
+
+ test("changeValue") {
+ val map = new OpenHashMap[String, String]()
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ val res = map.changeValue(i.toString, { assert(false); "" }, v => {
+ assert(v === i.toString)
+ v + "!"
+ })
+ assert(res === i + "!")
+ }
+ // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
+ // bug where changeValue would return the wrong result when the map grew on that insert
+ for (i <- 101 to 400) {
+ val res = map.changeValue(i.toString, { i + "!" }, v => { assert(false); v })
+ assert(res === i + "!")
+ }
+ assert(map.size === 400)
+ assert(map(null) === null)
+ map.changeValue(null, { "null!" }, v => { assert(false); v })
+ assert(map.size === 401)
+ map.changeValue(null, { assert(false); "" }, v => {
+ assert(v === "null!")
+ "null!!"
+ })
+ assert(map.size === 401)
+ }
+
+ test("inserting in capacity-1 map") {
+ val map = new OpenHashMap[String, String](1)
+ for (i <- 1 to 100) {
+ map(i.toString) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map(i.toString) === i.toString)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
new file mode 100644
index 0000000000..1b24f8f287
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.util.SizeEstimator
+
+
+class OpenHashSetSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive int") {
+ val loadFactor = 0.7
+ val set = new OpenHashSet[Int](64, loadFactor)
+ for (i <- 0 until 1024) {
+ set.add(i)
+ }
+ assert(set.size === 1024)
+ assert(set.capacity > 1024)
+ val actualSize = SizeEstimator.estimate(set)
+ // 32 bits for the ints + 1 bit for the bitset
+ val expectedSize = set.capacity * (32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
+
+ test("primitive int") {
+ val set = new OpenHashSet[Int]
+ assert(set.size === 0)
+ assert(!set.contains(10))
+ assert(!set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(10)
+ assert(set.contains(10))
+ assert(!set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(50)
+ assert(set.size === 2)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(!set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(999)
+ assert(set.size === 3)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(set.contains(999))
+ assert(!set.contains(10000))
+
+ set.add(50)
+ assert(set.size === 3)
+ assert(set.contains(10))
+ assert(set.contains(50))
+ assert(set.contains(999))
+ assert(!set.contains(10000))
+ }
+
+ test("primitive long") {
+ val set = new OpenHashSet[Long]
+ assert(set.size === 0)
+ assert(!set.contains(10L))
+ assert(!set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(10L)
+ assert(set.size === 1)
+ assert(set.contains(10L))
+ assert(!set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(50L)
+ assert(set.size === 2)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(!set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(999L)
+ assert(set.size === 3)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(set.contains(999L))
+ assert(!set.contains(10000L))
+
+ set.add(50L)
+ assert(set.size === 3)
+ assert(set.contains(10L))
+ assert(set.contains(50L))
+ assert(set.contains(999L))
+ assert(!set.contains(10000L))
+ }
+
+ test("non-primitive") {
+ val set = new OpenHashSet[String]
+ assert(set.size === 0)
+ assert(!set.contains(10.toString))
+ assert(!set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(10.toString)
+ assert(set.size === 1)
+ assert(set.contains(10.toString))
+ assert(!set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(50.toString)
+ assert(set.size === 2)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(!set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(999.toString)
+ assert(set.size === 3)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+
+ set.add(50.toString)
+ assert(set.size === 3)
+ assert(set.contains(10.toString))
+ assert(set.contains(50.toString))
+ assert(set.contains(999.toString))
+ assert(!set.contains(10000.toString))
+ }
+
+ test("non-primitive set growth") {
+ val set = new OpenHashSet[String]
+ for (i <- 1 to 1000) {
+ set.add(i.toString)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ for (i <- 1 to 100) {
+ set.add(i.toString)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ }
+
+ test("primitive set growth") {
+ val set = new OpenHashSet[Long]
+ for (i <- 1 to 1000) {
+ set.add(i.toLong)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ for (i <- 1 to 100) {
+ set.add(i.toLong)
+ }
+ assert(set.size === 1000)
+ assert(set.capacity > 1000)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
new file mode 100644
index 0000000000..3b60decee9
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.HashSet
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
+
+class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive key, value (int, int)") {
+ val capacity = 1024
+ val map = new PrimitiveKeyOpenHashMap[Int, Int](capacity)
+ val actualSize = SizeEstimator.estimate(map)
+ // 32 bit for keys, 32 bit for values, and 1 bit for the bitset.
+ val expectedSize = capacity * (32 + 32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
+
+ test("initialization") {
+ val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1)
+ assert(goodMap1.size === 0)
+ val goodMap2 = new PrimitiveKeyOpenHashMap[Int, Int](255)
+ assert(goodMap2.size === 0)
+ val goodMap3 = new PrimitiveKeyOpenHashMap[Int, Int](256)
+ assert(goodMap3.size === 0)
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29
+ }
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](-1)
+ }
+ intercept[IllegalArgumentException] {
+ new PrimitiveKeyOpenHashMap[Int, Int](0)
+ }
+ }
+
+ test("basic operations") {
+ val longBase = 1000000L
+ val map = new PrimitiveKeyOpenHashMap[Long, Int]
+
+ for (i <- 1 to 1000) {
+ map(i + longBase) = i
+ assert(map(i + longBase) === i)
+ }
+
+ assert(map.size === 1000)
+
+ for (i <- 1 to 1000) {
+ assert(map(i + longBase) === i)
+ }
+
+ // Test iterator
+ val set = new HashSet[(Long, Int)]
+ for ((k, v) <- map) {
+ set.add((k, v))
+ }
+ assert(set === (1 to 1000).map(x => (x + longBase, x)).toSet)
+ }
+
+ test("null values") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String]()
+ for (i <- 1 to 100) {
+ map(i.toLong) = null
+ }
+ assert(map.size === 100)
+ assert(map(1.toLong) === null)
+ }
+
+ test("changeValue") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String]()
+ for (i <- 1 to 100) {
+ map(i.toLong) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ val res = map.changeValue(i.toLong, { assert(false); "" }, v => {
+ assert(v === i.toString)
+ v + "!"
+ })
+ assert(res === i + "!")
+ }
+ // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a
+ // bug where changeValue would return the wrong result when the map grew on that insert
+ for (i <- 101 to 400) {
+ val res = map.changeValue(i.toLong, { i + "!" }, v => { assert(false); v })
+ assert(res === i + "!")
+ }
+ assert(map.size === 400)
+ }
+
+ test("inserting in capacity-1 map") {
+ val map = new PrimitiveKeyOpenHashMap[Long, String](1)
+ for (i <- 1 to 100) {
+ map(i.toLong) = i.toString
+ }
+ assert(map.size === 100)
+ for (i <- 1 to 100) {
+ assert(map(i.toLong) === i.toString)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala
new file mode 100644
index 0000000000..970dade628
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.util.SizeEstimator
+
+class PrimitiveVectorSuite extends FunSuite {
+
+ test("primitive value") {
+ val vector = new PrimitiveVector[Int]
+
+ for (i <- 0 until 1000) {
+ vector += i
+ assert(vector(i) === i)
+ }
+
+ assert(vector.size === 1000)
+ assert(vector.size == vector.length)
+ intercept[IllegalArgumentException] {
+ vector(1000)
+ }
+
+ for (i <- 0 until 1000) {
+ assert(vector(i) == i)
+ }
+ }
+
+ test("non-primitive value") {
+ val vector = new PrimitiveVector[String]
+
+ for (i <- 0 until 1000) {
+ vector += i.toString
+ assert(vector(i) === i.toString)
+ }
+
+ assert(vector.size === 1000)
+ assert(vector.size == vector.length)
+ intercept[IllegalArgumentException] {
+ vector(1000)
+ }
+
+ for (i <- 0 until 1000) {
+ assert(vector(i) == i.toString)
+ }
+ }
+
+ test("ideal growth") {
+ val vector = new PrimitiveVector[Long](initialSize = 1)
+ vector += 1
+ for (i <- 1 until 1024) {
+ vector += i
+ assert(vector.size === i + 1)
+ assert(vector.capacity === Integer.highestOneBit(i) * 2)
+ }
+ assert(vector.capacity === 1024)
+ vector += 1024
+ assert(vector.capacity === 2048)
+ }
+
+ test("ideal size") {
+ val vector = new PrimitiveVector[Long](8192)
+ for (i <- 0 until 8192) {
+ vector += i
+ }
+ assert(vector.size === 8192)
+ assert(vector.capacity === 8192)
+ val actualSize = SizeEstimator.estimate(vector)
+ val expectedSize = 8192 * 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ // Due to specialization wonkiness, we need to ensure we don't have 2 copies of the array.
+ assert(actualSize < expectedSize * 1.1)
+ }
+
+ test("resizing") {
+ val vector = new PrimitiveVector[Long]
+ for (i <- 0 until 4097) {
+ vector += i
+ }
+ assert(vector.size === 4097)
+ assert(vector.capacity === 8192)
+ vector.trim()
+ assert(vector.size === 4097)
+ assert(vector.capacity === 4097)
+ vector.resize(5000)
+ assert(vector.size === 4097)
+ assert(vector.capacity === 5000)
+ vector.resize(4000)
+ assert(vector.size === 4000)
+ assert(vector.capacity === 4000)
+ vector.resize(5000)
+ assert(vector.size === 4000)
+ assert(vector.capacity === 5000)
+ for (i <- 0 until 4000) {
+ assert(vector(i) == i)
+ }
+ intercept[IllegalArgumentException] {
+ vector(4000)
+ }
+ }
+}
diff --git a/docker/spark-test/README.md b/docker/spark-test/README.md
index addea277aa..ec0baf6e6d 100644
--- a/docker/spark-test/README.md
+++ b/docker/spark-test/README.md
@@ -1,10 +1,11 @@
Spark Docker files usable for testing and development purposes.
These images are intended to be run like so:
-docker run -v $SPARK_HOME:/opt/spark spark-test-master
-docker run -v $SPARK_HOME:/opt/spark spark-test-worker <master_ip>
+
+ docker run -v $SPARK_HOME:/opt/spark spark-test-master
+ docker run -v $SPARK_HOME:/opt/spark spark-test-worker spark://<master_ip>:7077
Using this configuration, the containers will have their Spark directories
-mounted to your actual SPARK_HOME, allowing you to modify and recompile
+mounted to your actual `SPARK_HOME`, allowing you to modify and recompile
your Spark source and have them immediately usable in the docker images
(without rebuilding them).
diff --git a/docs/_config.yml b/docs/_config.yml
index 48ecb8d0c9..02067f9750 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -5,6 +5,6 @@ markdown: kramdown
# of Spark, Scala, and Mesos.
SPARK_VERSION: 0.9.0-incubating-SNAPSHOT
SPARK_VERSION_SHORT: 0.9.0-SNAPSHOT
-SCALA_VERSION: 2.9.3
+SCALA_VERSION: 2.10
MESOS_VERSION: 0.13.0
SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 0c1d657cde..ad7969d012 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -74,12 +74,12 @@
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">API Docs<b class="caret"></b></a>
<ul class="dropdown-menu">
- <li><a href="api/core/index.html">Spark Core for Java/Scala</a></li>
+ <li><a href="api/core/index.html#org.apache.spark.package">Spark Core for Java/Scala</a></li>
<li><a href="api/pyspark/index.html">Spark Core for Python</a></li>
<li class="divider"></li>
- <li><a href="api/streaming/index.html">Spark Streaming</a></li>
- <li><a href="api/mllib/index.html">MLlib (Machine Learning)</a></li>
- <li><a href="api/bagel/index.html">Bagel (Pregel on Spark)</a></li>
+ <li><a href="api/streaming/index.html#org.apache.spark.streaming.package">Spark Streaming</a></li>
+ <li><a href="api/mllib/index.html#org.apache.spark.mllib.package">MLlib (Machine Learning)</a></li>
+ <li><a href="api/bagel/index.html#org.apache.spark.bagel.package">Bagel (Pregel on Spark)</a></li>
</ul>
</li>
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index c574ea7f5c..431de909cb 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -35,7 +35,7 @@ if not (ENV['SKIP_API'] == '1' or ENV['SKIP_SCALADOC'] == '1')
# Copy over the scaladoc from each project into the docs directory.
# This directory will be copied over to _site when `jekyll` command is run.
projects.each do |project_name|
- source = "../" + project_name + "/target/scala-2.9.3/api"
+ source = "../" + project_name + "/target/scala-2.10/api"
dest = "api/" + project_name
puts "echo making directory " + dest
diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md
index 140190a38c..de001e6c52 100644
--- a/docs/bagel-programming-guide.md
+++ b/docs/bagel-programming-guide.md
@@ -106,7 +106,7 @@ _Example_
## Operations
-Here are the actions and types in the Bagel API. See [Bagel.scala](https://github.com/apache/incubator-spark/blob/master/bagel/src/main/scala/spark/bagel/Bagel.scala) for details.
+Here are the actions and types in the Bagel API. See [Bagel.scala](https://github.com/apache/incubator-spark/blob/master/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala) for details.
### Actions
diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md
index 19c01e179f..c709001632 100644
--- a/docs/building-with-maven.md
+++ b/docs/building-with-maven.md
@@ -45,6 +45,12 @@ For Apache Hadoop 2.x, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions with
# Cloudera CDH 4.2.0 with MapReduce v2
$ mvn -Phadoop2-yarn -Dhadoop.version=2.0.0-cdh4.2.0 -Dyarn.version=2.0.0-chd4.2.0 -DskipTests clean package
+Hadoop versions 2.2.x and newer can be built by setting the ```new-yarn``` and the ```yarn.version``` as follows:
+
+ # Apache Hadoop 2.2.X and newer
+ $ mvn -Dyarn.version=2.2.0 -Dhadoop.version=2.2.0 -Pnew-yarn
+
+The build process handles Hadoop 2.2.x as a special case that uses the directory ```new-yarn```, which supports the new YARN API. Furthermore, for this version, the build depends on artifacts published by the spark-project to enable Akka 2.0.5 to work with protobuf 2.5.
## Spark Tests in Maven ##
diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md
index f679cad713..e16703292c 100644
--- a/docs/cluster-overview.md
+++ b/docs/cluster-overview.md
@@ -13,7 +13,7 @@ object in your main program (called the _driver program_).
Specifically, to run on a cluster, the SparkContext can connect to several types of _cluster managers_
(either Spark's own standalone cluster manager or Mesos/YARN), which allocate resources across
applications. Once connected, Spark acquires *executors* on nodes in the cluster, which are
-worker processes that run computations and store data for your application.
+worker processes that run computations and store data for your application.
Next, it sends your application code (defined by JAR or Python files passed to SparkContext) to
the executors. Finally, SparkContext sends *tasks* for the executors to run.
@@ -45,7 +45,7 @@ The system currently supports three cluster managers:
easy to set up a cluster.
* [Apache Mesos](running-on-mesos.html) -- a general cluster manager that can also run Hadoop MapReduce
and service applications.
-* [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2.0.
+* [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2.
In addition, Spark's [EC2 launch scripts](ec2-scripts.html) make it easy to launch a standalone
cluster on Amazon EC2.
@@ -57,6 +57,18 @@ which takes a list of JAR files (Java/Scala) or .egg and .zip libraries (Python)
worker nodes. You can also dynamically add new files to be sent to executors with `SparkContext.addJar`
and `addFile`.
+## URIs for addJar / addFile
+
+- **file:** - Absolute paths and `file:/` URIs are served by the driver's HTTP file server, and every executor
+ pulls the file from the driver HTTP server
+- **hdfs:**, **http:**, **https:**, **ftp:** - these pull down files and JARs from the URI as expected
+- **local:** - a URI starting with local:/ is expected to exist as a local file on each worker node. This
+ means that no network IO will be incurred, and works well for large files/JARs that are pushed to each worker,
+ or shared via NFS, GlusterFS, etc.
+
+Note that JARs and files are copied to the working directory for each SparkContext on the executor nodes.
+Over time this can use up a significant amount of space and will need to be cleaned up.
+
# Monitoring
Each driver program has a web UI, typically on port 4040, that displays information about running
diff --git a/docs/configuration.md b/docs/configuration.md
index 7940d41a27..677d182e50 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -149,7 +149,7 @@ Apart from these, the following properties are also available, and may be useful
<td>spark.io.compression.codec</td>
<td>org.apache.spark.io.<br />LZFCompressionCodec</td>
<td>
- The compression codec class to use for various compressions. By default, Spark provides two
+ The codec used to compress internal data such as RDD partitions and shuffle outputs. By default, Spark provides two
codecs: <code>org.apache.spark.io.LZFCompressionCodec</code> and <code>org.apache.spark.io.SnappyCompressionCodec</code>.
</td>
</tr>
@@ -275,12 +275,33 @@ Apart from these, the following properties are also available, and may be useful
</tr>
<tr>
<td>spark.akka.timeout</td>
- <td>20</td>
+ <td>100</td>
<td>
Communication timeout between Spark nodes, in seconds.
</td>
</tr>
<tr>
+ <td>spark.akka.heartbeat.pauses</td>
+ <td>600</td>
+ <td>
+ This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause in seconds for akka. This can be used to control sensitivity to gc pauses. Tune this in combination of `spark.akka.heartbeat.interval` and `spark.akka.failure-detector.threshold` if you need to.
+ </td>
+</tr>
+<tr>
+ <td>spark.akka.failure-detector.threshold</td>
+ <td>300.0</td>
+ <td>
+ This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). This maps to akka's `akka.remote.transport-failure-detector.threshold`. Tune this in combination of `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to.
+ </td>
+</tr>
+<tr>
+ <td>spark.akka.heartbeat.interval</td>
+ <td>1000</td>
+ <td>
+ This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger interval value in seconds reduces network overhead and a smaller value ( ~ 1 s) might be more informative for akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` and `spark.akka.failure-detector.threshold` if you need to. Only positive use case for using failure detector can be, a sensistive failure detector can help evict rogue executors really quick. However this is usually not the case as gc pauses and network lags are expected in a real spark cluster. Apart from that enabling this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those.
+ </td>
+</tr>
+<tr>
<td>spark.driver.host</td>
<td>(local hostname)</td>
<td>
@@ -319,7 +340,49 @@ Apart from these, the following properties are also available, and may be useful
Should be greater than or equal to 1. Number of allowed retries = this value - 1.
</td>
</tr>
-
+<tr>
+ <td>spark.broadcast.blockSize</td>
+ <td>4096</td>
+ <td>
+ Size of each piece of a block in kilobytes for <code>TorrentBroadcastFactory</code>.
+ Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, <code>BlockManager</code> might take a performance hit.
+ </td>
+</tr>
+<tr>
+ <td>spark.shuffle.consolidateFiles</td>
+ <td>false</td>
+ <td>
+ If set to "true", consolidates intermediate files created during a shuffle. Creating fewer files can improve filesystem performance for shuffles with large numbers of reduce tasks. It is recommended to set this to "true" when using ext4 or xfs filesystems. On ext3, this option might degrade performance on machines with many (>8) cores due to filesystem limitations.
+ </td>
+</tr>
+<tr>
+ <td>spark.speculation</td>
+ <td>false</td>
+ <td>
+ If set to "true", performs speculative execution of tasks. This means if one or more tasks are running slowly in a stage, they will be re-launched.
+ </td>
+</tr>
+<tr>
+ <td>spark.speculation.interval</td>
+ <td>100</td>
+ <td>
+ How often Spark will check for tasks to speculate, in milliseconds.
+ </td>
+</tr>
+<tr>
+ <td>spark.speculation.quantile</td>
+ <td>0.75</td>
+ <td>
+ Percentage of tasks which must be complete before speculation is enabled for a particular stage.
+ </td>
+</tr>
+<tr>
+ <td>spark.speculation.multiplier</td>
+ <td>1.5</td>
+ <td>
+ How many times slower a task is than the median to be considered for speculation.
+ </td>
+</tr>
</table>
# Environment Variables
diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md
index 1e5575d657..156a727026 100644
--- a/docs/ec2-scripts.md
+++ b/docs/ec2-scripts.md
@@ -98,7 +98,7 @@ permissions on your private key file, you can run `launch` with the
`bin/hadoop` script in that directory. Note that the data in this
HDFS goes away when you stop and restart a machine.
- There is also a *persistent HDFS* instance in
- `/root/presistent-hdfs` that will keep data across cluster restarts.
+ `/root/persistent-hdfs` that will keep data across cluster restarts.
Typically each node has relatively little space of persistent data
(about 3 GB), but you can use the `--ebs-vol-size` option to
`spark-ec2` to attach a persistent EBS volume to each node for
diff --git a/docs/hadoop-third-party-distributions.md b/docs/hadoop-third-party-distributions.md
index f706625fe9..de6a2b0a43 100644
--- a/docs/hadoop-third-party-distributions.md
+++ b/docs/hadoop-third-party-distributions.md
@@ -10,7 +10,7 @@ with these distributions:
# Compile-time Hadoop Version
When compiling Spark, you'll need to
-[set the SPARK_HADOOP_VERSION flag](http://localhost:4000/index.html#a-note-about-hadoop-versions):
+[set the SPARK_HADOOP_VERSION flag](index.html#a-note-about-hadoop-versions):
SPARK_HADOOP_VERSION=1.0.4 sbt/sbt assembly
@@ -25,8 +25,8 @@ the _exact_ Hadoop version you are running to avoid any compatibility errors.
<h3>CDH Releases</h3>
<table class="table" style="width:350px; margin-right: 20px;">
<tr><th>Release</th><th>Version code</th></tr>
- <tr><td>CDH 4.X.X (YARN mode)</td><td>2.0.0-chd4.X.X</td></tr>
- <tr><td>CDH 4.X.X</td><td>2.0.0-mr1-chd4.X.X</td></tr>
+ <tr><td>CDH 4.X.X (YARN mode)</td><td>2.0.0-cdh4.X.X</td></tr>
+ <tr><td>CDH 4.X.X</td><td>2.0.0-mr1-cdh4.X.X</td></tr>
<tr><td>CDH 3u6</td><td>0.20.2-cdh3u6</td></tr>
<tr><td>CDH 3u5</td><td>0.20.2-cdh3u5</td></tr>
<tr><td>CDH 3u4</td><td>0.20.2-cdh3u4</td></tr>
@@ -40,6 +40,7 @@ the _exact_ Hadoop version you are running to avoid any compatibility errors.
<tr><td>HDP 1.2</td><td>1.1.2</td></tr>
<tr><td>HDP 1.1</td><td>1.0.3</td></tr>
<tr><td>HDP 1.0</td><td>1.0.3</td></tr>
+ <tr><td>HDP 2.0</td><td>2.2.0</td></tr>
</table>
</td>
</tr>
diff --git a/docs/index.md b/docs/index.md
index bd386a8a8f..d3ac696d1e 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -56,14 +56,16 @@ Hadoop, you must build Spark against the same version that your cluster uses.
By default, Spark links to Hadoop 1.0.4. You can change this by setting the
`SPARK_HADOOP_VERSION` variable when compiling:
- SPARK_HADOOP_VERSION=1.2.1 sbt/sbt assembly
+ SPARK_HADOOP_VERSION=2.2.0 sbt/sbt assembly
-In addition, if you wish to run Spark on [YARN](running-on-yarn.md), set
+In addition, if you wish to run Spark on [YARN](running-on-yarn.html), set
`SPARK_YARN` to `true`:
SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_YARN=true sbt/sbt assembly
-(Note that on Windows, you need to set the environment variables on separate lines, e.g., `set SPARK_HADOOP_VERSION=1.2.1`.)
+Note that on Windows, you need to set the environment variables on separate lines, e.g., `set SPARK_HADOOP_VERSION=1.2.1`.
+
+For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to build Spark and publish it locally. See [Launching Spark on YARN](running-on-yarn.html). This is needed because Hadoop 2.2 has non backwards compatible API changes.
# Where to Go from Here
diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md
index d304c5497b..dbcb9ae343 100644
--- a/docs/job-scheduling.md
+++ b/docs/job-scheduling.md
@@ -91,7 +91,7 @@ The fair scheduler also supports grouping jobs into _pools_, and setting differe
(e.g. weight) for each pool. This can be useful to create a "high-priority" pool for more important jobs,
for example, or to group the jobs of each user together and give _users_ equal shares regardless of how
many concurrent jobs they have instead of giving _jobs_ equal shares. This approach is modeled after the
-[Hadoop Fair Scheduler](http://hadoop.apache.org/docs/stable/fair_scheduler.html).
+[Hadoop Fair Scheduler](http://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/FairScheduler.html).
Without any intervention, newly submitted jobs go into a _default pool_, but jobs' pools can be set by
adding the `spark.scheduler.pool` "local property" to the SparkContext in the thread that's submitting them.
diff --git a/docs/monitoring.md b/docs/monitoring.md
index 5f456b999b..5ed0474477 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -50,6 +50,7 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the
* `GangliaSink`: Sends metrics to a Ganglia node or multicast group.
* `JmxSink`: Registers metrics for viewing in a JXM console.
* `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data.
+* `GraphiteSink`: Sends metrics to a Graphite node.
The syntax of the metrics configuration file is defined in an example configuration file,
`$SPARK_HOME/conf/metrics.conf.template`.
diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md
index 6c2336ad0c..55e39b1de1 100644
--- a/docs/python-programming-guide.md
+++ b/docs/python-programming-guide.md
@@ -131,6 +131,17 @@ sc = SparkContext("local", "App Name", pyFiles=['MyFile.py', 'lib.zip', 'app.egg
Files listed here will be added to the `PYTHONPATH` and shipped to remote worker machines.
Code dependencies can be added to an existing SparkContext using its `addPyFile()` method.
+You can set [system properties](configuration.html#system-properties)
+using `SparkContext.setSystemProperty()` class method *before*
+instantiating SparkContext. For example, to set the amount of memory
+per executor process:
+
+{% highlight python %}
+from pyspark import SparkContext
+SparkContext.setSystemProperty('spark.executor.memory', '2g')
+sc = SparkContext("local", "App Name")
+{% endhighlight %}
+
# API Docs
[API documentation](api/pyspark/index.html) for PySpark is available as Epydoc.
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 2898af0bed..aa75ca4324 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -17,10 +17,12 @@ This can be built by setting the Hadoop version and `SPARK_YARN` environment var
The assembled JAR will be something like this:
`./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly_{{site.SPARK_VERSION}}-hadoop2.0.5.jar`.
+The build process now also supports new YARN versions (2.2.x). See below.
# Preparations
- Building a YARN-enabled assembly (see above).
+- The assembled jar can be installed into HDFS or used locally.
- Your application code must be packaged into a separate JAR file.
If you want to test out the YARN deployment mode, you can use the current Spark examples. A `spark-examples_{{site.SCALA_VERSION}}-{{site.SPARK_VERSION}}` file can be generated by running `sbt/sbt assembly`. NOTE: since the documentation you're reading is for Spark version {{site.SPARK_VERSION}}, we are assuming here that you have downloaded Spark {{site.SPARK_VERSION}} or checked it out of source control. If you are using a different version of Spark, the version numbers in the jar generated by the sbt package command will obviously be different.
@@ -30,18 +32,26 @@ If you want to test out the YARN deployment mode, you can use the current Spark
Most of the configs are the same for Spark on YARN as other deploys. See the Configuration page for more information on those. These are configs that are specific to SPARK on YARN.
Environment variables:
+
* `SPARK_YARN_USER_ENV`, to add environment variables to the Spark processes launched on YARN. This can be a comma separated list of environment variables, e.g. `SPARK_YARN_USER_ENV="JAVA_HOME=/jdk64,FOO=bar"`.
System Properties:
-* 'spark.yarn.applicationMaster.waitTries', property to set the number of times the ApplicationMaster waits for the the spark master and then also the number of tries it waits for the Spark Context to be intialized. Default is 10.
-* 'spark.yarn.submit.file.replication', the HDFS replication level for the files uploaded into HDFS for the application. These include things like the spark jar, the app jar, and any distributed cache files/archives.
-* 'spark.yarn.preserve.staging.files', set to true to preserve the staged files(spark jar, app jar, distributed cache files) at the end of the job rather then delete them.
+
+* `spark.yarn.applicationMaster.waitTries`, property to set the number of times the ApplicationMaster waits for the the spark master and then also the number of tries it waits for the Spark Context to be intialized. Default is 10.
+* `spark.yarn.submit.file.replication`, the HDFS replication level for the files uploaded into HDFS for the application. These include things like the spark jar, the app jar, and any distributed cache files/archives.
+* `spark.yarn.preserve.staging.files`, set to true to preserve the staged files(spark jar, app jar, distributed cache files) at the end of the job rather then delete them.
+* `spark.yarn.scheduler.heartbeat.interval-ms`, the interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. Default is 5 seconds.
+* `spark.yarn.max.worker.failures`, the maximum number of worker failures before failing the application. Default is the number of workers requested times 2 with minimum of 3.
# Launching Spark on YARN
Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster.
This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager.
+There are two scheduler mode that can be used to launch spark application on YARN.
+
+## Launch spark application by YARN Client with yarn-standalone mode.
+
The command to launch the YARN Client is as follows:
SPARK_JAR=<SPARK_ASSEMBLY_JAR_FILE> ./spark-class org.apache.spark.deploy.yarn.Client \
@@ -49,6 +59,7 @@ The command to launch the YARN Client is as follows:
--class <APP_MAIN_CLASS> \
--args <APP_MAIN_ARGUMENTS> \
--num-workers <NUMBER_OF_WORKER_MACHINES> \
+ --master-class <ApplicationMaster_CLASS>
--master-memory <MEMORY_FOR_MASTER> \
--worker-memory <MEMORY_PER_WORKER> \
--worker-cores <CORES_PER_WORKER> \
@@ -82,12 +93,37 @@ For example:
$ cat $YARN_APP_LOGS_DIR/$YARN_APP_ID/container*_000001/stdout
Pi is roughly 3.13794
-The above starts a YARN Client programs which periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running.
+The above starts a YARN Client programs which start the default Application Master. Then SparkPi will be run as a child thread of Application Master, YARN Client will periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running.
+
+With this mode, your application is actually run on the remote machine where the Application Master is run upon. Thus application that involve local interaction will not work well, e.g. spark-shell.
+
+## Launch spark application with yarn-client mode.
+
+With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR
+
+In order to tune worker core/number/memory etc. You need to export SPARK_WORKER_CORES, SPARK_WORKER_MEMORY, SPARK_WORKER_INSTANCES e.g. by ./conf/spark-env.sh
+
+For example:
+
+ SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \
+ SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
+ ./run-example org.apache.spark.examples.SparkPi yarn-client
+
+
+ SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \
+ SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
+ MASTER=yarn-client ./spark-shell
+
+# Building Spark for Hadoop/YARN 2.2.x
+
+Hadoop 2.2.x users must build Spark and publish it locally. The SBT build process handles Hadoop 2.2.x as a special case. This version of Hadoop has new YARN API changes and depends on a Protobuf version (2.5) that is not compatible with the Akka version (2.0.5) that Spark uses. Therefore, if the Hadoop version (e.g. set through ```SPARK_HADOOP_VERSION```) starts with 2.2.0 or higher then the build process will depend on Akka artifacts distributed by the Spark project compatible with Protobuf 2.5. Furthermore, the build process then uses the directory ```new-yarn``` (instead of ```yarn```), which supports the new YARN API. The build process should seamlessly work out of the box.
+
+See [Building Spark with Maven](building-with-maven.html) for instructions on how to build Spark using the Maven process.
# Important Notes
-- When your application instantiates a Spark context it must use a special "yarn-standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "yarn-standalone" as an argument to your program, as shown in the example above.
- We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed.
- The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored.
- The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN.
- The --addJars option allows the SparkContext.addJar function to work if you are using it with local files. It does not need to be used if you are using it with HDFS, HTTP, HTTPS, or FTP files.
+- YARN 2.2.x users cannot simply depend on the Spark packages without building Spark, as the published Spark artifacts are compiled to work with the pre 2.2 API. Those users must build Spark and publish it locally.
diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md
index 03647a2ad2..56d2a3a4a0 100644
--- a/docs/scala-programming-guide.md
+++ b/docs/scala-programming-guide.md
@@ -142,7 +142,7 @@ All transformations in Spark are <i>lazy</i>, in that they do not compute their
By default, each transformed RDD is recomputed each time you run an action on it. However, you may also *persist* an RDD in memory using the `persist` (or `cache`) method, in which case Spark will keep the elements around on the cluster for much faster access the next time you query it. There is also support for persisting datasets on disk, or replicated across the cluster. The next section in this document describes these options.
-The following tables list the transformations and actions currently supported (see also the [RDD API doc](api/core/index.html#org.apache.spark.RDD) for details):
+The following tables list the transformations and actions currently supported (see also the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD) for details):
### Transformations
@@ -211,7 +211,7 @@ The following tables list the transformations and actions currently supported (s
</tr>
</table>
-A complete list of transformations is available in the [RDD API doc](api/core/index.html#org.apache.spark.RDD).
+A complete list of transformations is available in the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD).
### Actions
@@ -259,7 +259,7 @@ A complete list of transformations is available in the [RDD API doc](api/core/in
</tr>
</table>
-A complete list of actions is available in the [RDD API doc](api/core/index.html#org.apache.spark.RDD).
+A complete list of actions is available in the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD).
## RDD Persistence
@@ -363,7 +363,7 @@ res2: Int = 10
# Where to Go from Here
-You can see some [example Spark programs](http://www.spark-project.org/examples.html) on the Spark website.
+You can see some [example Spark programs](http://spark.incubator.apache.org/examples.html) on the Spark website.
In addition, Spark includes several samples in `examples/src/main/scala`. Some of them have both Spark versions and local (non-parallel) versions, allowing you to see what had to be changed to make the program run on a cluster. You can run them using by passing the class name to the `run-example` script included in Spark; for example:
./run-example org.apache.spark.examples.SparkPi
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 17066ef0dd..b822265b5a 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -51,11 +51,11 @@ Finally, the following configuration options can be passed to the master and wor
</tr>
<tr>
<td><code>-c CORES</code>, <code>--cores CORES</code></td>
- <td>Total CPU cores to allow Spark applicatons to use on the machine (default: all available); only on worker</td>
+ <td>Total CPU cores to allow Spark applications to use on the machine (default: all available); only on worker</td>
</tr>
<tr>
<td><code>-m MEM</code>, <code>--memory MEM</code></td>
- <td>Total amount of memory to allow Spark applicatons to use on the machine, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker</td>
+ <td>Total amount of memory to allow Spark applications to use on the machine, in a format like 1000M or 2G (default: your machine's total RAM minus 1 GB); only on worker</td>
</tr>
<tr>
<td><code>-d DIR</code>, <code>--work-dir DIR</code></td>
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 835b257238..82f42e0b8d 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -73,6 +73,10 @@ DStreams support many of the transformations available on normal Spark RDD's:
Iterator[T] => Iterator[U] when running on an DStream of type T. </td>
</tr>
<tr>
+ <td> <b>repartition</b>(<i>numPartitions</i>) </td>
+ <td> Changes the level of parallelism in this DStream by creating more or fewer partitions. </td>
+</tr>
+<tr>
<td> <b>union</b>(<i>otherStream</i>) </td>
<td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td>
</tr>
@@ -210,7 +214,7 @@ ssc.stop()
{% endhighlight %}
# Example
-A simple example to start off is the [NetworkWordCount](https://github.com/apache/incubator-spark/tree/master/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala). This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in `<Spark repo>/streaming/src/main/scala/spark/streaming/examples/NetworkWordCount.scala` .
+A simple example to start off is the [NetworkWordCount](https://github.com/apache/incubator-spark/tree/master/examples/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala). This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in `<Spark repo>/streaming/src/main/scala/org/apache/spark/streaming/examples/NetworkWordCount.scala` .
{% highlight scala %}
import org.apache.spark.streaming.{Seconds, StreamingContext}
@@ -279,7 +283,7 @@ Time: 1357008430000 ms
</td>
</table>
-You can find more examples in `<Spark repo>/streaming/src/main/scala/spark/streaming/examples/`. They can be run in the similar manner using `./run-example org.apache.spark.streaming.examples....` . Executing without any parameter would give the required parameter list. Further explanation to run them can be found in comments in the files.
+You can find more examples in `<Spark repo>/streaming/src/main/scala/org/apache/spark/streaming/examples/`. They can be run in the similar manner using `./run-example org.apache.spark.streaming.examples....` . Executing without any parameter would give the required parameter list. Further explanation to run them can be found in comments in the files.
# DStream Persistence
Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`.
@@ -479,7 +483,7 @@ Similar to [Spark's Java API](java-programming-guide.html), we also provide a Ja
1. Functions for transformations must be implemented as subclasses of [Function](api/core/index.html#org.apache.spark.api.java.function.Function) and [Function2](api/core/index.html#org.apache.spark.api.java.function.Function2)
1. Unlike the Scala API, the Java API handles DStreams for key-value pairs using a separate [JavaPairDStream](api/streaming/index.html#org.apache.spark.streaming.api.java.JavaPairDStream) class(similar to [JavaRDD and JavaPairRDD](java-programming-guide.html#rdd-classes). DStream functions like `map` and `filter` are implemented separately by JavaDStreams and JavaPairDStream to return DStreams of appropriate types.
-Spark's [Java Programming Guide](java-programming-guide.html) gives more ideas about using the Java API. To extends the ideas presented for the RDDs to DStreams, we present parts of the Java version of the same NetworkWordCount example presented above. The full source code is given at `<spark repo>/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java`
+Spark's [Java Programming Guide](java-programming-guide.html) gives more ideas about using the Java API. To extends the ideas presented for the RDDs to DStreams, we present parts of the Java version of the same NetworkWordCount example presented above. The full source code is given at `<spark repo>/examples/src/main/java/org/apache/spark/streaming/examples/JavaNetworkWordCount.java`
The streaming context and the socket stream from input source is started by using a `JavaStreamingContext`, that has the same parameters and provides the same input streams as its Scala counterpart.
@@ -523,5 +527,5 @@ JavaPairDStream<String, Integer> wordCounts = words.map(
# Where to Go from Here
* API docs - [Scala](api/streaming/index.html#org.apache.spark.streaming.package) and [Java](api/streaming/index.html#org.apache.spark.streaming.api.java.package)
-* More examples - [Scala](https://github.com/apache/incubator-spark/tree/master/examples/src/main/scala/spark/streaming/examples) and [Java](https://github.com/apache/incubator-spark/tree/master/examples/src/main/java/spark/streaming/examples)
+* More examples - [Scala](https://github.com/apache/incubator-spark/tree/master/examples/src/main/scala/org/apache/spark/streaming/examples) and [Java](https://github.com/apache/incubator-spark/tree/master/examples/src/main/java/org/apache/spark/streaming/examples)
* [Paper describing Spark Streaming](http://www.eecs.berkeley.edu/Pubs/TechRpts/2012/EECS-2012-259.pdf)
diff --git a/docs/tuning.md b/docs/tuning.md
index f491ae9b95..a4be188169 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -39,7 +39,8 @@ in your operations) and performance. It provides two serialization libraries:
for best performance.
You can switch to using Kryo by calling `System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")`
-*before* creating your SparkContext. The only reason it is not the default is because of the custom
+*before* creating your SparkContext. This setting configures the serializer used for not only shuffling data between worker
+nodes but also when serializing RDDs to disk. The only reason Kryo is not the default is because of the custom
registration requirement, but we recommend trying it in any network-intensive application.
Finally, to register your classes with Kryo, create a public class that extends
@@ -67,7 +68,7 @@ The [Kryo documentation](http://code.google.com/p/kryo/) describes more advanced
registration options, such as adding custom serialization code.
If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb`
-system property. The default is 32, but this value needs to be large enough to hold the *largest*
+system property. The default is 2, but this value needs to be large enough to hold the *largest*
object you will serialize.
Finally, if you don't register your classes, Kryo will still work, but it will have to store the
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 65868b76b9..a2b0e7e7f4 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -72,12 +72,12 @@ def parse_args():
parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use")
parser.add_option("-v", "--spark-version", default="0.8.0",
help="Version of Spark to use: 'X.Y.Z' or a specific git hash")
- parser.add_option("--spark-git-repo",
- default="https://github.com/mesos/spark",
+ parser.add_option("--spark-git-repo",
+ default="https://github.com/apache/incubator-spark",
help="Github repo from which to checkout supplied commit hash")
parser.add_option("--hadoop-major-version", default="1",
help="Major version of Hadoop (default: 1)")
- parser.add_option("-D", metavar="[ADDRESS:]PORT", dest="proxy_port",
+ parser.add_option("-D", metavar="[ADDRESS:]PORT", dest="proxy_port",
help="Use SSH dynamic port forwarding to create a SOCKS proxy at " +
"the given local address (for use with login)")
parser.add_option("--resume", action="store_true", default=False,
@@ -101,6 +101,8 @@ def parse_args():
help="The SSH user you want to connect as (default: root)")
parser.add_option("--delete-groups", action="store_true", default=False,
help="When destroying a cluster, delete the security groups that were created")
+ parser.add_option("--use-existing-master", action="store_true", default=False,
+ help="Launch fresh slaves, but use an existing stopped master if possible")
(opts, args) = parser.parse_args()
if len(args) != 2:
@@ -191,7 +193,7 @@ def get_spark_ami(opts):
instance_type = "pvm"
print >> stderr,\
"Don't recognize %s, assuming type is pvm" % opts.instance_type
-
+
ami_path = "%s/%s/%s" % (AMI_PREFIX, opts.region, instance_type)
try:
ami = urllib2.urlopen(ami_path).read().strip()
@@ -215,6 +217,7 @@ def launch_cluster(conn, opts, cluster_name):
master_group.authorize(src_group=slave_group)
master_group.authorize('tcp', 22, 22, '0.0.0.0/0')
master_group.authorize('tcp', 8080, 8081, '0.0.0.0/0')
+ master_group.authorize('tcp', 19999, 19999, '0.0.0.0/0')
master_group.authorize('tcp', 50030, 50030, '0.0.0.0/0')
master_group.authorize('tcp', 50070, 50070, '0.0.0.0/0')
master_group.authorize('tcp', 60070, 60070, '0.0.0.0/0')
@@ -232,9 +235,9 @@ def launch_cluster(conn, opts, cluster_name):
slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0')
# Check if instances are already running in our groups
- active_nodes = get_existing_cluster(conn, opts, cluster_name,
- die_on_error=False)
- if any(active_nodes):
+ existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name,
+ die_on_error=False)
+ if existing_slaves or (existing_masters and not opts.use_existing_master):
print >> stderr, ("ERROR: There are already instances running in " +
"group %s or %s" % (master_group.name, slave_group.name))
sys.exit(1)
@@ -335,21 +338,28 @@ def launch_cluster(conn, opts, cluster_name):
zone, slave_res.id)
i += 1
- # Launch masters
- master_type = opts.master_instance_type
- if master_type == "":
- master_type = opts.instance_type
- if opts.zone == 'all':
- opts.zone = random.choice(conn.get_all_zones()).name
- master_res = image.run(key_name = opts.key_pair,
- security_groups = [master_group],
- instance_type = master_type,
- placement = opts.zone,
- min_count = 1,
- max_count = 1,
- block_device_map = block_map)
- master_nodes = master_res.instances
- print "Launched master in %s, regid = %s" % (zone, master_res.id)
+ # Launch or resume masters
+ if existing_masters:
+ print "Starting master..."
+ for inst in existing_masters:
+ if inst.state not in ["shutting-down", "terminated"]:
+ inst.start()
+ master_nodes = existing_masters
+ else:
+ master_type = opts.master_instance_type
+ if master_type == "":
+ master_type = opts.instance_type
+ if opts.zone == 'all':
+ opts.zone = random.choice(conn.get_all_zones()).name
+ master_res = image.run(key_name = opts.key_pair,
+ security_groups = [master_group],
+ instance_type = master_type,
+ placement = opts.zone,
+ min_count = 1,
+ max_count = 1,
+ block_device_map = block_map)
+ master_nodes = master_res.instances
+ print "Launched master in %s, regid = %s" % (zone, master_res.id)
# Return all the instances
return (master_nodes, slave_nodes)
@@ -403,8 +413,8 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
print slave.public_dns_name
ssh_write(slave.public_dns_name, opts, ['tar', 'x'], dot_ssh_tar)
- modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs',
- 'mapreduce', 'spark-standalone']
+ modules = ['spark', 'shark', 'ephemeral-hdfs', 'persistent-hdfs',
+ 'mapreduce', 'spark-standalone', 'tachyon']
if opts.hadoop_major_version == "1":
modules = filter(lambda x: x != "mapreduce", modules)
@@ -579,7 +589,7 @@ def ssh(host, opts, command):
while True:
try:
return subprocess.check_call(
- ssh_command(opts) + ['-t', '%s@%s' % (opts.user, host), stringify_command(command)])
+ ssh_command(opts) + ['-t', '-t', '%s@%s' % (opts.user, host), stringify_command(command)])
except subprocess.CalledProcessError as e:
if (tries > 2):
# If this was an ssh failure, provide the user with hints.
@@ -668,12 +678,12 @@ def real_main():
print "Terminating slaves..."
for inst in slave_nodes:
inst.terminate()
-
+
# Delete security groups as well
if opts.delete_groups:
print "Deleting security groups (this will take some time)..."
group_names = [cluster_name + "-master", cluster_name + "-slaves"]
-
+
attempt = 1;
while attempt <= 3:
print "Attempt %d" % attempt
@@ -720,7 +730,7 @@ def real_main():
if opts.proxy_port != None:
proxy_opt = ['-D', opts.proxy_port]
subprocess.check_call(
- ssh_command(opts) + proxy_opt + ['-t', "%s@%s" % (opts.user, master)])
+ ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)])
elif action == "get-master":
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
@@ -731,6 +741,7 @@ def real_main():
cluster_name + "?\nDATA ON EPHEMERAL DISKS WILL BE LOST, " +
"BUT THE CLUSTER WILL KEEP USING SPACE ON\n" +
"AMAZON EBS IF IT IS EBS-BACKED!!\n" +
+ "All data on spot-instance slaves will be lost.\n" +
"Stop cluster " + cluster_name + " (y/N): ")
if response == "y":
(master_nodes, slave_nodes) = get_existing_cluster(
@@ -742,7 +753,10 @@ def real_main():
print "Stopping slaves..."
for inst in slave_nodes:
if inst.state not in ["shutting-down", "terminated"]:
- inst.stop()
+ if inst.spot_instance_request_id:
+ inst.terminate()
+ else:
+ inst.stop()
elif action == "start":
(master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name)
diff --git a/examples/pom.xml b/examples/pom.xml
index b8c020a321..7a7032c319 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -26,41 +26,48 @@
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-examples_2.9.3</artifactId>
+ <artifactId>spark-examples_2.10</artifactId>
<packaging>jar</packaging>
<name>Spark Project Examples</name>
<url>http://spark.incubator.apache.org/</url>
<repositories>
- <!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
<repository>
- <id>lib</id>
- <url>file://${project.basedir}/lib</url>
+ <id>apache-repo</id>
+ <name>Apache Repository</name>
+ <url>https://repository.apache.org/content/repositories/releases</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
</repository>
</repositories>
+
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core_2.9.3</artifactId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-streaming_2.9.3</artifactId>
+ <artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-mllib_2.9.3</artifactId>
+ <artifactId>spark-mllib_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-bagel_2.9.3</artifactId>
+ <artifactId>spark-bagel_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
@@ -80,10 +87,19 @@
</exclusions>
</dependency>
<dependency>
- <groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
- <scope>provided</scope>
+ <groupId>com.sksamuel.kafka</groupId>
+ <artifactId>kafka_${scala.binary.version}</artifactId>
+ <version>0.8.0-beta1</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.sun.jmx</groupId>
+ <artifactId>jmxri</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.sun.jdmk</groupId>
+ <artifactId>jmxtools</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
@@ -91,23 +107,23 @@
</dependency>
<dependency>
<groupId>com.twitter</groupId>
- <artifactId>algebird-core_2.9.2</artifactId>
+ <artifactId>algebird-core_${scala.binary.version}</artifactId>
<version>0.1.11</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_2.9.3</artifactId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_2.9.3</artifactId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.cassandra</groupId>
<artifactId>cassandra-all</artifactId>
- <version>1.2.5</version>
+ <version>1.2.6</version>
<exclusions>
<exclusion>
<groupId>com.google.guava</groupId>
@@ -137,13 +153,21 @@
<groupId>org.apache.cassandra.deps</groupId>
<artifactId>avro</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.xerial.snappy</groupId>
+ <artifactId>*</artifactId>
+ </exclusion>
</exclusions>
</dependency>
</dependencies>
<build>
- <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
- <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
index 152f029213..407cd7ccfa 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java
@@ -123,7 +123,7 @@ public class JavaLogQuery {
});
List<Tuple2<Tuple3<String, String, String>, Stats>> output = counts.collect();
- for (Tuple2 t : output) {
+ for (Tuple2<?,?> t : output) {
System.out.println(t._1 + "\t" + t._2);
}
System.exit(0);
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
index c5603a639b..89aed8f279 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
@@ -21,7 +21,6 @@ import scala.Tuple2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
@@ -106,7 +105,7 @@ public class JavaPageRank {
// Collects all URL ranks and dump them to console.
List<Tuple2<String, Double>> output = ranks.collect();
- for (Tuple2 tuple : output) {
+ for (Tuple2<?,?> tuple : output) {
System.out.println(tuple._1 + " has rank: " + tuple._2 + ".");
}
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
index 07d32ad659..bd6383e13d 100644
--- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java
@@ -58,7 +58,7 @@ public class JavaWordCount {
});
List<Tuple2<String, Integer>> output = counts.collect();
- for (Tuple2 tuple : output) {
+ for (Tuple2<?,?> tuple : output) {
System.out.println(tuple._1 + ": " + tuple._2);
}
System.exit(0);
diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java
index 628cb892b6..45a0d237da 100644
--- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java
+++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java
@@ -25,7 +25,6 @@ import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
-import java.io.Serializable;
import java.util.Arrays;
import java.util.StringTokenizer;
diff --git a/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
new file mode 100644
index 0000000000..9a8e4209ed
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/streaming/examples/JavaKafkaWordCount.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.examples;
+
+import java.util.Map;
+import java.util.HashMap;
+
+import com.google.common.collect.Lists;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import scala.Tuple2;
+
+/**
+ * Consumes messages from one or more topics in Kafka and does wordcount.
+ * Usage: JavaKafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>
+ * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
+ * <zkQuorum> is a list of one or more zookeeper servers that make quorum
+ * <group> is the name of kafka consumer group
+ * <topics> is a list of one or more kafka topics to consume from
+ * <numThreads> is the number of threads the kafka consumer should use
+ *
+ * Example:
+ * `./run-example org.apache.spark.streaming.examples.JavaKafkaWordCount local[2] zoo01,zoo02,
+ * zoo03 my-consumer-group topic1,topic2 1`
+ */
+
+public class JavaKafkaWordCount {
+ public static void main(String[] args) {
+ if (args.length < 5) {
+ System.err.println("Usage: KafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>");
+ System.exit(1);
+ }
+
+ // Create the context with a 1 second batch size
+ JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount",
+ new Duration(2000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
+
+ int numThreads = Integer.parseInt(args[4]);
+ Map<String, Integer> topicMap = new HashMap<String, Integer>();
+ String[] topics = args[3].split(",");
+ for (String topic: topics) {
+ topicMap.put(topic, numThreads);
+ }
+
+ JavaPairDStream<String, String> messages = ssc.kafkaStream(args[1], args[2], topicMap);
+
+ JavaDStream<String> lines = messages.map(new Function<Tuple2<String, String>, String>() {
+ @Override
+ public String call(Tuple2<String, String> tuple2) throws Exception {
+ return tuple2._2();
+ }
+ });
+
+ JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
+ @Override
+ public Iterable<String> call(String x) {
+ return Lists.newArrayList(x.split(" "));
+ }
+ });
+
+ JavaPairDStream<String, Integer> wordCounts = words.map(
+ new PairFunction<String, String, Integer>() {
+ @Override
+ public Tuple2<String, Integer> call(String s) throws Exception {
+ return new Tuple2<String, Integer>(s, 1);
+ }
+ }).reduceByKey(new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer i1, Integer i2) throws Exception {
+ return i1 + i2;
+ }
+ });
+
+ wordCounts.print();
+ ssc.start();
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
index 868ff81f67..a119980992 100644
--- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
@@ -22,27 +22,36 @@ import org.apache.spark.SparkContext
object BroadcastTest {
def main(args: Array[String]) {
if (args.length == 0) {
- System.err.println("Usage: BroadcastTest <master> [<slices>] [numElem]")
+ System.err.println("Usage: BroadcastTest <master> [slices] [numElem] [broadcastAlgo] [blockSize]")
System.exit(1)
}
+ val bcName = if (args.length > 3) args(3) else "Http"
+ val blockSize = if (args.length > 4) args(4) else "4096"
+
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + "BroadcastFactory")
+ System.setProperty("spark.broadcast.blockSize", blockSize)
+
val sc = new SparkContext(args(0), "Broadcast Test",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
+
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
- var arr1 = new Array[Int](num)
+ val arr1 = new Array[Int](num)
for (i <- 0 until arr1.length) {
arr1(i) = i
}
- for (i <- 0 until 2) {
+ for (i <- 0 until 3) {
println("Iteration " + i)
println("===========")
+ val startTime = System.nanoTime
val barr1 = sc.broadcast(arr1)
- sc.parallelize(1 to 10, slices).foreach {
- i => println(barr1.value.size)
- }
+ val observedSizes = sc.parallelize(1 to 10, slices).map(_ => barr1.value.size)
+ // Collect the small RDD so we can print the observed sizes locally.
+ observedSizes.collect().foreach(i => println(i))
+ println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6))
}
System.exit(0)
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
index 4af45b2b4a..83db8b9e26 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
@@ -120,7 +120,7 @@ object LocalALS {
System.exit(1)
}
}
- printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS);
+ printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS)
val R = generateR()
diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
index f79f0142b8..e1afc29f9a 100644
--- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala
@@ -18,35 +18,38 @@
package org.apache.spark.examples
import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
object MultiBroadcastTest {
def main(args: Array[String]) {
if (args.length == 0) {
- System.err.println("Usage: BroadcastTest <master> [<slices>] [numElem]")
+ System.err.println("Usage: MultiBroadcastTest <master> [<slices>] [numElem]")
System.exit(1)
}
- val sc = new SparkContext(args(0), "Broadcast Test",
+ val sc = new SparkContext(args(0), "Multi-Broadcast Test",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
- var arr1 = new Array[Int](num)
+ val arr1 = new Array[Int](num)
for (i <- 0 until arr1.length) {
arr1(i) = i
}
- var arr2 = new Array[Int](num)
+ val arr2 = new Array[Int](num)
for (i <- 0 until arr2.length) {
arr2(i) = i
}
val barr1 = sc.broadcast(arr1)
val barr2 = sc.broadcast(arr2)
- sc.parallelize(1 to 10, slices).foreach {
- i => println(barr1.value.size + barr2.value.size)
+ val observedSizes: RDD[(Int, Int)] = sc.parallelize(1 to 10, slices).map { _ =>
+ (barr1.value.size, barr2.value.size)
}
+ // Collect the small RDD so we can print the observed sizes locally.
+ observedSizes.collect().foreach(i => println(i))
System.exit(0)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
index 646682878f..86dd9ca1b3 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala
@@ -21,6 +21,7 @@ import java.util.Random
import scala.math.exp
import org.apache.spark.util.Vector
import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.InputFormatInfo
/**
@@ -51,7 +52,7 @@ object SparkHdfsLR {
System.exit(1)
}
val inputPath = args(1)
- val conf = SparkEnv.get.hadoop.newConfiguration()
+ val conf = SparkHadoopUtil.get.newConfiguration()
val sc = new SparkContext(args(0), "SparkHdfsLR",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")), Map(),
InputFormatInfo.computePreferredLocations(
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
index 5a2bc9b0d0..a689e5a360 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala
@@ -38,6 +38,6 @@ object SparkPi {
if (x*x + y*y < 1) 1 else 0
}.reduce(_ + _)
println("Pi is roughly " + 4.0 * count / n)
- System.exit(0)
+ spark.stop()
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
index 5a7a9d1bd8..8543ce0e32 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala
@@ -65,7 +65,7 @@ object SparkTC {
oldCount = nextCount
// Perform the join, obtaining an RDD of (y, (z, x)) pairs,
// then project the result to obtain the new (x, z) paths.
- tc = tc.union(tc.join(edges).map(x => (x._2._2, x._2._1))).distinct().cache();
+ tc = tc.union(tc.join(edges).map(x => (x._2._2, x._2._1))).distinct().cache()
nextCount = tc.count()
} while (nextCount != oldCount)
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
index cd3423a07b..50e3f9639c 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala
@@ -19,6 +19,7 @@ package org.apache.spark.streaming.examples
import scala.collection.mutable.LinkedList
import scala.util.Random
+import scala.reflect.ClassTag
import akka.actor.Actor
import akka.actor.ActorRef
@@ -82,10 +83,10 @@ class FeederActor extends Actor {
*
* @see [[org.apache.spark.streaming.examples.FeederActor]]
*/
-class SampleActorReceiver[T: ClassManifest](urlOfPublisher: String)
+class SampleActorReceiver[T: ClassTag](urlOfPublisher: String)
extends Actor with Receiver {
- lazy private val remotePublisher = context.actorFor(urlOfPublisher)
+ lazy private val remotePublisher = context.actorSelection(urlOfPublisher)
override def preStart = remotePublisher ! SubscribeReceiver(context.self)
@@ -120,7 +121,7 @@ object FeederActor {
println("Feeder started as:" + feeder)
- actorSystem.awaitTermination();
+ actorSystem.awaitTermination()
}
}
@@ -164,7 +165,7 @@ object ActorWordCount {
*/
val lines = ssc.actorStream[String](
- Props(new SampleActorReceiver[String]("akka://test@%s:%s/user/FeederActor".format(
+ Props(new SampleActorReceiver[String]("akka.tcp://test@%s:%s/user/FeederActor".format(
host, port.toInt))), "SampleReceiver")
//compute wordcount
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
index 12f939d5a7..570ba4c81a 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala
@@ -18,13 +18,11 @@
package org.apache.spark.streaming.examples
import java.util.Properties
-import kafka.message.Message
-import kafka.producer.SyncProducerConfig
+
import kafka.producer._
-import org.apache.spark.SparkContext
+
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.util.RawTextHelper._
/**
@@ -54,9 +52,10 @@ object KafkaWordCount {
ssc.checkpoint("checkpoint")
val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
- val lines = ssc.kafkaStream(zkQuorum, group, topicpMap)
+ val lines = ssc.kafkaStream(zkQuorum, group, topicpMap).map(_._2)
val words = lines.flatMap(_.split(" "))
- val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
+ val wordCounts = words.map(x => (x, 1l))
+ .reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
wordCounts.print()
ssc.start()
@@ -68,15 +67,16 @@ object KafkaWordCountProducer {
def main(args: Array[String]) {
if (args.length < 2) {
- System.err.println("Usage: KafkaWordCountProducer <zkQuorum> <topic> <messagesPerSec> <wordsPerMessage>")
+ System.err.println("Usage: KafkaWordCountProducer <metadataBrokerList> <topic> " +
+ "<messagesPerSec> <wordsPerMessage>")
System.exit(1)
}
- val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args
+ val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args
// Zookeper connection properties
val props = new Properties()
- props.put("zk.connect", zkQuorum)
+ props.put("metadata.broker.list", brokers)
props.put("serializer.class", "kafka.serializer.StringEncoder")
val config = new ProducerConfig(props)
@@ -85,11 +85,13 @@ object KafkaWordCountProducer {
// Send some messages
while(true) {
val messages = (1 to messagesPerSec.toInt).map { messageNum =>
- (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ")
+ val str = (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString)
+ .mkString(" ")
+
+ new KeyedMessage[String, String](topic, str)
}.toArray
- println(messages.mkString(","))
- val data = new ProducerData[String, String](topic, messages)
- producer.send(data)
+
+ producer.send(messages: _*)
Thread.sleep(100)
}
}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala
new file mode 100644
index 0000000000..ff332a0282
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.examples
+
+import org.apache.spark.streaming.{ Seconds, StreamingContext }
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.dstream.MQTTReceiver
+import org.apache.spark.storage.StorageLevel
+
+import org.eclipse.paho.client.mqttv3.MqttClient
+import org.eclipse.paho.client.mqttv3.MqttClientPersistence
+import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
+import org.eclipse.paho.client.mqttv3.MqttException
+import org.eclipse.paho.client.mqttv3.MqttMessage
+import org.eclipse.paho.client.mqttv3.MqttTopic
+
+/**
+ * A simple Mqtt publisher for demonstration purposes, repeatedly publishes
+ * Space separated String Message "hello mqtt demo for spark streaming"
+ */
+object MQTTPublisher {
+
+ var client: MqttClient = _
+
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println("Usage: MQTTPublisher <MqttBrokerUrl> <topic>")
+ System.exit(1)
+ }
+
+ val Seq(brokerUrl, topic) = args.toSeq
+
+ try {
+ var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp")
+ client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance)
+ } catch {
+ case e: MqttException => println("Exception Caught: " + e)
+ }
+
+ client.connect()
+
+ val msgtopic: MqttTopic = client.getTopic(topic)
+ val msg: String = "hello mqtt demo for spark streaming"
+
+ while (true) {
+ val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes())
+ msgtopic.publish(message)
+ println("Published data. topic: " + msgtopic.getName() + " Message: " + message)
+ }
+ client.disconnect()
+ }
+}
+
+/**
+ * A sample wordcount with MqttStream stream
+ *
+ * To work with Mqtt, Mqtt Message broker/server required.
+ * Mosquitto (http://mosquitto.org/) is an open source Mqtt Broker
+ * In ubuntu mosquitto can be installed using the command `$ sudo apt-get install mosquitto`
+ * Eclipse paho project provides Java library for Mqtt Client http://www.eclipse.org/paho/
+ * Example Java code for Mqtt Publisher and Subscriber can be found here https://bitbucket.org/mkjinesh/mqttclient
+ * Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>
+ * In local mode, <master> should be 'local[n]' with n > 1
+ * <MqttbrokerUrl> and <topic> describe where Mqtt publisher is running.
+ *
+ * To run this example locally, you may run publisher as
+ * `$ ./run-example org.apache.spark.streaming.examples.MQTTPublisher tcp://localhost:1883 foo`
+ * and run the example as
+ * `$ ./run-example org.apache.spark.streaming.examples.MQTTWordCount local[2] tcp://localhost:1883 foo`
+ */
+object MQTTWordCount {
+
+ def main(args: Array[String]) {
+ if (args.length < 3) {
+ System.err.println(
+ "Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>" +
+ " In local mode, <master> should be 'local[n]' with n > 1")
+ System.exit(1)
+ }
+
+ val Seq(master, brokerUrl, topic) = args.toSeq
+
+ val ssc = new StreamingContext(master, "MqttWordCount", Seconds(2), System.getenv("SPARK_HOME"),
+ Seq(System.getenv("SPARK_EXAMPLES_JAR")))
+ val lines = ssc.mqttStream(brokerUrl, topic, StorageLevel.MEMORY_ONLY)
+
+ val words = lines.flatMap(x => x.toString.split(" "))
+ val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+ wordCounts.print()
+ ssc.start()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala
index c8743b9e25..e83ce78aa5 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala
@@ -23,6 +23,7 @@ import akka.zeromq._
import org.apache.spark.streaming.{ Seconds, StreamingContext }
import org.apache.spark.streaming.StreamingContext._
import akka.zeromq.Subscribe
+import akka.util.ByteString
/**
* A simple publisher for demonstration purposes, repeatedly publishes random Messages
@@ -40,10 +41,11 @@ object SimpleZeroMQPublisher {
val acs: ActorSystem = ActorSystem()
val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url))
- val messages: Array[String] = Array("words ", "may ", "count ")
+ implicit def stringToByteString(x: String) = ByteString(x)
+ val messages: List[ByteString] = List("words ", "may ", "count ")
while (true) {
Thread.sleep(1000)
- pubSocket ! ZMQMessage(Frame(topic) :: messages.map(x => Frame(x.getBytes)).toList)
+ pubSocket ! ZMQMessage(ByteString(topic) :: messages)
}
acs.awaitTermination()
}
@@ -78,7 +80,7 @@ object ZeroMQWordCount {
val ssc = new StreamingContext(master, "ZeroMQWordCount", Seconds(2),
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
- def bytesToStringIterator(x: Seq[Seq[Byte]]) = (x.map(x => new String(x.toArray))).iterator
+ def bytesToStringIterator(x: Seq[ByteString]) = (x.map(_.utf8String)).iterator
//For this stream, a zeroMQ publisher should be running.
val lines = ssc.zeroMQStream(url, Subscribe(topic), bytesToStringIterator)
diff --git a/mllib/pom.xml b/mllib/pom.xml
index f472082ad1..dda3900afe 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -26,7 +26,7 @@
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-mllib_2.9.3</artifactId>
+ <artifactId>spark-mllib_2.10</artifactId>
<packaging>jar</packaging>
<name>Spark Project ML Library</name>
<url>http://spark.incubator.apache.org/</url>
@@ -34,7 +34,7 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core_2.9.3</artifactId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@@ -48,12 +48,12 @@
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_2.9.3</artifactId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_2.9.3</artifactId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
@@ -63,8 +63,8 @@
</dependency>
</dependencies>
<build>
- <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
- <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
<plugin>
<groupId>org.scalatest</groupId>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index edbf77dbcc..0dee9399a8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -18,15 +18,16 @@
package org.apache.spark.mllib.clustering
import scala.collection.mutable.ArrayBuffer
-import scala.util.Random
+
+import org.jblas.DoubleMatrix
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.util.XORShiftRandom
-import org.jblas.DoubleMatrix
/**
@@ -195,7 +196,7 @@ class KMeans private (
*/
private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = {
// Sample all the cluster centers in one pass to avoid repeated scans
- val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq
+ val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray)
}
@@ -210,7 +211,7 @@ class KMeans private (
*/
private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = {
// Initialize each run's center to a random point
- val seed = new Random().nextInt()
+ val seed = new XORShiftRandom().nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r)))
@@ -222,7 +223,7 @@ class KMeans private (
for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point))
}.reduceByKey(_ + _).collectAsMap()
val chosen = data.mapPartitionsWithIndex { (index, points) =>
- val rand = new Random(seed ^ (step << 16) ^ index)
+ val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
for {
p <- points
r <- 0 until runs
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
index 5aec867257..d5f3f6b8db 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala
@@ -83,7 +83,7 @@ object MFDataGenerator{
scala.math.round(.99 * m * n)).toInt
val rand = new Random()
val mn = m * n
- val shuffled = rand.shuffle(1 to mn toIterable)
+ val shuffled = rand.shuffle(1 to mn toList)
val omega = shuffled.slice(0, sampSize)
val ordered = omega.sortWith(_ < _).toArray
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
index 32d3934ac1..33b99f4bd3 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
@@ -77,7 +77,7 @@ public class JavaKMeansSuite implements Serializable {
@Test
public void runKMeansUsingStaticMethods() {
- List<double[]> points = new ArrayList();
+ List<double[]> points = new ArrayList<double[]>();
points.add(new double[]{1.0, 2.0, 6.0});
points.add(new double[]{1.0, 3.0, 0.0});
points.add(new double[]{1.0, 4.0, 6.0});
@@ -94,7 +94,7 @@ public class JavaKMeansSuite implements Serializable {
@Test
public void runKMeansUsingConstructor() {
- List<double[]> points = new ArrayList();
+ List<double[]> points = new ArrayList<double[]>();
points.add(new double[]{1.0, 2.0, 6.0});
points.add(new double[]{1.0, 3.0, 0.0});
points.add(new double[]{1.0, 4.0, 6.0});
diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
index eafee060cd..b40f552e0d 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java
@@ -21,8 +21,6 @@ import java.io.Serializable;
import java.util.List;
import java.lang.Math;
-import scala.Tuple2;
-
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
diff --git a/new-yarn/pom.xml b/new-yarn/pom.xml
new file mode 100644
index 0000000000..4cd28f34e3
--- /dev/null
+++ b/new-yarn/pom.xml
@@ -0,0 +1,161 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ ~ Licensed to the Apache Software Foundation (ASF) under one or more
+ ~ contributor license agreements. See the NOTICE file distributed with
+ ~ this work for additional information regarding copyright ownership.
+ ~ The ASF licenses this file to You under the Apache License, Version 2.0
+ ~ (the "License"); you may not use this file except in compliance with
+ ~ the License. You may obtain a copy of the License at
+ ~
+ ~ http://www.apache.org/licenses/LICENSE-2.0
+ ~
+ ~ Unless required by applicable law or agreed to in writing, software
+ ~ distributed under the License is distributed on an "AS IS" BASIS,
+ ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ ~ See the License for the specific language governing permissions and
+ ~ limitations under the License.
+ -->
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-parent</artifactId>
+ <version>0.9.0-incubating-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-yarn_2.10</artifactId>
+ <packaging>jar</packaging>
+ <name>Spark Project YARN Support</name>
+ <url>http://spark.incubator.apache.org/</url>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_2.10</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-common</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-client</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ <version>${yarn.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.avro</groupId>
+ <artifactId>avro-ipc</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_2.10</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-shade-plugin</artifactId>
+ <configuration>
+ <shadedArtifactAttached>false</shadedArtifactAttached>
+ <outputFile>${project.build.directory}/${project.artifactId}-${project.version}-shaded.jar</outputFile>
+ <artifactSet>
+ <includes>
+ <include>*:*</include>
+ </includes>
+ </artifactSet>
+ <filters>
+ <filter>
+ <artifact>*:*</artifact>
+ <excludes>
+ <exclude>META-INF/*.SF</exclude>
+ <exclude>META-INF/*.DSA</exclude>
+ <exclude>META-INF/*.RSA</exclude>
+ </excludes>
+ </filter>
+ </filters>
+ </configuration>
+ <executions>
+ <execution>
+ <phase>package</phase>
+ <goals>
+ <goal>shade</goal>
+ </goals>
+ <configuration>
+ <transformers>
+ <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" />
+ <transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
+ <resource>reference.conf</resource>
+ </transformer>
+ </transformers>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-antrun-plugin</artifactId>
+ <executions>
+ <execution>
+ <phase>test</phase>
+ <goals>
+ <goal>run</goal>
+ </goals>
+ <configuration>
+ <exportAntProperties>true</exportAntProperties>
+ <tasks>
+ <property name="spark.classpath" refid="maven.test.classpath" />
+ <property environment="env" />
+ <fail message="Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry.">
+ <condition>
+ <not>
+ <or>
+ <isset property="env.SCALA_HOME" />
+ <isset property="env.SCALA_LIBRARY_PATH" />
+ </or>
+ </not>
+ </condition>
+ </fail>
+ </tasks>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ <configuration>
+ <environmentVariables>
+ <SPARK_HOME>${basedir}/..</SPARK_HOME>
+ <SPARK_TESTING>1</SPARK_TESTING>
+ <SPARK_CLASSPATH>${spark.classpath}</SPARK_CLASSPATH>
+ </environmentVariables>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+</project>
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
new file mode 100644
index 0000000000..eeeca3ea8a
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -0,0 +1,446 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.IOException
+import java.net.Socket
+import java.util.concurrent.CopyOnWriteArrayList
+import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.util.ShutdownHookManager
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.util.Utils
+
+
+class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
+
+ def this(args: ApplicationMasterArguments) = this(args, new Configuration())
+
+ private var rpc: YarnRPC = YarnRPC.create(conf)
+ private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+ private var appAttemptId: ApplicationAttemptId = _
+ private var userThread: Thread = _
+ private val fs = FileSystem.get(yarnConf)
+
+ private var yarnAllocator: YarnAllocationHandler = _
+ private var isFinished: Boolean = false
+ private var uiAddress: String = _
+ private val maxAppAttempts: Int = conf.getInt(
+ YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS)
+ private var isLastAMRetry: Boolean = true
+ private var amClient: AMRMClient[ContainerRequest] = _
+
+ // Default to numWorkers * 2, with minimum of 3
+ private val maxNumWorkerFailures = System.getProperty("spark.yarn.max.worker.failures",
+ math.max(args.numWorkers * 2, 3).toString()).toInt
+
+ def run() {
+ // Setup the directories so things go to YARN approved directories rather
+ // than user specified and /tmp.
+ System.setProperty("spark.local.dir", getLocalDirs())
+
+ // Use priority 30 as it's higher then HDFS. It's same priority as MapReduce is using.
+ ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30)
+
+ appAttemptId = getApplicationAttemptId()
+ isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts
+ amClient = AMRMClient.createAMRMClient()
+ amClient.init(yarnConf)
+ amClient.start()
+
+ // Workaround until hadoop moves to something which has
+ // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line)
+ // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
+
+ ApplicationMaster.register(this)
+
+ // Start the user's JAR
+ userThread = startUserClass()
+
+ // This a bit hacky, but we need to wait until the spark.driver.port property has
+ // been set by the Thread executing the user class.
+ waitForSparkMaster()
+
+ waitForSparkContextInitialized()
+
+ // Do this after Spark master is up and SparkContext is created so that we can register UI Url.
+ val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
+
+ // Allocate all containers
+ allocateWorkers()
+
+ // Wait for the user class to Finish
+ userThread.join()
+
+ System.exit(0)
+ }
+
+ /** Get the Yarn approved local directories. */
+ private def getLocalDirs(): String = {
+ // Hadoop 0.23 and 2.x have different Environment variable names for the
+ // local dirs, so lets check both. We assume one of the 2 is set.
+ // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
+ val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
+ .getOrElse(Option(System.getenv("LOCAL_DIRS"))
+ .getOrElse(""))
+
+ if (localDirs.isEmpty()) {
+ throw new Exception("Yarn Local dirs can't be empty")
+ }
+ localDirs
+ }
+
+ private def getApplicationAttemptId(): ApplicationAttemptId = {
+ val envs = System.getenv()
+ val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ logInfo("ApplicationAttemptId: " + appAttemptId)
+ appAttemptId
+ }
+
+ private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
+ logInfo("Registering the ApplicationMaster")
+ amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress)
+ }
+
+ private def waitForSparkMaster() {
+ logInfo("Waiting for Spark driver to be reachable.")
+ var driverUp = false
+ var tries = 0
+ val numTries = System.getProperty("spark.yarn.applicationMaster.waitTries", "10").toInt
+ while (!driverUp && tries < numTries) {
+ val driverHost = System.getProperty("spark.driver.host")
+ val driverPort = System.getProperty("spark.driver.port")
+ try {
+ val socket = new Socket(driverHost, driverPort.toInt)
+ socket.close()
+ logInfo("Driver now available: %s:%s".format(driverHost, driverPort))
+ driverUp = true
+ } catch {
+ case e: Exception => {
+ logWarning("Failed to connect to driver at %s:%s, retrying ...".
+ format(driverHost, driverPort))
+ Thread.sleep(100)
+ tries = tries + 1
+ }
+ }
+ }
+ }
+
+ private def startUserClass(): Thread = {
+ logInfo("Starting the user JAR in a separate Thread")
+ val mainMethod = Class.forName(
+ args.userClass,
+ false /* initialize */,
+ Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]])
+ val t = new Thread {
+ override def run() {
+ var successed = false
+ try {
+ // Copy
+ var mainArgs: Array[String] = new Array[String](args.userArgs.size)
+ args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size)
+ mainMethod.invoke(null, mainArgs)
+ // some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR
+ // userThread will stop here unless it has uncaught exception thrown out
+ // It need shutdown hook to set SUCCEEDED
+ successed = true
+ } finally {
+ logDebug("finishing main")
+ isLastAMRetry = true
+ if (successed) {
+ ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
+ } else {
+ ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.FAILED)
+ }
+ }
+ }
+ }
+ t.start()
+ t
+ }
+
+ // This need to happen before allocateWorkers()
+ private def waitForSparkContextInitialized() {
+ logInfo("Waiting for Spark context initialization")
+ try {
+ var sparkContext: SparkContext = null
+ ApplicationMaster.sparkContextRef.synchronized {
+ var numTries = 0
+ val waitTime = 10000L
+ val maxNumTries = System.getProperty("spark.yarn.ApplicationMaster.waitTries", "10").toInt
+ while (ApplicationMaster.sparkContextRef.get() == null && numTries < maxNumTries) {
+ logInfo("Waiting for Spark context initialization ... " + numTries)
+ numTries = numTries + 1
+ ApplicationMaster.sparkContextRef.wait(waitTime)
+ }
+ sparkContext = ApplicationMaster.sparkContextRef.get()
+ assert(sparkContext != null || numTries >= maxNumTries)
+
+ if (sparkContext != null) {
+ uiAddress = sparkContext.ui.appUIAddress
+ this.yarnAllocator = YarnAllocationHandler.newAllocator(
+ yarnConf,
+ amClient,
+ appAttemptId,
+ args,
+ sparkContext.preferredNodeLocationData)
+ } else {
+ logWarning("Unable to retrieve SparkContext inspite of waiting for %d, maxNumTries = %d".
+ format(numTries * waitTime, maxNumTries))
+ this.yarnAllocator = YarnAllocationHandler.newAllocator(
+ yarnConf,
+ amClient,
+ appAttemptId,
+ args)
+ }
+ }
+ } finally {
+ // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT :
+ // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks.
+ ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
+ }
+ }
+
+ private def allocateWorkers() {
+ try {
+ logInfo("Allocating " + args.numWorkers + " workers.")
+ // Wait until all containers have finished
+ // TODO: This is a bit ugly. Can we make it nicer?
+ // TODO: Handle container failure
+ yarnAllocator.addResourceRequests(args.numWorkers)
+ // Exits the loop if the user thread exits.
+ while (yarnAllocator.getNumWorkersRunning < args.numWorkers && userThread.isAlive) {
+ if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) {
+ finishApplicationMaster(FinalApplicationStatus.FAILED,
+ "max number of worker failures reached")
+ }
+ yarnAllocator.allocateResources()
+ ApplicationMaster.incrementAllocatorLoop(1)
+ Thread.sleep(100)
+ }
+ } finally {
+ // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT,
+ // so that the loop in ApplicationMaster#sparkContextInitialized() breaks.
+ ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
+ }
+ logInfo("All workers have launched.")
+
+ // Launch a progress reporter thread, else the app will get killed after expiration
+ // (def: 10mins) timeout.
+ if (userThread.isAlive) {
+ // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses.
+ val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+
+ // we want to be reasonably responsive without causing too many requests to RM.
+ val schedulerInterval =
+ System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong
+
+ // must be <= timeoutInterval / 2.
+ val interval = math.min(timeoutInterval / 2, schedulerInterval)
+
+ launchReporterThread(interval)
+ }
+ }
+
+ private def launchReporterThread(_sleepTime: Long): Thread = {
+ val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
+
+ val t = new Thread {
+ override def run() {
+ while (userThread.isAlive) {
+ if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) {
+ finishApplicationMaster(FinalApplicationStatus.FAILED,
+ "max number of worker failures reached")
+ }
+ val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning -
+ yarnAllocator.getNumPendingAllocate
+ if (missingWorkerCount > 0) {
+ logInfo("Allocating %d containers to make up for (potentially) lost containers".
+ format(missingWorkerCount))
+ yarnAllocator.addResourceRequests(missingWorkerCount)
+ }
+ sendProgress()
+ Thread.sleep(sleepTime)
+ }
+ }
+ }
+ // Setting to daemon status, though this is usually not a good idea.
+ t.setDaemon(true)
+ t.start()
+ logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+ t
+ }
+
+ private def sendProgress() {
+ logDebug("Sending progress")
+ // Simulated with an allocate request with no nodes requested.
+ yarnAllocator.allocateResources()
+ }
+
+ /*
+ def printContainers(containers: List[Container]) = {
+ for (container <- containers) {
+ logInfo("Launching shell command on a new container."
+ + ", containerId=" + container.getId()
+ + ", containerNode=" + container.getNodeId().getHost()
+ + ":" + container.getNodeId().getPort()
+ + ", containerNodeURI=" + container.getNodeHttpAddress()
+ + ", containerState" + container.getState()
+ + ", containerResourceMemory"
+ + container.getResource().getMemory())
+ }
+ }
+ */
+
+ def finishApplicationMaster(status: FinalApplicationStatus, diagnostics: String = "") {
+ synchronized {
+ if (isFinished) {
+ return
+ }
+ isFinished = true
+ }
+
+ logInfo("finishApplicationMaster with " + status)
+ // Set tracking URL to empty since we don't have a history server.
+ amClient.unregisterApplicationMaster(status, "" /* appMessage */, "" /* appTrackingUrl */)
+ }
+
+ /**
+ * Clean up the staging directory.
+ */
+ private def cleanupStagingDir() {
+ var stagingDirPath: Path = null
+ try {
+ val preserveFiles = System.getProperty("spark.yarn.preserve.staging.files", "false").toBoolean
+ if (!preserveFiles) {
+ stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR"))
+ if (stagingDirPath == null) {
+ logError("Staging directory is null")
+ return
+ }
+ logInfo("Deleting staging directory " + stagingDirPath)
+ fs.delete(stagingDirPath, true)
+ }
+ } catch {
+ case ioe: IOException =>
+ logError("Failed to cleanup staging dir " + stagingDirPath, ioe)
+ }
+ }
+
+ // The shutdown hook that runs when a signal is received AND during normal close of the JVM.
+ class AppMasterShutdownHook(appMaster: ApplicationMaster) extends Runnable {
+
+ def run() {
+ logInfo("AppMaster received a signal.")
+ // we need to clean up staging dir before HDFS is shut down
+ // make sure we don't delete it until this is the last AM
+ if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir()
+ }
+ }
+}
+
+object ApplicationMaster {
+ // Number of times to wait for the allocator loop to complete.
+ // Each loop iteration waits for 100ms, so maximum of 3 seconds.
+ // This is to ensure that we have reasonable number of containers before we start
+ // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
+ // optimal as more containers are available. Might need to handle this better.
+ private val ALLOCATOR_LOOP_WAIT_COUNT = 30
+
+ private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
+
+ val sparkContextRef: AtomicReference[SparkContext] =
+ new AtomicReference[SparkContext](null /* initialValue */)
+
+ val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
+
+ def incrementAllocatorLoop(by: Int) {
+ val count = yarnAllocatorLoop.getAndAdd(by)
+ if (count >= ALLOCATOR_LOOP_WAIT_COUNT) {
+ yarnAllocatorLoop.synchronized {
+ // to wake threads off wait ...
+ yarnAllocatorLoop.notifyAll()
+ }
+ }
+ }
+
+ def register(master: ApplicationMaster) {
+ applicationMasters.add(master)
+ }
+
+ // TODO(harvey): See whether this should be discarded - it isn't used anywhere atm...
+ def sparkContextInitialized(sc: SparkContext): Boolean = {
+ var modified = false
+ sparkContextRef.synchronized {
+ modified = sparkContextRef.compareAndSet(null, sc)
+ sparkContextRef.notifyAll()
+ }
+
+ // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do
+ // System.exit.
+ // Should not really have to do this, but it helps YARN to evict resources earlier.
+ // Not to mention, prevent the Client from declaring failure even though we exited properly.
+ // Note that this will unfortunately not properly clean up the staging files because it gets
+ // called too late, after the filesystem is already shutdown.
+ if (modified) {
+ Runtime.getRuntime().addShutdownHook(new Thread with Logging {
+ // This is not only logs, but also ensures that log system is initialized for this instance
+ // when we are actually 'run'-ing.
+ logInfo("Adding shutdown hook for context " + sc)
+ override def run() {
+ logInfo("Invoking sc stop from shutdown hook")
+ sc.stop()
+ // Best case ...
+ for (master <- applicationMasters) {
+ master.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
+ }
+ }
+ } )
+ }
+
+ // Wait for initialization to complete and atleast 'some' nodes can get allocated.
+ yarnAllocatorLoop.synchronized {
+ while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) {
+ yarnAllocatorLoop.wait(1000L)
+ }
+ }
+ modified
+ }
+
+ def main(argStrings: Array[String]) {
+ val args = new ApplicationMasterArguments(argStrings)
+ new ApplicationMaster(args).run()
+ }
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
new file mode 100644
index 0000000000..f76a5ddd39
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import org.apache.spark.util.IntParam
+import collection.mutable.ArrayBuffer
+
+class ApplicationMasterArguments(val args: Array[String]) {
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024
+ var workerCores = 1
+ var numWorkers = 2
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer = new ArrayBuffer[String]()
+
+ var args = inputArgs
+
+ while (! args.isEmpty) {
+
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: IntParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ }
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: org.apache.spark.deploy.yarn.ApplicationMaster [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n")
+ System.exit(exitCode)
+ }
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
new file mode 100644
index 0000000000..94678815e8
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -0,0 +1,519 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.{InetAddress, UnknownHostException, URI}
+import java.nio.ByteBuffer
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileContext, FileStatus, FileSystem, Path, FileUtil}
+import org.apache.hadoop.fs.permission.FsPermission;
+import org.apache.hadoop.io.DataOutputBuffer
+import org.apache.hadoop.mapred.Master
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.client.api.impl.YarnClientImpl
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{Apps, Records}
+
+import org.apache.spark.Logging
+import org.apache.spark.util.Utils
+import org.apache.spark.deploy.SparkHadoopUtil
+
+
+/**
+ * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The
+ * Client submits an application to the global ResourceManager to launch Spark's ApplicationMaster,
+ * which will launch a Spark master process and negotiate resources throughout its duration.
+ */
+class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging {
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+ val credentials = UserGroupInformation.getCurrentUser().getCredentials()
+ private val SPARK_STAGING: String = ".sparkStaging"
+ private val distCacheMgr = new ClientDistributedCacheManager()
+
+ // Staging directory is private! -> rwx--------
+ val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700: Short)
+ // App files are world-wide readable and owner writable -> rw-r--r--
+ val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644: Short)
+
+ def this(args: ClientArguments) = this(new Configuration(), args)
+
+ def runApp(): ApplicationId = {
+ validateArgs()
+ // Initialize and start the client service.
+ init(yarnConf)
+ start()
+
+ // Log details about this YARN cluster (e.g, the number of slave machines/NodeManagers).
+ logClusterResourceDetails()
+
+ // Prepare to submit a request to the ResourcManager (specifically its ApplicationsManager (ASM)
+ // interface).
+
+ // Get a new client application.
+ val newApp = super.createApplication()
+ val newAppResponse = newApp.getNewApplicationResponse()
+ val appId = newAppResponse.getApplicationId()
+
+ verifyClusterResources(newAppResponse)
+
+ // Set up resource and environment variables.
+ val appStagingDir = getAppStagingDir(appId)
+ val localResources = prepareLocalResources(appStagingDir)
+ val launchEnv = setupLaunchEnv(localResources, appStagingDir)
+ val amContainer = createContainerLaunchContext(newAppResponse, localResources, launchEnv)
+
+ // Set up an application submission context.
+ val appContext = newApp.getApplicationSubmissionContext()
+ appContext.setApplicationName(args.appName)
+ appContext.setQueue(args.amQueue)
+ appContext.setAMContainerSpec(amContainer)
+
+ // Memory for the ApplicationMaster.
+ val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
+ memoryResource.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ appContext.setResource(memoryResource)
+
+ // Finally, submit and monitor the application.
+ submitApp(appContext)
+ appId
+ }
+
+ def run() {
+ val appId = runApp()
+ monitorApplication(appId)
+ System.exit(0)
+ }
+
+ // TODO(harvey): This could just go in ClientArguments.
+ def validateArgs() = {
+ Map(
+ (System.getenv("SPARK_JAR") == null) -> "Error: You must set SPARK_JAR environment variable!",
+ (args.userJar == null) -> "Error: You must specify a user jar!",
+ (args.userClass == null) -> "Error: You must specify a user class!",
+ (args.numWorkers <= 0) -> "Error: You must specify atleast 1 worker!",
+ (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: AM memory size must be" +
+ "greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD),
+ (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Worker memory size" +
+ "must be greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString)
+ ).foreach { case(cond, errStr) =>
+ if (cond) {
+ logError(errStr)
+ args.printUsageAndExit(1)
+ }
+ }
+ }
+
+ def getAppStagingDir(appId: ApplicationId): String = {
+ SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR
+ }
+
+ def logClusterResourceDetails() {
+ val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
+ logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " +
+ clusterMetrics.getNumNodeManagers)
+
+ val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue)
+ logInfo("""Queue info ... queueName: %s, queueCurrentCapacity: %s, queueMaxCapacity: %s,
+ queueApplicationCount = %s, queueChildQueueCount = %s""".format(
+ queueInfo.getQueueName,
+ queueInfo.getCurrentCapacity,
+ queueInfo.getMaximumCapacity,
+ queueInfo.getApplications.size,
+ queueInfo.getChildQueues.size))
+ }
+
+ def verifyClusterResources(app: GetNewApplicationResponse) = {
+ val maxMem = app.getMaximumResourceCapability().getMemory()
+ logInfo("Max mem capabililty of a single resource in this cluster " + maxMem)
+
+ // If we have requested more then the clusters max for a single resource then exit.
+ if (args.workerMemory > maxMem) {
+ logError("Required worker memory (%d MB), is above the max threshold (%d MB) of this cluster.".
+ format(args.workerMemory, maxMem))
+ System.exit(1)
+ }
+ val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD
+ if (amMem > maxMem) {
+ logError("Required AM memory (%d) is above the max threshold (%d) of this cluster".
+ format(args.amMemory, maxMem))
+ System.exit(1)
+ }
+
+ // We could add checks to make sure the entire cluster has enough resources but that involves
+ // getting all the node reports and computing ourselves.
+ }
+
+ /** See if two file systems are the same or not. */
+ private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = {
+ val srcUri = srcFs.getUri()
+ val dstUri = destFs.getUri()
+ if (srcUri.getScheme() == null) {
+ return false
+ }
+ if (!srcUri.getScheme().equals(dstUri.getScheme())) {
+ return false
+ }
+ var srcHost = srcUri.getHost()
+ var dstHost = dstUri.getHost()
+ if ((srcHost != null) && (dstHost != null)) {
+ try {
+ srcHost = InetAddress.getByName(srcHost).getCanonicalHostName()
+ dstHost = InetAddress.getByName(dstHost).getCanonicalHostName()
+ } catch {
+ case e: UnknownHostException =>
+ return false
+ }
+ if (!srcHost.equals(dstHost)) {
+ return false
+ }
+ } else if (srcHost == null && dstHost != null) {
+ return false
+ } else if (srcHost != null && dstHost == null) {
+ return false
+ }
+ //check for ports
+ if (srcUri.getPort() != dstUri.getPort()) {
+ return false
+ }
+ return true
+ }
+
+ /** Copy the file into HDFS if needed. */
+ private def copyRemoteFile(
+ dstDir: Path,
+ originalPath: Path,
+ replication: Short,
+ setPerms: Boolean = false): Path = {
+ val fs = FileSystem.get(conf)
+ val remoteFs = originalPath.getFileSystem(conf)
+ var newPath = originalPath
+ if (! compareFs(remoteFs, fs)) {
+ newPath = new Path(dstDir, originalPath.getName())
+ logInfo("Uploading " + originalPath + " to " + newPath)
+ FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf)
+ fs.setReplication(newPath, replication)
+ if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION))
+ }
+ // Resolve any symlinks in the URI path so using a "current" symlink to point to a specific
+ // version shows the specific version in the distributed cache configuration
+ val qualPath = fs.makeQualified(newPath)
+ val fc = FileContext.getFileContext(qualPath.toUri(), conf)
+ val destPath = fc.resolvePath(qualPath)
+ destPath
+ }
+
+ def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = {
+ logInfo("Preparing Local resources")
+ // Upload Spark and the application JAR to the remote file system if necessary. Add them as
+ // local resources to the application master.
+ val fs = FileSystem.get(conf)
+
+ val delegTokenRenewer = Master.getMasterPrincipal(conf)
+ if (UserGroupInformation.isSecurityEnabled()) {
+ if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) {
+ logError("Can't get Master Kerberos principal for use as renewer")
+ System.exit(1)
+ }
+ }
+ val dst = new Path(fs.getHomeDirectory(), appStagingDir)
+ val replication = System.getProperty("spark.yarn.submit.file.replication", "3").toShort
+
+ if (UserGroupInformation.isSecurityEnabled()) {
+ val dstFs = dst.getFileSystem(conf)
+ dstFs.addDelegationTokens(delegTokenRenewer, credentials)
+ }
+
+ val localResources = HashMap[String, LocalResource]()
+ FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION))
+
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+
+ Map(
+ Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar,
+ Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF")
+ ).foreach { case(destName, _localPath) =>
+ val localPath: String = if (_localPath != null) _localPath.trim() else ""
+ if (! localPath.isEmpty()) {
+ var localURI = new URI(localPath)
+ // If not specified assume these are in the local filesystem to keep behavior like Hadoop
+ if (localURI.getScheme() == null) {
+ localURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(localPath)).toString)
+ }
+ val setPermissions = if (destName.equals(Client.APP_JAR)) true else false
+ val destPath = copyRemoteFile(dst, new Path(localURI), replication, setPermissions)
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+ destName, statCache)
+ }
+ }
+
+ // Handle jars local to the ApplicationMaster.
+ if ((args.addJars != null) && (!args.addJars.isEmpty())){
+ args.addJars.split(',').foreach { case file: String =>
+ val localURI = new URI(file.trim())
+ val localPath = new Path(localURI)
+ val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+ val destPath = copyRemoteFile(dst, localPath, replication)
+ // Only add the resource to the Spark ApplicationMaster.
+ val appMasterOnly = true
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+ linkname, statCache, appMasterOnly)
+ }
+ }
+
+ // Handle any distributed cache files
+ if ((args.files != null) && (!args.files.isEmpty())){
+ args.files.split(',').foreach { case file: String =>
+ val localURI = new URI(file.trim())
+ val localPath = new Path(localURI)
+ val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+ val destPath = copyRemoteFile(dst, localPath, replication)
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+ linkname, statCache)
+ }
+ }
+
+ // Handle any distributed cache archives
+ if ((args.archives != null) && (!args.archives.isEmpty())) {
+ args.archives.split(',').foreach { case file:String =>
+ val localURI = new URI(file.trim())
+ val localPath = new Path(localURI)
+ val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+ val destPath = copyRemoteFile(dst, localPath, replication)
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE,
+ linkname, statCache)
+ }
+ }
+
+ UserGroupInformation.getCurrentUser().addCredentials(credentials)
+ localResources
+ }
+
+ def setupLaunchEnv(
+ localResources: HashMap[String, LocalResource],
+ stagingDir: String): HashMap[String, String] = {
+ logInfo("Setting up the launch environment")
+ val log4jConfLocalRes = localResources.getOrElse(Client.LOG4J_PROP, null)
+
+ val env = new HashMap[String, String]()
+
+ Client.populateClasspath(yarnConf, log4jConfLocalRes != null, env)
+ env("SPARK_YARN_MODE") = "true"
+ env("SPARK_YARN_STAGING_DIR") = stagingDir
+
+ // Set the environment variables to be passed on to the Workers.
+ distCacheMgr.setDistFilesEnv(env)
+ distCacheMgr.setDistArchivesEnv(env)
+
+ // Allow users to specify some environment variables.
+ Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
+
+ // Add each SPARK_* key to the environment.
+ System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+
+ env
+ }
+
+ def userArgsToString(clientArgs: ClientArguments): String = {
+ val prefix = " --args "
+ val args = clientArgs.userArgs
+ val retval = new StringBuilder()
+ for (arg <- args){
+ retval.append(prefix).append(" '").append(arg).append("' ")
+ }
+ retval.toString
+ }
+
+ def createContainerLaunchContext(
+ newApp: GetNewApplicationResponse,
+ localResources: HashMap[String, LocalResource],
+ env: HashMap[String, String]): ContainerLaunchContext = {
+ logInfo("Setting up container launch context")
+ val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
+ amContainer.setLocalResources(localResources)
+ amContainer.setEnvironment(env)
+
+ // TODO: Need a replacement for the following code to fix -Xmx?
+ // val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
+ // var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
+ // ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) -
+ // YarnAllocationHandler.MEMORY_OVERHEAD)
+
+ // Extra options for the JVM
+ var JAVA_OPTS = ""
+
+ // Add Xmx for AM memory
+ JAVA_OPTS += "-Xmx" + args.amMemory + "m"
+
+ val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
+ JAVA_OPTS += " -Djava.io.tmpdir=" + tmpDir
+
+ // TODO: Remove once cpuset version is pushed out.
+ // The context is, default gc for server class machines ends up using all cores to do gc -
+ // hence if there are multiple containers in same node, Spark GC affects all other containers'
+ // performance (which can be that of other Spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce "proper" Spark behavior in
+ // multi-tenant environments. Not sure how default Java GC behaves if it is limited to subset
+ // of cores on a node.
+ val useConcurrentAndIncrementalGC = env.isDefinedAt("SPARK_USE_CONC_INCR_GC") &&
+ java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))
+ if (useConcurrentAndIncrementalGC) {
+ // In our expts, using (default) throughput collector has severe perf ramifications in
+ // multi-tenant machines
+ JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+ JAVA_OPTS += " -XX:+CMSIncrementalMode "
+ JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ }
+
+ if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+ JAVA_OPTS += " " + env("SPARK_JAVA_OPTS")
+ }
+
+ // Command for the ApplicationMaster
+ var javaCommand = "java"
+ val javaHome = System.getenv("JAVA_HOME")
+ if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) {
+ javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
+ }
+
+ val commands = List[String](
+ javaCommand +
+ " -server " +
+ JAVA_OPTS +
+ " " + args.amClass +
+ " --class " + args.userClass +
+ " --jar " + args.userJar +
+ userArgsToString(args) +
+ " --worker-memory " + args.workerMemory +
+ " --worker-cores " + args.workerCores +
+ " --num-workers " + args.numWorkers +
+ " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+ " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+
+ logInfo("Command for starting the Spark ApplicationMaster: " + commands(0))
+ amContainer.setCommands(commands)
+
+ // Setup security tokens.
+ val dob = new DataOutputBuffer()
+ credentials.writeTokenStorageToStream(dob)
+ amContainer.setTokens(ByteBuffer.wrap(dob.getData()))
+
+ amContainer
+ }
+
+ def submitApp(appContext: ApplicationSubmissionContext) = {
+ // Submit the application to the applications manager.
+ logInfo("Submitting application to ASM")
+ super.submitApplication(appContext)
+ }
+
+ def monitorApplication(appId: ApplicationId): Boolean = {
+ while (true) {
+ Thread.sleep(1000)
+ val report = super.getApplicationReport(appId)
+
+ logInfo("Application report from ASM: \n" +
+ "\t application identifier: " + appId.toString() + "\n" +
+ "\t appId: " + appId.getId() + "\n" +
+ "\t clientToAMToken: " + report.getClientToAMToken() + "\n" +
+ "\t appDiagnostics: " + report.getDiagnostics() + "\n" +
+ "\t appMasterHost: " + report.getHost() + "\n" +
+ "\t appQueue: " + report.getQueue() + "\n" +
+ "\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
+ "\t appStartTime: " + report.getStartTime() + "\n" +
+ "\t yarnAppState: " + report.getYarnApplicationState() + "\n" +
+ "\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" +
+ "\t appTrackingUrl: " + report.getTrackingUrl() + "\n" +
+ "\t appUser: " + report.getUser()
+ )
+
+ val state = report.getYarnApplicationState()
+ val dsStatus = report.getFinalApplicationStatus()
+ if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ return true
+ }
+ }
+ true
+ }
+}
+
+object Client {
+ val SPARK_JAR: String = "spark.jar"
+ val APP_JAR: String = "app.jar"
+ val LOG4J_PROP: String = "log4j.properties"
+
+ def main(argStrings: Array[String]) {
+ // Set an env variable indicating we are running in YARN mode.
+ // Note: anything env variable with SPARK_ prefix gets propagated to all (remote) processes -
+ // see Client#setupLaunchEnv().
+ System.setProperty("SPARK_YARN_MODE", "true")
+
+ val args = new ClientArguments(argStrings)
+
+ (new Client(args)).run()
+ }
+
+ // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
+ def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) {
+ for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
+ }
+ }
+
+ def populateClasspath(conf: Configuration, addLog4j: Boolean, env: HashMap[String, String]) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
+ // If log4j present, ensure ours overrides all others
+ if (addLog4j) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + LOG4J_PROP)
+ }
+ // Normally the users app.jar is last in case conflicts with spark jars
+ val userClasspathFirst = System.getProperty("spark.yarn.user.classpath.first", "false")
+ .toBoolean
+ if (userClasspathFirst) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + APP_JAR)
+ }
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + SPARK_JAR)
+ Client.populateHadoopClasspath(conf, env)
+
+ if (!userClasspathFirst) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + APP_JAR)
+ }
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "*")
+ }
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
new file mode 100644
index 0000000000..70be15d0a3
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+import org.apache.spark.scheduler.{InputFormatInfo, SplitInfo}
+import org.apache.spark.util.IntParam
+import org.apache.spark.util.MemoryParam
+
+
+// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware !
+class ClientArguments(val args: Array[String]) {
+ var addJars: String = null
+ var files: String = null
+ var archives: String = null
+ var userJar: String = null
+ var userClass: String = null
+ var userArgs: Seq[String] = Seq[String]()
+ var workerMemory = 1024 // MB
+ var workerCores = 1
+ var numWorkers = 2
+ var amQueue = System.getProperty("QUEUE", "default")
+ var amMemory: Int = 512 // MB
+ var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster"
+ var appName: String = "Spark"
+ // TODO
+ var inputFormatInfo: List[InputFormatInfo] = null
+ // TODO(harvey)
+ var priority = 0
+
+ parseArgs(args.toList)
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]()
+ val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]()
+
+ var args = inputArgs
+
+ while (!args.isEmpty) {
+ args match {
+ case ("--jar") :: value :: tail =>
+ userJar = value
+ args = tail
+
+ case ("--class") :: value :: tail =>
+ userClass = value
+ args = tail
+
+ case ("--args") :: value :: tail =>
+ userArgsBuffer += value
+ args = tail
+
+ case ("--master-class") :: value :: tail =>
+ amClass = value
+ args = tail
+
+ case ("--master-memory") :: MemoryParam(value) :: tail =>
+ amMemory = value
+ args = tail
+
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
+ case ("--worker-memory") :: MemoryParam(value) :: tail =>
+ workerMemory = value
+ args = tail
+
+ case ("--worker-cores") :: IntParam(value) :: tail =>
+ workerCores = value
+ args = tail
+
+ case ("--queue") :: value :: tail =>
+ amQueue = value
+ args = tail
+
+ case ("--name") :: value :: tail =>
+ appName = value
+ args = tail
+
+ case ("--addJars") :: value :: tail =>
+ addJars = value
+ args = tail
+
+ case ("--files") :: value :: tail =>
+ files = value
+ args = tail
+
+ case ("--archives") :: value :: tail =>
+ archives = value
+ args = tail
+
+ case Nil =>
+ if (userJar == null || userClass == null) {
+ printUsageAndExit(1)
+ }
+
+ case _ =>
+ printUsageAndExit(1, args)
+ }
+ }
+
+ userArgs = userArgsBuffer.readOnly
+ inputFormatInfo = inputFormatMap.values.toList
+ }
+
+
+ def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ if (unknownParam != null) {
+ System.err.println("Unknown/unsupported param " + unknownParam)
+ }
+ System.err.println(
+ "Usage: org.apache.spark.deploy.yarn.Client [options] \n" +
+ "Options:\n" +
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
+ " --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" +
+ " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
+ " --name NAME The name of your application (Default: Spark)\n" +
+ " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
+ " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
+ " --files files Comma separated list of files to be distributed with the job.\n" +
+ " --archives archives Comma separated list of archives to be distributed with the job."
+ )
+ System.exit(exitCode)
+ }
+
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
new file mode 100644
index 0000000000..5f159b073f
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsAction
+import org.apache.hadoop.yarn.api.records.LocalResource
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
+import org.apache.hadoop.yarn.api.records.LocalResourceType
+import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
+
+import org.apache.spark.Logging
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.LinkedHashMap
+import scala.collection.mutable.Map
+
+
+/** Client side methods to setup the Hadoop distributed cache */
+class ClientDistributedCacheManager() extends Logging {
+ private val distCacheFiles: Map[String, Tuple3[String, String, String]] =
+ LinkedHashMap[String, Tuple3[String, String, String]]()
+ private val distCacheArchives: Map[String, Tuple3[String, String, String]] =
+ LinkedHashMap[String, Tuple3[String, String, String]]()
+
+
+ /**
+ * Add a resource to the list of distributed cache resources. This list can
+ * be sent to the ApplicationMaster and possibly the workers so that it can
+ * be downloaded into the Hadoop distributed cache for use by this application.
+ * Adds the LocalResource to the localResources HashMap passed in and saves
+ * the stats of the resources to they can be sent to the workers and verified.
+ *
+ * @param fs FileSystem
+ * @param conf Configuration
+ * @param destPath path to the resource
+ * @param localResources localResource hashMap to insert the resource into
+ * @param resourceType LocalResourceType
+ * @param link link presented in the distributed cache to the destination
+ * @param statCache cache to store the file/directory stats
+ * @param appMasterOnly Whether to only add the resource to the app master
+ */
+ def addResource(
+ fs: FileSystem,
+ conf: Configuration,
+ destPath: Path,
+ localResources: HashMap[String, LocalResource],
+ resourceType: LocalResourceType,
+ link: String,
+ statCache: Map[URI, FileStatus],
+ appMasterOnly: Boolean = false) = {
+ val destStatus = fs.getFileStatus(destPath)
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(resourceType)
+ val visibility = getVisibility(conf, destPath.toUri(), statCache)
+ amJarRsrc.setVisibility(visibility)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(destPath))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name")
+ localResources(link) = amJarRsrc
+
+ if (appMasterOnly == false) {
+ val uri = destPath.toUri()
+ val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link)
+ if (resourceType == LocalResourceType.FILE) {
+ distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(),
+ destStatus.getModificationTime().toString(), visibility.name())
+ } else {
+ distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(),
+ destStatus.getModificationTime().toString(), visibility.name())
+ }
+ }
+ }
+
+ /**
+ * Adds the necessary cache file env variables to the env passed in
+ * @param env
+ */
+ def setDistFilesEnv(env: Map[String, String]) = {
+ val (keys, tupleValues) = distCacheFiles.unzip
+ val (sizes, timeStamps, visibilities) = tupleValues.unzip3
+
+ if (keys.size > 0) {
+ env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") =
+ timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_FILE_SIZES") =
+ sizes.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_VISIBILITIES") =
+ visibilities.reduceLeft[String] { (acc,n) => acc + "," + n }
+ }
+ }
+
+ /**
+ * Adds the necessary cache archive env variables to the env passed in
+ * @param env
+ */
+ def setDistArchivesEnv(env: Map[String, String]) = {
+ val (keys, tupleValues) = distCacheArchives.unzip
+ val (sizes, timeStamps, visibilities) = tupleValues.unzip3
+
+ if (keys.size > 0) {
+ env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") =
+ timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") =
+ sizes.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") =
+ visibilities.reduceLeft[String] { (acc,n) => acc + "," + n }
+ }
+ }
+
+ /**
+ * Returns the local resource visibility depending on the cache file permissions
+ * @param conf
+ * @param uri
+ * @param statCache
+ * @return LocalResourceVisibility
+ */
+ def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ LocalResourceVisibility = {
+ if (isPublic(conf, uri, statCache)) {
+ return LocalResourceVisibility.PUBLIC
+ }
+ return LocalResourceVisibility.PRIVATE
+ }
+
+ /**
+ * Returns a boolean to denote whether a cache file is visible to all(public)
+ * or not
+ * @param conf
+ * @param uri
+ * @param statCache
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = {
+ val fs = FileSystem.get(uri, conf)
+ val current = new Path(uri.getPath())
+ //the leaf level file should be readable by others
+ if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) {
+ return false
+ }
+ return ancestorsHaveExecutePermissions(fs, current.getParent(), statCache)
+ }
+
+ /**
+ * Returns true if all ancestors of the specified path have the 'execute'
+ * permission set for all users (i.e. that other users can traverse
+ * the directory heirarchy to the given path)
+ * @param fs
+ * @param path
+ * @param statCache
+ * @return true if all ancestors have the 'execute' permission set for all users
+ */
+ def ancestorsHaveExecutePermissions(fs: FileSystem, path: Path,
+ statCache: Map[URI, FileStatus]): Boolean = {
+ var current = path
+ while (current != null) {
+ //the subdirs in the path should have execute permissions for others
+ if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) {
+ return false
+ }
+ current = current.getParent()
+ }
+ return true
+ }
+
+ /**
+ * Checks for a given path whether the Other permissions on it
+ * imply the permission in the passed FsAction
+ * @param fs
+ * @param path
+ * @param action
+ * @param statCache
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ def checkPermissionOfOther(fs: FileSystem, path: Path,
+ action: FsAction, statCache: Map[URI, FileStatus]): Boolean = {
+ val status = getFileStatus(fs, path.toUri(), statCache)
+ val perms = status.getPermission()
+ val otherAction = perms.getOtherAction()
+ if (otherAction.implies(action)) {
+ return true
+ }
+ return false
+ }
+
+ /**
+ * Checks to see if the given uri exists in the cache, if it does it
+ * returns the existing FileStatus, otherwise it stats the uri, stores
+ * it in the cache, and returns the FileStatus.
+ * @param fs
+ * @param uri
+ * @param statCache
+ * @return FileStatus
+ */
+ def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = {
+ val stat = statCache.get(uri) match {
+ case Some(existstat) => existstat
+ case None =>
+ val newStat = fs.getFileStatus(new Path(uri))
+ statCache.put(uri, newStat)
+ newStat
+ }
+ return stat
+ }
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
new file mode 100644
index 0000000000..bc31bb2eb0
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.Socket
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import akka.actor._
+import akka.remote._
+import akka.actor.Terminated
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.scheduler.SplitInfo
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+
+class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
+
+ def this(args: ApplicationMasterArguments) = this(args, new Configuration())
+
+ private var appAttemptId: ApplicationAttemptId = _
+ private var reporterThread: Thread = _
+ private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ private var yarnAllocator: YarnAllocationHandler = _
+ private var driverClosed:Boolean = false
+
+ private var amClient: AMRMClient[ContainerRequest] = _
+
+ val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0)._1
+ var actor: ActorRef = _
+
+ // This actor just working as a monitor to watch on Driver Actor.
+ class MonitorActor(driverUrl: String) extends Actor {
+
+ var driver: ActorSelection = null
+
+ override def preStart() {
+ logInfo("Listen to driver: " + driverUrl)
+ driver = context.actorSelection(driverUrl)
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ }
+
+ override def receive = {
+ case x: DisassociatedEvent =>
+ logInfo("Driver terminated or disconnected! Shutting down.")
+ driverClosed = true
+ }
+ }
+
+ def run() {
+
+ amClient = AMRMClient.createAMRMClient()
+ amClient.init(yarnConf)
+ amClient.start()
+
+ appAttemptId = getApplicationAttemptId()
+ registerApplicationMaster()
+
+ waitForSparkMaster()
+
+ // Allocate all containers
+ allocateWorkers()
+
+ // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
+ // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
+
+ val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+ // must be <= timeoutInterval/ 2.
+ // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
+ // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
+ val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+ reporterThread = launchReporterThread(interval)
+
+ // Wait for the reporter thread to Finish.
+ reporterThread.join()
+
+ finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
+ actorSystem.shutdown()
+
+ logInfo("Exited")
+ System.exit(0)
+ }
+
+ private def getApplicationAttemptId(): ApplicationAttemptId = {
+ val envs = System.getenv()
+ val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ logInfo("ApplicationAttemptId: " + appAttemptId)
+ appAttemptId
+ }
+
+ private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
+ logInfo("Registering the ApplicationMaster")
+ // TODO:(Raymond) Find out Spark UI address and fill in here?
+ amClient.registerApplicationMaster(Utils.localHostName(), 0, "")
+ }
+
+ private def waitForSparkMaster() {
+ logInfo("Waiting for Spark driver to be reachable.")
+ var driverUp = false
+ val hostport = args.userArgs(0)
+ val (driverHost, driverPort) = Utils.parseHostPort(hostport)
+ while(!driverUp) {
+ try {
+ val socket = new Socket(driverHost, driverPort)
+ socket.close()
+ logInfo("Driver now available: %s:%s".format(driverHost, driverPort))
+ driverUp = true
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to driver at %s:%s, retrying ...".
+ format(driverHost, driverPort))
+ Thread.sleep(100)
+ }
+ }
+ System.setProperty("spark.driver.host", driverHost)
+ System.setProperty("spark.driver.port", driverPort.toString)
+
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
+ driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+ actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM")
+ }
+
+
+ private def allocateWorkers() {
+
+ // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now.
+ val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] =
+ scala.collection.immutable.Map()
+
+ yarnAllocator = YarnAllocationHandler.newAllocator(
+ yarnConf,
+ amClient,
+ appAttemptId,
+ args,
+ preferredNodeLocationData)
+
+ logInfo("Allocating " + args.numWorkers + " workers.")
+ // Wait until all containers have finished
+ // TODO: This is a bit ugly. Can we make it nicer?
+ // TODO: Handle container failure
+
+ yarnAllocator.addResourceRequests(args.numWorkers)
+ while(yarnAllocator.getNumWorkersRunning < args.numWorkers) {
+ yarnAllocator.allocateResources()
+ Thread.sleep(100)
+ }
+
+ logInfo("All workers have launched.")
+
+ }
+
+ // TODO: We might want to extend this to allocate more containers in case they die !
+ private def launchReporterThread(_sleepTime: Long): Thread = {
+ val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
+
+ val t = new Thread {
+ override def run() {
+ while (!driverClosed) {
+ val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning -
+ yarnAllocator.getNumPendingAllocate
+ if (missingWorkerCount > 0) {
+ logInfo("Allocating %d containers to make up for (potentially) lost containers".
+ format(missingWorkerCount))
+ yarnAllocator.addResourceRequests(missingWorkerCount)
+ }
+ sendProgress()
+ Thread.sleep(sleepTime)
+ }
+ }
+ }
+ // setting to daemon status, though this is usually not a good idea.
+ t.setDaemon(true)
+ t.start()
+ logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+ t
+ }
+
+ private def sendProgress() {
+ logDebug("Sending progress")
+ // simulated with an allocate request with no nodes requested ...
+ yarnAllocator.allocateResources()
+ }
+
+ def finishApplicationMaster(status: FinalApplicationStatus) {
+ logInfo("finish ApplicationMaster with " + status)
+ amClient.unregisterApplicationMaster(status, "" /* appMessage */, "" /* appTrackingUrl */)
+ }
+
+}
+
+
+object WorkerLauncher {
+ def main(argStrings: Array[String]) {
+ val args = new ApplicationMasterArguments(argStrings)
+ new WorkerLauncher(args).run()
+ }
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
new file mode 100644
index 0000000000..9f5523c4b9
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+import java.nio.ByteBuffer
+import java.security.PrivilegedExceptionAction
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.HashMap
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.DataOutputBuffer
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.records.impl.pb.ProtoUtils
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.client.api.NMClient
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
+
+import org.apache.spark.Logging
+
+
+class WorkerRunnable(
+ container: Container,
+ conf: Configuration,
+ masterAddress: String,
+ slaveId: String,
+ hostname: String,
+ workerMemory: Int,
+ workerCores: Int)
+ extends Runnable with Logging {
+
+ var rpc: YarnRPC = YarnRPC.create(conf)
+ var nmClient: NMClient = _
+ val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ def run = {
+ logInfo("Starting Worker Container")
+ nmClient = NMClient.createNMClient()
+ nmClient.init(yarnConf)
+ nmClient.start()
+ startContainer
+ }
+
+ def startContainer = {
+ logInfo("Setting up ContainerLaunchContext")
+
+ val ctx = Records.newRecord(classOf[ContainerLaunchContext])
+ .asInstanceOf[ContainerLaunchContext]
+
+ val localResources = prepareLocalResources
+ ctx.setLocalResources(localResources)
+
+ val env = prepareEnvironment
+ ctx.setEnvironment(env)
+
+ // Extra options for the JVM
+ var JAVA_OPTS = ""
+ // Set the JVM memory
+ val workerMemoryString = workerMemory + "m"
+ JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " "
+ if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
+ JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
+ }
+
+ JAVA_OPTS += " -Djava.io.tmpdir=" +
+ new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " "
+
+ // Commenting it out for now - so that people can refer to the properties if required. Remove
+ // it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence
+ // if there are multiple containers in same node, spark gc effects all other containers
+ // performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in
+ // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset
+ // of cores on a node.
+/*
+ else {
+ // If no java_opts specified, default to using -XX:+CMSIncrementalMode
+ // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont
+ // want to mess with it.
+ // In our expts, using (default) throughput collector has severe perf ramnifications in
+ // multi-tennent machines
+ // The options are based on
+ // http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
+ JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
+ JAVA_OPTS += " -XX:+CMSIncrementalMode "
+ JAVA_OPTS += " -XX:+CMSIncrementalPacing "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
+ JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ }
+*/
+
+ val credentials = UserGroupInformation.getCurrentUser().getCredentials()
+ val dob = new DataOutputBuffer()
+ credentials.writeTokenStorageToStream(dob)
+ ctx.setTokens(ByteBuffer.wrap(dob.getData()))
+
+ var javaCommand = "java"
+ val javaHome = System.getenv("JAVA_HOME")
+ if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) {
+ javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
+ }
+
+ val commands = List[String](javaCommand +
+ " -server " +
+ // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
+ // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in
+ // an inconsistent state.
+ // TODO: If the OOM is not recoverable by rescheduling it on different node, then do
+ // 'something' to fail job ... akin to blacklisting trackers in mapred ?
+ " -XX:OnOutOfMemoryError='kill %p' " +
+ JAVA_OPTS +
+ " org.apache.spark.executor.CoarseGrainedExecutorBackend " +
+ masterAddress + " " +
+ slaveId + " " +
+ hostname + " " +
+ workerCores +
+ " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
+ " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
+ logInfo("Setting up worker with commands: " + commands)
+ ctx.setCommands(commands)
+
+ // Send the start request to the ContainerManager
+ nmClient.startContainer(container, ctx)
+ }
+
+ private def setupDistributedCache(
+ file: String,
+ rtype: LocalResourceType,
+ localResources: HashMap[String, LocalResource],
+ timestamp: String,
+ size: String,
+ vis: String) = {
+ val uri = new URI(file)
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(rtype)
+ amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis))
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri))
+ amJarRsrc.setTimestamp(timestamp.toLong)
+ amJarRsrc.setSize(size.toLong)
+ localResources(uri.getFragment()) = amJarRsrc
+ }
+
+ def prepareLocalResources: HashMap[String, LocalResource] = {
+ logInfo("Preparing Local resources")
+ val localResources = HashMap[String, LocalResource]()
+
+ if (System.getenv("SPARK_YARN_CACHE_FILES") != null) {
+ val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
+ val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
+ val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',')
+ val visibilities = System.getenv("SPARK_YARN_CACHE_FILES_VISIBILITIES").split(',')
+ for( i <- 0 to distFiles.length - 1) {
+ setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i),
+ fileSizes(i), visibilities(i))
+ }
+ }
+
+ if (System.getenv("SPARK_YARN_CACHE_ARCHIVES") != null) {
+ val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',')
+ val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',')
+ val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',')
+ val visibilities = System.getenv("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES").split(',')
+ for( i <- 0 to distArchives.length - 1) {
+ setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources,
+ timeStamps(i), fileSizes(i), visibilities(i))
+ }
+ }
+
+ logInfo("Prepared Local resources " + localResources)
+ localResources
+ }
+
+ def prepareEnvironment: HashMap[String, String] = {
+ val env = new HashMap[String, String]()
+
+ Client.populateClasspath(yarnConf, System.getenv("SPARK_YARN_LOG4J_PATH") != null, env)
+
+ // Allow users to specify some environment variables
+ Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
+
+ System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
+ env
+ }
+
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
new file mode 100644
index 0000000000..c27257cda4
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -0,0 +1,687 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.lang.{Boolean => JBoolean}
+import java.util.{Collections, Set => JSet}
+import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
+import org.apache.spark.Logging
+import org.apache.spark.scheduler.SplitInfo
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
+import org.apache.spark.util.Utils
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.ApplicationMasterProtocol
+import org.apache.hadoop.yarn.api.records.ApplicationAttemptId
+import org.apache.hadoop.yarn.api.records.{Container, ContainerId, ContainerStatus}
+import org.apache.hadoop.yarn.api.records.{Priority, Resource, ResourceRequest}
+import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
+import org.apache.hadoop.yarn.client.api.AMRMClient
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
+import org.apache.hadoop.yarn.util.{RackResolver, Records}
+
+
+object AllocationType extends Enumeration ("HOST", "RACK", "ANY") {
+ type AllocationType = Value
+ val HOST, RACK, ANY = Value
+}
+
+// TODO:
+// Too many params.
+// Needs to be mt-safe
+// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should
+// make it more proactive and decoupled.
+
+// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
+// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for
+// more info on how we are requesting for containers.
+private[yarn] class YarnAllocationHandler(
+ val conf: Configuration,
+ val amClient: AMRMClient[ContainerRequest],
+ val appAttemptId: ApplicationAttemptId,
+ val maxWorkers: Int,
+ val workerMemory: Int,
+ val workerCores: Int,
+ val preferredHostToCount: Map[String, Int],
+ val preferredRackToCount: Map[String, Int])
+ extends Logging {
+ // These three are locked on allocatedHostToContainersMap. Complementary data structures
+ // allocatedHostToContainersMap : containers which are running : host, Set<containerid>
+ // allocatedContainerToHostMap: container to host mapping.
+ private val allocatedHostToContainersMap =
+ new HashMap[String, collection.mutable.Set[ContainerId]]()
+
+ private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
+
+ // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an
+ // allocated node)
+ // As with the two data structures above, tightly coupled with them, and to be locked on
+ // allocatedHostToContainersMap
+ private val allocatedRackCount = new HashMap[String, Int]()
+
+ // Containers which have been released.
+ private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]()
+ // Containers to be released in next request to RM
+ private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
+
+ // Number of container requests that have been sent to, but not yet allocated by the
+ // ApplicationMaster.
+ private val numPendingAllocate = new AtomicInteger()
+ private val numWorkersRunning = new AtomicInteger()
+ // Used to generate a unique id per worker
+ private val workerIdCounter = new AtomicInteger()
+ private val lastResponseId = new AtomicInteger()
+ private val numWorkersFailed = new AtomicInteger()
+
+ def getNumPendingAllocate: Int = numPendingAllocate.intValue
+
+ def getNumWorkersRunning: Int = numWorkersRunning.intValue
+
+ def getNumWorkersFailed: Int = numWorkersFailed.intValue
+
+ def isResourceConstraintSatisfied(container: Container): Boolean = {
+ container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ }
+
+ def releaseContainer(container: Container) {
+ val containerId = container.getId
+ pendingReleaseContainers.put(containerId, true)
+ amClient.releaseAssignedContainer(containerId)
+ }
+
+ def allocateResources() {
+ // We have already set the container request. Poll the ResourceManager for a response.
+ // This doubles as a heartbeat if there are no pending container requests.
+ val progressIndicator = 0.1f
+ val allocateResponse = amClient.allocate(progressIndicator)
+
+ val allocatedContainers = allocateResponse.getAllocatedContainers()
+ if (allocatedContainers.size > 0) {
+ var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size)
+
+ if (numPendingAllocateNow < 0) {
+ numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow)
+ }
+
+ logDebug("""
+ Allocated containers: %d
+ Current worker count: %d
+ Containers released: %s
+ Containers to-be-released: %s
+ Cluster resources: %s
+ """.format(
+ allocatedContainers.size,
+ numWorkersRunning.get(),
+ releasedContainerList,
+ pendingReleaseContainers,
+ allocateResponse.getAvailableResources))
+
+ val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ for (container <- allocatedContainers) {
+ if (isResourceConstraintSatisfied(container)) {
+ // Add the accepted `container` to the host's list of already accepted,
+ // allocated containers
+ val host = container.getNodeId.getHost
+ val containersForHost = hostToContainers.getOrElseUpdate(host,
+ new ArrayBuffer[Container]())
+ containersForHost += container
+ } else {
+ // Release container, since it doesn't satisfy resource constraints.
+ releaseContainer(container)
+ }
+ }
+
+ // Find the appropriate containers to use.
+ // TODO: Cleanup this group-by...
+ val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
+ val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
+
+ for (candidateHost <- hostToContainers.keySet) {
+ val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
+ val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
+
+ val remainingContainersOpt = hostToContainers.get(candidateHost)
+ assert(remainingContainersOpt.isDefined)
+ var remainingContainers = remainingContainersOpt.get
+
+ if (requiredHostCount >= remainingContainers.size) {
+ // Since we have <= required containers, add all remaining containers to
+ // `dataLocalContainers`.
+ dataLocalContainers.put(candidateHost, remainingContainers)
+ // There are no more free containers remaining.
+ remainingContainers = null
+ } else if (requiredHostCount > 0) {
+ // Container list has more containers than we need for data locality.
+ // Split the list into two: one based on the data local container count,
+ // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining
+ // containers.
+ val (dataLocal, remaining) = remainingContainers.splitAt(
+ remainingContainers.size - requiredHostCount)
+ dataLocalContainers.put(candidateHost, dataLocal)
+
+ // Invariant: remainingContainers == remaining
+
+ // YARN has a nasty habit of allocating a ton of containers on a host - discourage this.
+ // Add each container in `remaining` to list of containers to release. If we have an
+ // insufficient number of containers, then the next allocation cycle will reallocate
+ // (but won't treat it as data local).
+ // TODO(harvey): Rephrase this comment some more.
+ for (container <- remaining) releaseContainer(container)
+ remainingContainers = null
+ }
+
+ // For rack local containers
+ if (remainingContainers != null) {
+ val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+ if (rack != null) {
+ val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
+ val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
+ rackLocalContainers.getOrElse(rack, List()).size
+
+ if (requiredRackCount >= remainingContainers.size) {
+ // Add all remaining containers to to `dataLocalContainers`.
+ dataLocalContainers.put(rack, remainingContainers)
+ remainingContainers = null
+ } else if (requiredRackCount > 0) {
+ // Container list has more containers that we need for data locality.
+ // Split the list into two: one based on the data local container count,
+ // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining
+ // containers.
+ val (rackLocal, remaining) = remainingContainers.splitAt(
+ remainingContainers.size - requiredRackCount)
+ val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack,
+ new ArrayBuffer[Container]())
+
+ existingRackLocal ++= rackLocal
+
+ remainingContainers = remaining
+ }
+ }
+ }
+
+ if (remainingContainers != null) {
+ // Not all containers have been consumed - add them to the list of off-rack containers.
+ offRackContainers.put(candidateHost, remainingContainers)
+ }
+ }
+
+ // Now that we have split the containers into various groups, go through them in order:
+ // first host-local, then rack-local, and finally off-rack.
+ // Note that the list we create below tries to ensure that not all containers end up within
+ // a host if there is a sufficiently large number of hosts/containers.
+ val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size)
+ allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
+ allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
+ allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(offRackContainers)
+
+ // Run each of the allocated containers.
+ for (container <- allocatedContainersToProcess) {
+ val numWorkersRunningNow = numWorkersRunning.incrementAndGet()
+ val workerHostname = container.getNodeId.getHost
+ val containerId = container.getId
+
+ val workerMemoryOverhead = (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+ assert(container.getResource.getMemory >= workerMemoryOverhead)
+
+ if (numWorkersRunningNow > maxWorkers) {
+ logInfo("""Ignoring container %s at host %s, since we already have the required number of
+ containers for it.""".format(containerId, workerHostname))
+ releaseContainer(container)
+ numWorkersRunning.decrementAndGet()
+ } else {
+ val workerId = workerIdCounter.incrementAndGet().toString
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"),
+ System.getProperty("spark.driver.port"),
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+ logInfo("Launching container %s for on host %s".format(containerId, workerHostname))
+
+ // To be safe, remove the container from `pendingReleaseContainers`.
+ pendingReleaseContainers.remove(containerId)
+
+ val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
+ allocatedHostToContainersMap.synchronized {
+ val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname,
+ new HashSet[ContainerId]())
+
+ containerSet += containerId
+ allocatedContainerToHostMap.put(containerId, workerHostname)
+
+ if (rack != null) {
+ allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
+ }
+ }
+ logInfo("Launching WorkerRunnable. driverUrl: %s, workerHostname: %s".format(driverUrl, workerHostname))
+ val workerRunnable = new WorkerRunnable(
+ container,
+ conf,
+ driverUrl,
+ workerId,
+ workerHostname,
+ workerMemory,
+ workerCores)
+ new Thread(workerRunnable).start()
+ }
+ }
+ logDebug("""
+ Finished allocating %s containers (from %s originally).
+ Current number of workers running: %d,
+ releasedContainerList: %s,
+ pendingReleaseContainers: %s
+ """.format(
+ allocatedContainersToProcess,
+ allocatedContainers,
+ numWorkersRunning.get(),
+ releasedContainerList,
+ pendingReleaseContainers))
+ }
+
+ val completedContainers = allocateResponse.getCompletedContainersStatuses()
+ if (completedContainers.size > 0) {
+ logDebug("Completed %d containers".format(completedContainers.size))
+
+ for (completedContainer <- completedContainers) {
+ val containerId = completedContainer.getContainerId
+
+ if (pendingReleaseContainers.containsKey(containerId)) {
+ // YarnAllocationHandler already marked the container for release, so remove it from
+ // `pendingReleaseContainers`.
+ pendingReleaseContainers.remove(containerId)
+ } else {
+ // Decrement the number of workers running. The next iteration of the ApplicationMaster's
+ // reporting thread will take care of allocating.
+ numWorkersRunning.decrementAndGet()
+ logInfo("Completed container %s (state: %s, exit status: %s)".format(
+ containerId,
+ completedContainer.getState,
+ completedContainer.getExitStatus()))
+ // Hadoop 2.2.X added a ContainerExitStatus we should switch to use
+ // there are some exit status' we shouldn't necessarily count against us, but for
+ // now I think its ok as none of the containers are expected to exit
+ if (completedContainer.getExitStatus() != 0) {
+ logInfo("Container marked as failed: " + containerId)
+ numWorkersFailed.incrementAndGet()
+ }
+ }
+
+ allocatedHostToContainersMap.synchronized {
+ if (allocatedContainerToHostMap.containsKey(containerId)) {
+ val hostOpt = allocatedContainerToHostMap.get(containerId)
+ assert(hostOpt.isDefined)
+ val host = hostOpt.get
+
+ val containerSetOpt = allocatedHostToContainersMap.get(host)
+ assert(containerSetOpt.isDefined)
+ val containerSet = containerSetOpt.get
+
+ containerSet.remove(containerId)
+ if (containerSet.isEmpty) {
+ allocatedHostToContainersMap.remove(host)
+ } else {
+ allocatedHostToContainersMap.update(host, containerSet)
+ }
+
+ allocatedContainerToHostMap.remove(containerId)
+
+ // TODO: Move this part outside the synchronized block?
+ val rack = YarnAllocationHandler.lookupRack(conf, host)
+ if (rack != null) {
+ val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
+ if (rackCount > 0) {
+ allocatedRackCount.put(rack, rackCount)
+ } else {
+ allocatedRackCount.remove(rack)
+ }
+ }
+ }
+ }
+ }
+ logDebug("""
+ Finished processing %d completed containers.
+ Current number of workers running: %d,
+ releasedContainerList: %s,
+ pendingReleaseContainers: %s
+ """.format(
+ completedContainers.size,
+ numWorkersRunning.get(),
+ releasedContainerList,
+ pendingReleaseContainers))
+ }
+ }
+
+ def createRackResourceRequests(
+ hostContainers: ArrayBuffer[ContainerRequest]
+ ): ArrayBuffer[ContainerRequest] = {
+ // Generate modified racks and new set of hosts under it before issuing requests.
+ val rackToCounts = new HashMap[String, Int]()
+
+ for (container <- hostContainers) {
+ val candidateHost = container.getNodes.last
+ assert(YarnAllocationHandler.ANY_HOST != candidateHost)
+
+ val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
+ if (rack != null) {
+ var count = rackToCounts.getOrElse(rack, 0)
+ count += 1
+ rackToCounts.put(rack, count)
+ }
+ }
+
+ val requestedContainers = new ArrayBuffer[ContainerRequest](rackToCounts.size)
+ for ((rack, count) <- rackToCounts) {
+ requestedContainers ++= createResourceRequests(
+ AllocationType.RACK,
+ rack,
+ count,
+ YarnAllocationHandler.PRIORITY)
+ }
+
+ requestedContainers
+ }
+
+ def allocatedContainersOnHost(host: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedHostToContainersMap.getOrElse(host, Set()).size
+ }
+ retval
+ }
+
+ def allocatedContainersOnRack(rack: String): Int = {
+ var retval = 0
+ allocatedHostToContainersMap.synchronized {
+ retval = allocatedRackCount.getOrElse(rack, 0)
+ }
+ retval
+ }
+
+ def addResourceRequests(numWorkers: Int) {
+ val containerRequests: List[ContainerRequest] =
+ if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
+ logDebug("numWorkers: " + numWorkers + ", host preferences: " +
+ preferredHostToCount.isEmpty)
+ createResourceRequests(
+ AllocationType.ANY,
+ resource = null,
+ numWorkers,
+ YarnAllocationHandler.PRIORITY).toList
+ } else {
+ // Request for all hosts in preferred nodes and for numWorkers -
+ // candidates.size, request by default allocation policy.
+ val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size)
+ for ((candidateHost, candidateCount) <- preferredHostToCount) {
+ val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
+
+ if (requiredCount > 0) {
+ hostContainerRequests ++= createResourceRequests(
+ AllocationType.HOST,
+ candidateHost,
+ requiredCount,
+ YarnAllocationHandler.PRIORITY)
+ }
+ }
+ val rackContainerRequests: List[ContainerRequest] = createRackResourceRequests(
+ hostContainerRequests).toList
+
+ val anyContainerRequests = createResourceRequests(
+ AllocationType.ANY,
+ resource = null,
+ numWorkers,
+ YarnAllocationHandler.PRIORITY)
+
+ val containerRequestBuffer = new ArrayBuffer[ContainerRequest](
+ hostContainerRequests.size + rackContainerRequests.size() + anyContainerRequests.size)
+
+ containerRequestBuffer ++= hostContainerRequests
+ containerRequestBuffer ++= rackContainerRequests
+ containerRequestBuffer ++= anyContainerRequests
+ containerRequestBuffer.toList
+ }
+
+ for (request <- containerRequests) {
+ amClient.addContainerRequest(request)
+ }
+
+ if (numWorkers > 0) {
+ numPendingAllocate.addAndGet(numWorkers)
+ logInfo("Will Allocate %d worker containers, each with %d memory".format(
+ numWorkers,
+ (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)))
+ } else {
+ logDebug("Empty allocation request ...")
+ }
+
+ for (request <- containerRequests) {
+ val nodes = request.getNodes
+ var hostStr = if (nodes == null || nodes.isEmpty) {
+ "Any"
+ } else {
+ nodes.last
+ }
+ logInfo("Container request (host: %s, priority: %s, capability: %s".format(
+ hostStr,
+ request.getPriority().getPriority,
+ request.getCapability))
+ }
+ }
+
+ private def createResourceRequests(
+ requestType: AllocationType.AllocationType,
+ resource: String,
+ numWorkers: Int,
+ priority: Int
+ ): ArrayBuffer[ContainerRequest] = {
+
+ // If hostname is specified, then we need at least two requests - node local and rack local.
+ // There must be a third request, which is ANY. That will be specially handled.
+ requestType match {
+ case AllocationType.HOST => {
+ assert(YarnAllocationHandler.ANY_HOST != resource)
+ val hostname = resource
+ val nodeLocal = constructContainerRequests(
+ Array(hostname),
+ racks = null,
+ numWorkers,
+ priority)
+
+ // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler.
+ YarnAllocationHandler.populateRackInfo(conf, hostname)
+ nodeLocal
+ }
+ case AllocationType.RACK => {
+ val rack = resource
+ constructContainerRequests(hosts = null, Array(rack), numWorkers, priority)
+ }
+ case AllocationType.ANY => constructContainerRequests(
+ hosts = null, racks = null, numWorkers, priority)
+ case _ => throw new IllegalArgumentException(
+ "Unexpected/unsupported request type: " + requestType)
+ }
+ }
+
+ private def constructContainerRequests(
+ hosts: Array[String],
+ racks: Array[String],
+ numWorkers: Int,
+ priority: Int
+ ): ArrayBuffer[ContainerRequest] = {
+
+ val memoryResource = Records.newRecord(classOf[Resource])
+ memoryResource.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
+
+ val prioritySetting = Records.newRecord(classOf[Priority])
+ prioritySetting.setPriority(priority)
+
+ val requests = new ArrayBuffer[ContainerRequest]()
+ for (i <- 0 until numWorkers) {
+ requests += new ContainerRequest(memoryResource, hosts, racks, prioritySetting)
+ }
+ requests
+ }
+}
+
+object YarnAllocationHandler {
+
+ val ANY_HOST = "*"
+ // All requests are issued with same priority : we do not (yet) have any distinction between
+ // request types (like map/reduce in hadoop for example)
+ val PRIORITY = 1
+
+ // Additional memory overhead - in mb.
+ val MEMORY_OVERHEAD = 384
+
+ // Host to rack map - saved from allocation requests. We are expecting this not to change.
+ // Note that it is possible for this to change : and ResurceManager will indicate that to us via
+ // update response to allocate. But we are punting on handling that for now.
+ private val hostToRack = new ConcurrentHashMap[String, String]()
+ private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
+
+
+ def newAllocator(
+ conf: Configuration,
+ amClient: AMRMClient[ContainerRequest],
+ appAttemptId: ApplicationAttemptId,
+ args: ApplicationMasterArguments
+ ): YarnAllocationHandler = {
+ new YarnAllocationHandler(
+ conf,
+ amClient,
+ appAttemptId,
+ args.numWorkers,
+ args.workerMemory,
+ args.workerCores,
+ Map[String, Int](),
+ Map[String, Int]())
+ }
+
+ def newAllocator(
+ conf: Configuration,
+ amClient: AMRMClient[ContainerRequest],
+ appAttemptId: ApplicationAttemptId,
+ args: ApplicationMasterArguments,
+ map: collection.Map[String,
+ collection.Set[SplitInfo]]
+ ): YarnAllocationHandler = {
+ val (hostToSplitCount, rackToSplitCount) = generateNodeToWeight(conf, map)
+ new YarnAllocationHandler(
+ conf,
+ amClient,
+ appAttemptId,
+ args.numWorkers,
+ args.workerMemory,
+ args.workerCores,
+ hostToSplitCount,
+ rackToSplitCount)
+ }
+
+ def newAllocator(
+ conf: Configuration,
+ amClient: AMRMClient[ContainerRequest],
+ appAttemptId: ApplicationAttemptId,
+ maxWorkers: Int,
+ workerMemory: Int,
+ workerCores: Int,
+ map: collection.Map[String, collection.Set[SplitInfo]]
+ ): YarnAllocationHandler = {
+ val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
+ new YarnAllocationHandler(
+ conf,
+ amClient,
+ appAttemptId,
+ maxWorkers,
+ workerMemory,
+ workerCores,
+ hostToCount,
+ rackToCount)
+ }
+
+ // A simple method to copy the split info map.
+ private def generateNodeToWeight(
+ conf: Configuration,
+ input: collection.Map[String, collection.Set[SplitInfo]]
+ ): (Map[String, Int], Map[String, Int]) = {
+
+ if (input == null) {
+ return (Map[String, Int](), Map[String, Int]())
+ }
+
+ val hostToCount = new HashMap[String, Int]
+ val rackToCount = new HashMap[String, Int]
+
+ for ((host, splits) <- input) {
+ val hostCount = hostToCount.getOrElse(host, 0)
+ hostToCount.put(host, hostCount + splits.size)
+
+ val rack = lookupRack(conf, host)
+ if (rack != null){
+ val rackCount = rackToCount.getOrElse(host, 0)
+ rackToCount.put(host, rackCount + splits.size)
+ }
+ }
+
+ (hostToCount.toMap, rackToCount.toMap)
+ }
+
+ def lookupRack(conf: Configuration, host: String): String = {
+ if (!hostToRack.contains(host)) {
+ populateRackInfo(conf, host)
+ }
+ hostToRack.get(host)
+ }
+
+ def fetchCachedHostsForRack(rack: String): Option[Set[String]] = {
+ Option(rackToHostSet.get(rack)).map { set =>
+ val convertedSet: collection.mutable.Set[String] = set
+ // TODO: Better way to get a Set[String] from JSet.
+ convertedSet.toSet
+ }
+ }
+
+ def populateRackInfo(conf: Configuration, hostname: String) {
+ Utils.checkHost(hostname)
+
+ if (!hostToRack.containsKey(hostname)) {
+ // If there are repeated failures to resolve, all to an ignore list.
+ val rackInfo = RackResolver.resolve(conf, hostname)
+ if (rackInfo != null && rackInfo.getNetworkLocation != null) {
+ val rack = rackInfo.getNetworkLocation
+ hostToRack.put(hostname, rack)
+ if (! rackToHostSet.containsKey(rack)) {
+ rackToHostSet.putIfAbsent(rack,
+ Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
+ }
+ rackToHostSet.get(rack).add(hostname)
+
+ // TODO(harvey): Figure out what this comment means...
+ // Since RackResolver caches, we are disabling this for now ...
+ } /* else {
+ // right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
+ hostToRack.put(hostname, null)
+ } */
+ }
+ }
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
new file mode 100644
index 0000000000..2ba2366ead
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.conf.Configuration
+
+/**
+ * Contains util methods to interact with Hadoop from spark.
+ */
+class YarnSparkHadoopUtil extends SparkHadoopUtil {
+
+ // Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true.
+ override def isYarnMode(): Boolean = { true }
+
+ // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
+ // Always create a new config, dont reuse yarnConf.
+ override def newConfiguration(): Configuration = new YarnConfiguration(new Configuration())
+
+ // add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster
+ override def addCredentials(conf: JobConf) {
+ val jobCreds = conf.getCredentials()
+ jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials())
+ }
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
new file mode 100644
index 0000000000..63a0449e5a
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark._
+import org.apache.hadoop.conf.Configuration
+import org.apache.spark.deploy.yarn.YarnAllocationHandler
+import org.apache.spark.util.Utils
+
+/**
+ *
+ * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM.
+ */
+private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+
+ def this(sc: SparkContext) = this(sc, new Configuration())
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ val retval = YarnAllocationHandler.lookupRack(conf, host)
+ if (retval != null) Some(retval) else None
+ }
+
+ override def postStartHook() {
+
+ // The yarn application is running, but the worker might not yet ready
+ // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+ Thread.sleep(2000L)
+ logInfo("YarnClientClusterScheduler.postStartHook done")
+ }
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
new file mode 100644
index 0000000000..b206780c78
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
+import org.apache.spark.{SparkException, Logging, SparkContext}
+import org.apache.spark.deploy.yarn.{Client, ClientArguments}
+
+private[spark] class YarnClientSchedulerBackend(
+ scheduler: ClusterScheduler,
+ sc: SparkContext)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ with Logging {
+
+ var client: Client = null
+ var appId: ApplicationId = null
+
+ override def start() {
+ super.start()
+
+ val defalutWorkerCores = "2"
+ val defalutWorkerMemory = "512m"
+ val defaultWorkerNumber = "1"
+
+ val userJar = System.getenv("SPARK_YARN_APP_JAR")
+ var workerCores = System.getenv("SPARK_WORKER_CORES")
+ var workerMemory = System.getenv("SPARK_WORKER_MEMORY")
+ var workerNumber = System.getenv("SPARK_WORKER_INSTANCES")
+
+ if (userJar == null)
+ throw new SparkException("env SPARK_YARN_APP_JAR is not set")
+
+ if (workerCores == null)
+ workerCores = defalutWorkerCores
+ if (workerMemory == null)
+ workerMemory = defalutWorkerMemory
+ if (workerNumber == null)
+ workerNumber = defaultWorkerNumber
+
+ val driverHost = System.getProperty("spark.driver.host")
+ val driverPort = System.getProperty("spark.driver.port")
+ val hostport = driverHost + ":" + driverPort
+
+ val argsArray = Array[String](
+ "--class", "notused",
+ "--jar", userJar,
+ "--args", hostport,
+ "--worker-memory", workerMemory,
+ "--worker-cores", workerCores,
+ "--num-workers", workerNumber,
+ "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher"
+ )
+
+ val args = new ClientArguments(argsArray)
+ client = new Client(args)
+ appId = client.runApp()
+ waitForApp()
+ }
+
+ def waitForApp() {
+
+ // TODO : need a better way to find out whether the workers are ready or not
+ // maybe by resource usage report?
+ while(true) {
+ val report = client.getApplicationReport(appId)
+
+ logInfo("Application report from ASM: \n" +
+ "\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
+ "\t appStartTime: " + report.getStartTime() + "\n" +
+ "\t yarnAppState: " + report.getYarnApplicationState() + "\n"
+ )
+
+ // Ready to go, or already gone.
+ val state = report.getYarnApplicationState()
+ if (state == YarnApplicationState.RUNNING) {
+ return
+ } else if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ throw new SparkException("Yarn application already ended," +
+ "might be killed or not able to launch application master.")
+ }
+
+ Thread.sleep(1000)
+ }
+ }
+
+ override def stop() {
+ super.stop()
+ client.stop()
+ logInfo("Stoped")
+ }
+
+}
diff --git a/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
new file mode 100644
index 0000000000..29b3f22e13
--- /dev/null
+++ b/new-yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark._
+import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
+import org.apache.spark.util.Utils
+import org.apache.hadoop.conf.Configuration
+
+/**
+ *
+ * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
+ */
+private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+
+ logInfo("Created YarnClusterScheduler")
+
+ def this(sc: SparkContext) = this(sc, new Configuration())
+
+ // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate
+ // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?)
+ // Subsequent creations are ignored - since nodes are already allocated by then.
+
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ val retval = YarnAllocationHandler.lookupRack(conf, host)
+ if (retval != null) Some(retval) else None
+ }
+
+ override def postStartHook() {
+ val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
+ if (sparkContextInitialized){
+ // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+ Thread.sleep(3000L)
+ }
+ logInfo("YarnClusterScheduler.postStartHook done")
+ }
+}
diff --git a/new-yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/new-yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
new file mode 100644
index 0000000000..2941356bc5
--- /dev/null
+++ b/new-yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import org.scalatest.FunSuite
+import org.scalatest.mock.MockitoSugar
+import org.mockito.Mockito.when
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsAction
+import org.apache.hadoop.yarn.api.records.LocalResource
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
+import org.apache.hadoop.yarn.api.records.LocalResourceType
+import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
+
+
+class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
+
+ class MockClientDistributedCacheManager extends ClientDistributedCacheManager {
+ override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ LocalResourceVisibility = {
+ return LocalResourceVisibility.PRIVATE
+ }
+ }
+
+ test("test getFileStatus empty") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath() === null)
+ }
+
+ test("test getFileStatus cached") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus)
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath().toString() === "/tmp/testing")
+ }
+
+ test("test addResource") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 0)
+ assert(resource.getSize() === 0)
+ assert(resource.getType() === LocalResourceType.FILE)
+
+ val env = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env)
+ assert(env("SPARK_YARN_CACHE_FILES") === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === "0")
+ assert(env("SPARK_YARN_CACHE_FILES_FILE_SIZES") === "0")
+ assert(env("SPARK_YARN_CACHE_FILES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name())
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
+
+ //add another one and verify both there and order correct
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing2"))
+ val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2")
+ when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus)
+ distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2",
+ statCache, false)
+ val resource2 = localResources("link2")
+ assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource2.getResource()) === destPath2)
+ assert(resource2.getTimestamp() === 10)
+ assert(resource2.getSize() === 20)
+ assert(resource2.getType() === LocalResourceType.FILE)
+
+ val env2 = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env2)
+ val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
+ val files = env2("SPARK_YARN_CACHE_FILES").split(',')
+ val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
+ val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',')
+ assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(timestamps(0) === "0")
+ assert(sizes(0) === "0")
+ assert(visibilities(0) === LocalResourceVisibility.PRIVATE.name())
+
+ assert(files(1) === "file:/foo.invalid.com:8080/tmp/testing2#link2")
+ assert(timestamps(1) === "10")
+ assert(sizes(1) === "20")
+ assert(visibilities(1) === LocalResourceVisibility.PRIVATE.name())
+ }
+
+ test("test addResource link null") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ intercept[Exception] {
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null,
+ statCache, false)
+ }
+ assert(localResources.get("link") === None)
+ assert(localResources.size === 0)
+ }
+
+ test("test addResource appmaster only") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, true)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val env = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_FILES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None)
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
+ }
+
+ test("test addResource archive") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val env = new HashMap[String, String]()
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env("SPARK_YARN_CACHE_ARCHIVES") === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === "10")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === "20")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name())
+
+ distMgr.setDistFilesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_FILES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None)
+ }
+
+
+}
diff --git a/pom.xml b/pom.xml
index 334958382a..8c87e431a5 100644
--- a/pom.xml
+++ b/pom.xml
@@ -99,13 +99,17 @@
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
- <java.version>1.5</java.version>
- <scala.version>2.9.3</scala.version>
+ <java.version>1.6</java.version>
+
+ <scala.version>2.10.3</scala.version>
+ <scala.binary.version>2.10</scala.binary.version>
<mesos.version>0.13.0</mesos.version>
- <akka.version>2.0.5</akka.version>
+ <akka.group>org.spark-project.akka</akka.group>
+ <akka.version>2.2.3-shaded-protobuf</akka.version>
<slf4j.version>1.7.2</slf4j.version>
<log4j.version>1.2.17</log4j.version>
<hadoop.version>1.0.4</hadoop.version>
+ <protobuf.version>2.4.1</protobuf.version>
<yarn.version>0.23.7</yarn.version>
<hbase.version>0.94.6</hbase.version>
@@ -114,10 +118,10 @@
</properties>
<repositories>
- <repository>
- <id>jboss-repo</id>
- <name>JBoss Repository</name>
- <url>http://repository.jboss.org/nexus/content/repositories/releases/</url>
+ <repository>
+ <id>maven-repo</id> <!-- This should be at top, it makes maven try the central repo first and then others and hence faster dep resolution -->
+ <name>Maven Repository</name>
+ <url>http://repo.maven.apache.org/maven2</url>
<releases>
<enabled>true</enabled>
</releases>
@@ -126,9 +130,9 @@
</snapshots>
</repository>
<repository>
- <id>cloudera-repo</id>
- <name>Cloudera Repository</name>
- <url>https://repository.cloudera.com/artifactory/cloudera-repos/</url>
+ <id>jboss-repo</id>
+ <name>JBoss Repository</name>
+ <url>http://repository.jboss.org/nexus/content/repositories/releases</url>
<releases>
<enabled>true</enabled>
</releases>
@@ -137,9 +141,9 @@
</snapshots>
</repository>
<repository>
- <id>akka-repo</id>
- <name>Akka Repository</name>
- <url>http://repo.akka.io/releases/</url>
+ <id>mqtt-repo</id>
+ <name>MQTT Repository</name>
+ <url>https://repo.eclipse.org/content/repositories/paho-releases</url>
<releases>
<enabled>true</enabled>
</releases>
@@ -148,41 +152,6 @@
</snapshots>
</repository>
</repositories>
- <pluginRepositories>
- <pluginRepository>
- <id>oss-sonatype-releases</id>
- <name>OSS Sonatype</name>
- <url>https://oss.sonatype.org/content/repositories/releases</url>
- <releases>
- <enabled>true</enabled>
- </releases>
- <snapshots>
- <enabled>false</enabled>
- </snapshots>
- </pluginRepository>
- <pluginRepository>
- <id>oss-sonatype-snapshots</id>
- <name>OSS Sonatype</name>
- <url>https://oss.sonatype.org/content/repositories/snapshots</url>
- <releases>
- <enabled>false</enabled>
- </releases>
- <snapshots>
- <enabled>true</enabled>
- </snapshots>
- </pluginRepository>
- <pluginRepository>
- <id>oss-sonatype</id>
- <name>OSS Sonatype</name>
- <url>https://oss.sonatype.org/content/groups/public</url>
- <releases>
- <enabled>true</enabled>
- </releases>
- <snapshots>
- <enabled>true</enabled>
- </snapshots>
- </pluginRepository>
- </pluginRepositories>
<dependencyManagement>
<dependencies>
@@ -231,6 +200,11 @@
<artifactId>asm</artifactId>
<version>4.0</version>
</dependency>
+ <!-- In theory we need not directly depend on protobuf since Spark does not directly
+ use it. However, when building with Hadoop/YARN 2.2 Maven doesn't correctly bump
+ the protobuf version up from the one Mesos gives. For now we include this variable
+ to explicitly bump the version when building with YARN. It would be nice to figure
+ out why Maven can't resolve this correctly (like SBT does). -->
<dependency>
<groupId>com.clearspring.analytics</groupId>
<artifactId>stream</artifactId>
@@ -239,11 +213,11 @@
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
- <version>2.4.1</version>
+ <version>${protobuf.version}</version>
</dependency>
<dependency>
<groupId>com.twitter</groupId>
- <artifactId>chill_2.9.3</artifactId>
+ <artifactId>chill_${scala.binary.version}</artifactId>
<version>0.3.1</version>
</dependency>
<dependency>
@@ -252,19 +226,48 @@
<version>0.3.1</version>
</dependency>
<dependency>
- <groupId>com.typesafe.akka</groupId>
- <artifactId>akka-actor</artifactId>
+ <groupId>${akka.group}</groupId>
+ <artifactId>akka-actor_${scala.binary.version}</artifactId>
<version>${akka.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
- <groupId>com.typesafe.akka</groupId>
- <artifactId>akka-remote</artifactId>
+ <groupId>${akka.group}</groupId>
+ <artifactId>akka-remote_${scala.binary.version}</artifactId>
<version>${akka.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
- <groupId>com.typesafe.akka</groupId>
- <artifactId>akka-slf4j</artifactId>
+ <groupId>${akka.group}</groupId>
+ <artifactId>akka-slf4j_${scala.binary.version}</artifactId>
<version>${akka.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>${akka.group}</groupId>
+ <artifactId>akka-zeromq_${scala.binary.version}</artifactId>
+ <version>${akka.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>it.unimi.dsi</groupId>
@@ -277,11 +280,6 @@
<version>1.2.0</version>
</dependency>
<dependency>
- <groupId>com.github.scala-incubator.io</groupId>
- <artifactId>scala-io-file_2.9.2</artifactId>
- <version>0.4.1</version>
- </dependency>
- <dependency>
<groupId>org.apache.mesos</groupId>
<artifactId>mesos</artifactId>
<version>${mesos.version}</version>
@@ -289,7 +287,7 @@
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
- <version>4.0.0.Beta2</version>
+ <version>4.0.0.CR1</version>
</dependency>
<dependency>
<groupId>org.apache.derby</groupId>
@@ -299,8 +297,14 @@
</dependency>
<dependency>
<groupId>net.liftweb</groupId>
- <artifactId>lift-json_2.9.2</artifactId>
- <version>2.5</version>
+ <artifactId>lift-json_${scala.binary.version}</artifactId>
+ <version>2.5.1</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.scala-lang</groupId>
+ <artifactId>scalap</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>com.codahale.metrics</groupId>
@@ -323,6 +327,11 @@
<version>3.0.0</version>
</dependency>
<dependency>
+ <groupId>com.codahale.metrics</groupId>
+ <artifactId>metrics-graphite</artifactId>
+ <version>3.0.0</version>
+ </dependency>
+ <dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-compiler</artifactId>
<version>${scala.version}</version>
@@ -338,32 +347,36 @@
<version>${scala.version}</version>
</dependency>
<dependency>
- <groupId>org.scala-lang</groupId>
- <artifactId>scalap</artifactId>
- <version>${scala.version}</version>
- </dependency>
-
- <dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
<version>${log4j.version}</version>
</dependency>
-
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_2.9.3</artifactId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
<version>1.9.1</version>
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>commons-io</groupId>
+ <artifactId>commons-io</artifactId>
+ <version>2.4</version>
+ </dependency>
+ <dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>
<version>3.1</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <version>1.8.5</version>
+ <scope>test</scope>
+ </dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_2.9.3</artifactId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
<version>1.10.0</version>
<scope>test</scope>
</dependency>
@@ -399,19 +412,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@@ -435,19 +440,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@@ -466,22 +463,15 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
+
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-yarn-client</artifactId>
@@ -497,19 +487,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@@ -528,6 +510,10 @@
<groupId>org.jboss.netty</groupId>
<artifactId>netty</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>io.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
</exclusions>
</dependency>
</dependencies>
@@ -741,6 +727,7 @@
<hadoop.major.version>2</hadoop.major.version>
<!-- 0.23.* is same as 2.0.* - except hardened to run production jobs -->
<hadoop.version>0.23.7</hadoop.version>
+ <protobuf.version>2.5.0</protobuf.version>
<!--<hadoop.version>2.0.5-alpha</hadoop.version> -->
</properties>
@@ -752,7 +739,7 @@
<repository>
<id>maven-root</id>
<name>Maven root repository</name>
- <url>http://repo1.maven.org/maven2/</url>
+ <url>http://repo1.maven.org/maven2</url>
<releases>
<enabled>true</enabled>
</releases>
@@ -767,6 +754,39 @@
</dependencies>
</dependencyManagement>
</profile>
+
+ <profile>
+ <id>new-yarn</id>
+ <properties>
+ <hadoop.major.version>2</hadoop.major.version>
+ <hadoop.version>2.2.0</hadoop.version>
+ <protobuf.version>2.5.0</protobuf.version>
+ </properties>
+
+ <modules>
+ <module>new-yarn</module>
+ </modules>
+
+ <repositories>
+ <repository>
+ <id>maven-root</id>
+ <name>Maven root repository</name>
+ <url>http://repo1.maven.org/maven2</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
+ </repository>
+ </repositories>
+
+ <dependencyManagement>
+ <dependencies>
+ </dependencies>
+ </dependencyManagement>
+ </profile>
+
<profile>
<id>repl-bin</id>
<activation>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index b332485e82..7bcbd90bd3 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -28,14 +28,19 @@ object SparkBuild extends Build {
// "2.0.0-mr1-cdh4.2.0" for Cloudera Hadoop. Note that these variables can be set
// through the environment variables SPARK_HADOOP_VERSION and SPARK_YARN.
val DEFAULT_HADOOP_VERSION = "1.0.4"
+
+ // Whether the Hadoop version to build against is 2.2.x, or a variant of it. This can be set
+ // through the SPARK_IS_NEW_HADOOP environment variable.
+ val DEFAULT_IS_NEW_HADOOP = false
+
val DEFAULT_YARN = false
// HBase version; set as appropriate.
val HBASE_VERSION = "0.94.6"
// Target JVM version
- val SCALAC_JVM_VERSION = "jvm-1.5"
- val JAVAC_JVM_VERSION = "1.5"
+ val SCALAC_JVM_VERSION = "jvm-1.6"
+ val JAVAC_JVM_VERSION = "1.6"
lazy val root = Project("root", file("."), settings = rootSettings) aggregate(allProjects: _*)
@@ -55,32 +60,47 @@ object SparkBuild extends Build {
lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn(core)
- lazy val yarn = Project("yarn", file("yarn"), settings = yarnSettings) dependsOn(core)
-
lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings)
.dependsOn(core, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*)
+ lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects")
+
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
// Allows build configuration to be set through environment variables
lazy val hadoopVersion = scala.util.Properties.envOrElse("SPARK_HADOOP_VERSION", DEFAULT_HADOOP_VERSION)
+ lazy val isNewHadoop = scala.util.Properties.envOrNone("SPARK_IS_NEW_HADOOP") match {
+ case None => {
+ val isNewHadoopVersion = "2.[2-9]+".r.findFirstIn(hadoopVersion).isDefined
+ (isNewHadoopVersion|| DEFAULT_IS_NEW_HADOOP)
+ }
+ case Some(v) => v.toBoolean
+ }
+
lazy val isYarnEnabled = scala.util.Properties.envOrNone("SPARK_YARN") match {
case None => DEFAULT_YARN
case Some(v) => v.toBoolean
}
// Conditionally include the yarn sub-project
- lazy val maybeYarn = if(isYarnEnabled) Seq[ClasspathDependency](yarn) else Seq[ClasspathDependency]()
- lazy val maybeYarnRef = if(isYarnEnabled) Seq[ProjectReference](yarn) else Seq[ProjectReference]()
- lazy val allProjects = Seq[ProjectReference](
- core, repl, examples, bagel, streaming, mllib, tools, assemblyProj) ++ maybeYarnRef
+ lazy val yarn = Project("yarn", file(if (isNewHadoop) "new-yarn" else "yarn"), settings = yarnSettings) dependsOn(core)
+
+ //lazy val yarn = Project("yarn", file("yarn"), settings = yarnSettings) dependsOn(core)
+
+ lazy val maybeYarn = if (isYarnEnabled) Seq[ClasspathDependency](yarn) else Seq[ClasspathDependency]()
+ lazy val maybeYarnRef = if (isYarnEnabled) Seq[ProjectReference](yarn) else Seq[ProjectReference]()
+
+ // Everything except assembly, tools and examples belong to packageProjects
+ lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib) ++ maybeYarnRef
+
+ lazy val allProjects = packageProjects ++ Seq[ProjectReference](examples, tools, assemblyProj)
def sharedSettings = Defaults.defaultSettings ++ Seq(
- organization := "org.apache.spark",
- version := "0.9.0-incubating-SNAPSHOT",
- scalaVersion := "2.9.3",
+ organization := "org.apache.spark",
+ version := "0.9.0-incubating-SNAPSHOT",
+ scalaVersion := "2.10.3",
scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation",
"-target:" + SCALAC_JVM_VERSION),
javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION),
@@ -94,16 +114,16 @@ object SparkBuild extends Build {
fork := true,
javaOptions += "-Xmx3g",
+ // Show full stack trace and duration in test cases.
+ testOptions in Test += Tests.Argument("-oDF"),
+
// Only allow one test at a time, even across projects, since they run in the same JVM
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
// also check the local Maven repository ~/.m2
resolvers ++= Seq(Resolver.file("Local Maven Repo", file(Path.userHome + "/.m2/repository"))),
- // Shared between both core and streaming.
- resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"),
-
- // For Sonatype publishing
+ // For Sonatype publishing
resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
"sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"),
@@ -158,12 +178,19 @@ object SparkBuild extends Build {
libraryDependencies ++= Seq(
- "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106",
- "org.scalatest" %% "scalatest" % "1.9.1" % "test",
- "org.scalacheck" %% "scalacheck" % "1.10.0" % "test",
- "com.novocode" % "junit-interface" % "0.9" % "test",
- "org.easymock" % "easymock" % "3.1" % "test"
+ "io.netty" % "netty-all" % "4.0.0.CR1",
+ "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106",
+ /** Workaround for SPARK-959. Dependency used by org.eclipse.jetty. Fixed in ivy 2.3.0. */
+ "org.eclipse.jetty.orbit" % "javax.servlet" % "2.5.0.v201103041518" artifacts Artifact("javax.servlet", "jar", "jar"),
+ "org.scalatest" %% "scalatest" % "1.9.1" % "test",
+ "org.scalacheck" %% "scalacheck" % "1.10.0" % "test",
+ "com.novocode" % "junit-interface" % "0.9" % "test",
+ "org.easymock" % "easymock" % "3.1" % "test",
+ "org.mockito" % "mockito-all" % "1.8.5" % "test",
+ "commons-io" % "commons-io" % "2.4" % "test"
),
+
+ parallelExecution := true,
/* Workaround for issue #206 (fixed after SBT 0.11.0) */
watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task,
const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) },
@@ -188,62 +215,61 @@ object SparkBuild extends Build {
def coreSettings = sharedSettings ++ Seq(
name := "spark-core",
resolvers ++= Seq(
- "JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/",
- "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/"
+ "JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/",
+ "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/"
),
libraryDependencies ++= Seq(
- "com.google.guava" % "guava" % "14.0.1",
- "com.google.code.findbugs" % "jsr305" % "1.3.9",
- "log4j" % "log4j" % "1.2.17",
- "org.slf4j" % "slf4j-api" % slf4jVersion,
- "org.slf4j" % "slf4j-log4j12" % slf4jVersion,
- "commons-daemon" % "commons-daemon" % "1.0.10", // workaround for bug HADOOP-9407
- "com.ning" % "compress-lzf" % "0.8.4",
- "org.xerial.snappy" % "snappy-java" % "1.0.5",
- "org.ow2.asm" % "asm" % "4.0",
- "com.google.protobuf" % "protobuf-java" % "2.4.1",
- "com.typesafe.akka" % "akka-actor" % "2.0.5" excludeAll(excludeNetty),
- "com.typesafe.akka" % "akka-remote" % "2.0.5" excludeAll(excludeNetty),
- "com.typesafe.akka" % "akka-slf4j" % "2.0.5" excludeAll(excludeNetty),
- "it.unimi.dsi" % "fastutil" % "6.4.4",
- "colt" % "colt" % "1.2.0",
- "net.liftweb" % "lift-json_2.9.2" % "2.5",
- "org.apache.mesos" % "mesos" % "0.13.0",
- "io.netty" % "netty-all" % "4.0.0.Beta2",
- "org.apache.derby" % "derby" % "10.4.2.0" % "test",
- "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib),
- "net.java.dev.jets3t" % "jets3t" % "0.7.1",
- "org.apache.avro" % "avro" % "1.7.4",
- "org.apache.avro" % "avro-ipc" % "1.7.4" excludeAll(excludeNetty),
- "org.apache.zookeeper" % "zookeeper" % "3.4.5" excludeAll(excludeNetty),
- "com.codahale.metrics" % "metrics-core" % "3.0.0",
- "com.codahale.metrics" % "metrics-jvm" % "3.0.0",
- "com.codahale.metrics" % "metrics-json" % "3.0.0",
- "com.codahale.metrics" % "metrics-ganglia" % "3.0.0",
- "com.twitter" % "chill_2.9.3" % "0.3.1",
- "com.twitter" % "chill-java" % "0.3.1",
- "com.clearspring.analytics" % "stream" % "2.4.0"
- )
+ "com.google.guava" % "guava" % "14.0.1",
+ "com.google.code.findbugs" % "jsr305" % "1.3.9",
+ "log4j" % "log4j" % "1.2.17",
+ "org.slf4j" % "slf4j-api" % slf4jVersion,
+ "org.slf4j" % "slf4j-log4j12" % slf4jVersion,
+ "commons-daemon" % "commons-daemon" % "1.0.10", // workaround for bug HADOOP-9407
+ "com.ning" % "compress-lzf" % "0.8.4",
+ "org.xerial.snappy" % "snappy-java" % "1.0.5",
+ "org.ow2.asm" % "asm" % "4.0",
+ "org.spark-project.akka" %% "akka-remote" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty),
+ "org.spark-project.akka" %% "akka-slf4j" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty),
+ "net.liftweb" %% "lift-json" % "2.5.1" excludeAll(excludeNetty),
+ "it.unimi.dsi" % "fastutil" % "6.4.4",
+ "colt" % "colt" % "1.2.0",
+ "org.apache.mesos" % "mesos" % "0.13.0",
+ "net.java.dev.jets3t" % "jets3t" % "0.7.1",
+ "org.apache.derby" % "derby" % "10.4.2.0" % "test",
+ "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib),
+ "org.apache.avro" % "avro" % "1.7.4",
+ "org.apache.avro" % "avro-ipc" % "1.7.4" excludeAll(excludeNetty),
+ "org.apache.zookeeper" % "zookeeper" % "3.4.5" excludeAll(excludeNetty),
+ "com.codahale.metrics" % "metrics-core" % "3.0.0",
+ "com.codahale.metrics" % "metrics-jvm" % "3.0.0",
+ "com.codahale.metrics" % "metrics-json" % "3.0.0",
+ "com.codahale.metrics" % "metrics-ganglia" % "3.0.0",
+ "com.codahale.metrics" % "metrics-graphite" % "3.0.0",
+ "com.twitter" %% "chill" % "0.3.1",
+ "com.twitter" % "chill-java" % "0.3.1"
+ )
)
def rootSettings = sharedSettings ++ Seq(
publish := {}
)
- def replSettings = sharedSettings ++ Seq(
+ def replSettings = sharedSettings ++ Seq(
name := "spark-repl",
- libraryDependencies <+= scalaVersion("org.scala-lang" % "scala-compiler" % _)
+ libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v ),
+ libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "jline" % v ),
+ libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-reflect" % v )
)
+
def examplesSettings = sharedSettings ++ Seq(
name := "spark-examples",
libraryDependencies ++= Seq(
- "com.twitter" % "algebird-core_2.9.2" % "0.1.11",
-
+ "com.twitter" %% "algebird-core" % "0.1.11",
+ "org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty, excludeAsm),
"org.apache.hbase" % "hbase" % HBASE_VERSION excludeAll(excludeNetty, excludeAsm),
-
- "org.apache.cassandra" % "cassandra-all" % "1.2.5"
+ "org.apache.cassandra" % "cassandra-all" % "1.2.6"
exclude("com.google.guava", "guava")
exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru")
exclude("com.ning","compress-lzf")
@@ -258,7 +284,7 @@ object SparkBuild extends Build {
def toolsSettings = sharedSettings ++ Seq(
name := "spark-tools"
- )
+ ) ++ assemblySettings ++ extraAssemblySettings
def bagelSettings = sharedSettings ++ Seq(
name := "spark-bagel"
@@ -274,13 +300,21 @@ object SparkBuild extends Build {
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
resolvers ++= Seq(
- "Akka Repository" at "http://repo.akka.io/releases/"
+ "Eclipse Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/",
+ "Apache repo" at "https://repository.apache.org/content/repositories/releases"
),
+
libraryDependencies ++= Seq(
- "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy),
- "com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty),
- "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
- "com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty)
+ "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy),
+ "com.sksamuel.kafka" %% "kafka" % "0.8.0-beta1"
+ exclude("com.sun.jdmk", "jmxtools")
+ exclude("com.sun.jmx", "jmxri")
+ exclude("net.sf.jopt-simple", "jopt-simple")
+ excludeAll(excludeNetty),
+ "org.eclipse.paho" % "mqtt-client" % "0.4.0",
+ "com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty),
+ "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
+ "org.spark-project.akka" %% "akka-zeromq" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty)
)
)
@@ -304,7 +338,9 @@ object SparkBuild extends Build {
def assemblyProjSettings = sharedSettings ++ Seq(
name := "spark-assembly",
- jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }
+ assembleDeps in Compile <<= (packageProjects.map(packageBin in Compile in _) ++ Seq(packageDependency in Compile)).dependOn,
+ jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" },
+ jarName in packageDependency <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" }
) ++ assemblySettings ++ extraAssemblySettings
def extraAssemblySettings() = Seq(
@@ -313,7 +349,7 @@ object SparkBuild extends Build {
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
case "log4j.properties" => MergeStrategy.discard
- case "META-INF/services/org.apache.hadoop.fs.FileSystem" => MergeStrategy.concat
+ case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines
case "reference.conf" => MergeStrategy.concat
case _ => MergeStrategy.first
}
diff --git a/project/plugins.sbt b/project/plugins.sbt
index cfcd85082a..4ba0e4280a 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -4,7 +4,7 @@ resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/release
resolvers += "Spray Repository" at "http://repo.spray.cc/"
-addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.9.1")
+addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.9.2")
addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0")
diff --git a/pyspark b/pyspark
index 4941a36d0d..12cc926dda 100755
--- a/pyspark
+++ b/pyspark
@@ -23,7 +23,7 @@ FWDIR="$(cd `dirname $0`; pwd)"
# Export this as SPARK_HOME
export SPARK_HOME="$FWDIR"
-SCALA_VERSION=2.9.3
+SCALA_VERSION=2.10
# Exit if the user hasn't compiled Spark
if [ ! -f "$FWDIR/RELEASE" ]; then
@@ -59,8 +59,12 @@ if [ -n "$IPYTHON_OPTS" ]; then
fi
if [[ "$IPYTHON" = "1" ]] ; then
- IPYTHON_OPTS=${IPYTHON_OPTS:--i}
- exec ipython "$IPYTHON_OPTS" -c "%run $PYTHONSTARTUP"
+ # IPython <1.0.0 doesn't honor PYTHONSTARTUP, while 1.0.0+ does.
+ # Hence we clear PYTHONSTARTUP and use the -c "%run $IPYTHONSTARTUP" command which works on all versions
+ # We also force interactive mode with "-i"
+ IPYTHONSTARTUP=$PYTHONSTARTUP
+ PYTHONSTARTUP=
+ exec ipython "$IPYTHON_OPTS" -i -c "%run $IPYTHONSTARTUP"
else
exec "$PYSPARK_PYTHON" "$@"
fi
diff --git a/pyspark2.cmd b/pyspark2.cmd
index f58e349643..21f9a34388 100644
--- a/pyspark2.cmd
+++ b/pyspark2.cmd
@@ -17,7 +17,7 @@ rem See the License for the specific language governing permissions and
rem limitations under the License.
rem
-set SCALA_VERSION=2.9.3
+set SCALA_VERSION=2.10
rem Figure out where the Spark framework is installed
set FWDIR=%~dp0
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 1d0d002d36..0b42e729f8 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -32,6 +32,6 @@ target: docs/
private: no
-exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
pyspark.rddsampler pyspark.daemon
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index d367f91967..2204e9c9ca 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -42,6 +42,13 @@
>>> a.value
13
+>>> b = sc.accumulator(0)
+>>> def g(x):
+... b.add(x)
+>>> rdd.foreach(g)
+>>> b.value
+6
+
>>> from pyspark.accumulators import AccumulatorParam
>>> class VectorAccumulatorParam(AccumulatorParam):
... def zero(self, value):
@@ -83,8 +90,10 @@ import struct
import SocketServer
import threading
from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import read_int, read_with_length, load_pickle
+from pyspark.serializers import read_int, PickleSerializer
+
+pickleSer = PickleSerializer()
# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
# the local accumulator updates back to the driver program at the end of a task.
@@ -139,9 +148,13 @@ class Accumulator(object):
raise Exception("Accumulator.value cannot be accessed inside tasks")
self._value = value
+ def add(self, term):
+ """Adds a term to this accumulator's value"""
+ self._value = self.accum_param.addInPlace(self._value, term)
+
def __iadd__(self, term):
"""The += operator; adds a term to this accumulator's value"""
- self._value = self.accum_param.addInPlace(self._value, term)
+ self.add(term)
return self
def __str__(self):
@@ -200,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
from pyspark.accumulators import _accumulatorRegistry
num_updates = read_int(self.rfile)
for _ in range(num_updates):
- (aid, update) = load_pickle(read_with_length(self.rfile))
+ (aid, update) = pickleSer._read_with_length(self.rfile)
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 597110321a..0604f6836c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
@@ -42,15 +42,15 @@ class SparkContext(object):
_gateway = None
_jvm = None
- _writeIteratorToPickleFile = None
- _takePartition = None
+ _writeToFile = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
+
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
- environment=None, batchSize=1024):
+ environment=None, batchSize=1024, serializer=PickleSerializer()):
"""
Create a new SparkContext.
@@ -66,24 +66,30 @@ class SparkContext(object):
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
+ @param serializer: The serializer for RDDs.
+
+
+ >>> from pyspark.context import SparkContext
+ >>> sc = SparkContext('local', 'test')
+
+ >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
"""
- with SparkContext._lock:
- if SparkContext._active_spark_context:
- raise ValueError("Cannot run multiple SparkContexts at once")
- else:
- SparkContext._active_spark_context = self
- if not SparkContext._gateway:
- SparkContext._gateway = launch_gateway()
- SparkContext._jvm = SparkContext._gateway.jvm
- SparkContext._writeIteratorToPickleFile = \
- SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
- SparkContext._takePartition = \
- SparkContext._jvm.PythonRDD.takePartition
+ SparkContext._ensure_initialized(self)
+
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
self.environment = environment or {}
- self.batchSize = batchSize # -1 represents a unlimited batch size
+ self._batchSize = batchSize # -1 represents an unlimited batch size
+ self._unbatched_serializer = serializer
+ if batchSize == 1:
+ self.serializer = self._unbatched_serializer
+ else:
+ self.serializer = BatchedSerializer(self._unbatched_serializer,
+ batchSize)
# Create the Java SparkContext through Py4J
empty_string_array = self._gateway.new_array(self._jvm.String, 0)
@@ -119,6 +125,30 @@ class SparkContext(object):
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
+ @classmethod
+ def _ensure_initialized(cls, instance=None):
+ with SparkContext._lock:
+ if not SparkContext._gateway:
+ SparkContext._gateway = launch_gateway()
+ SparkContext._jvm = SparkContext._gateway.jvm
+ SparkContext._writeToFile = \
+ SparkContext._jvm.PythonRDD.writeToFile
+
+ if instance:
+ if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
+ raise ValueError("Cannot run multiple SparkContexts at once")
+ else:
+ SparkContext._active_spark_context = instance
+
+ @classmethod
+ def setSystemProperty(cls, key, value):
+ """
+ Set a system property, such as spark.executor.memory. This must be
+ invoked before instantiating SparkContext.
+ """
+ SparkContext._ensure_initialized()
+ SparkContext._jvm.java.lang.System.setProperty(key, value)
+
@property
def defaultParallelism(self):
"""
@@ -158,15 +188,17 @@ class SparkContext(object):
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
- batchSize = min(len(c) // numSlices, self.batchSize)
+ batchSize = min(len(c) // numSlices, self._batchSize)
if batchSize > 1:
- c = batched(c, batchSize)
- for x in c:
- write_with_length(dump_pickle(x), tempFile)
+ serializer = BatchedSerializer(self._unbatched_serializer,
+ batchSize)
+ else:
+ serializer = self._unbatched_serializer
+ serializer.dump_stream(c, tempFile)
tempFile.close()
- readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
- jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
- return RDD(jrdd, self)
+ readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+ jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+ return RDD(jrdd, self, serializer)
def textFile(self, name, minSplits=None):
"""
@@ -175,21 +207,39 @@ class SparkContext(object):
RDD of Strings.
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
- jrdd = self._jsc.textFile(name, minSplits)
- return RDD(jrdd, self)
+ return RDD(self._jsc.textFile(name, minSplits), self,
+ MUTF8Deserializer())
- def _checkpointFile(self, name):
+ def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
- return RDD(jrdd, self)
+ return RDD(jrdd, self, input_deserializer)
def union(self, rdds):
"""
Build the union of a list of RDDs.
+
+ This supports unions() of RDDs with different serialized formats,
+ although this forces them to be reserialized using the default
+ serializer:
+
+ >>> path = os.path.join(tempdir, "union-text.txt")
+ >>> with open(path, "w") as testFile:
+ ... testFile.write("Hello")
+ >>> textFile = sc.textFile(path)
+ >>> textFile.collect()
+ [u'Hello']
+ >>> parallelized = sc.parallelize(["World!"])
+ >>> sorted(sc.union([textFile, parallelized]).collect())
+ [u'Hello', 'World!']
"""
+ first_jrdd_deserializer = rdds[0]._jrdd_deserializer
+ if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
+ rdds = [x._reserialize() for x in rdds]
first = rdds[0]._jrdd
rest = [x._jrdd for x in rdds[1:]]
- rest = ListConverter().convert(rest, self.gateway._gateway_client)
- return RDD(self._jsc.union(first, rest), self)
+ rest = ListConverter().convert(rest, self._gateway._gateway_client)
+ return RDD(self._jsc.union(first, rest), self,
+ rdds[0]._jrdd_deserializer)
def broadcast(self, value):
"""
@@ -197,7 +247,9 @@ class SparkContext(object):
object for reading it in distributed functions. The variable will be
sent to each cluster only once.
"""
- jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+ pickleSer = PickleSerializer()
+ pickled = pickleSer.dumps(value)
+ jbroadcast = self._jsc.broadcast(bytearray(pickled))
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
@@ -209,7 +261,7 @@ class SparkContext(object):
and floating-point numbers if you do not provide one. For other types,
a custom AccumulatorParam can be used.
"""
- if accum_param == None:
+ if accum_param is None:
if isinstance(value, int):
accum_param = accumulators.INT_ACCUMULATOR_PARAM
elif isinstance(value, float):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7019fb8bee..f87923e6fa 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,7 +18,7 @@
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
-from itertools import chain, ifilter, imap, product
+from itertools import chain, ifilter, imap
import operator
import os
import sys
@@ -27,9 +27,8 @@ from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
-from pyspark import cloudpickle
-from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
- read_from_pickle_file, pack_long
+from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
+ BatchedSerializer, CloudPickleSerializer, pack_long
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -48,12 +47,15 @@ class RDD(object):
operated on in parallel.
"""
- def __init__(self, jrdd, ctx):
+ def __init__(self, jrdd, ctx, jrdd_deserializer):
self._jrdd = jrdd
self.is_cached = False
self.is_checkpointed = False
self.ctx = ctx
- self._partitionFunc = None
+ self._jrdd_deserializer = jrdd_deserializer
+
+ def __repr__(self):
+ return self._jrdd.toString()
@property
def context(self):
@@ -247,7 +249,23 @@ class RDD(object):
>>> rdd.union(rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
"""
- return RDD(self._jrdd.union(other._jrdd), self.ctx)
+ if self._jrdd_deserializer == other._jrdd_deserializer:
+ rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
+ self._jrdd_deserializer)
+ return rdd
+ else:
+ # These RDDs contain data in different serialized formats, so we
+ # must normalize them to the default serializer.
+ self_copy = self._reserialize()
+ other_copy = other._reserialize()
+ return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+ self.ctx.serializer)
+
+ def _reserialize(self):
+ if self._jrdd_deserializer == self.ctx.serializer:
+ return self
+ else:
+ return self.map(lambda x: x, preservesPartitioning=True)
def __add__(self, other):
"""
@@ -334,17 +352,9 @@ class RDD(object):
[(1, 1), (1, 2), (2, 1), (2, 2)]
"""
# Due to batching, we can't use the Java cartesian method.
- java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
- def unpack_batches(pair):
- (x, y) = pair
- if type(x) == Batch or type(y) == Batch:
- xs = x.items if type(x) == Batch else [x]
- ys = y.items if type(y) == Batch else [y]
- for pair in product(xs, ys):
- yield pair
- else:
- yield pair
- return java_cartesian.flatMap(unpack_batches)
+ deserializer = CartesianDeserializer(self._jrdd_deserializer,
+ other._jrdd_deserializer)
+ return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
def groupBy(self, f, numPartitions=None):
"""
@@ -391,8 +401,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- picklesInJava = self._jrdd.collect().iterator()
- return list(self._collect_iterator_through_file(picklesInJava))
+ bytesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
@@ -400,10 +410,10 @@ class RDD(object):
# file and read it back.
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
- self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+ self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
- for item in read_from_pickle_file(tempFile):
+ for item in self._jrdd_deserializer.load_stream(tempFile):
yield item
os.unlink(tempFile.name)
@@ -569,9 +579,14 @@ class RDD(object):
# Take only up to num elements from each partition we try
mapped = self.mapPartitions(takeUpToNum)
items = []
+ # TODO(shivaram): Similar to the scala implementation, update the take
+ # method to scan multiple splits based on an estimate of how many elements
+ # we have per-split.
for partition in range(mapped._jrdd.splits().size()):
- iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
- items.extend(self._collect_iterator_through_file(iterator))
+ partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+ partitionsToTake[0] = partition
+ iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
+ items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break
return items[:num]
@@ -598,7 +613,10 @@ class RDD(object):
'0\\n1\\n2\\n3\\n4\\n5\\n6\\n7\\n8\\n9\\n'
"""
def func(split, iterator):
- return (str(x).encode("utf-8") for x in iterator)
+ for x in iterator:
+ if not isinstance(x, basestring):
+ x = unicode(x)
+ yield x.encode("utf-8")
keyed = PipelinedRDD(self, func)
keyed._bypass_serializer = True
keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
@@ -735,6 +753,7 @@ class RDD(object):
# Transferring O(n) objects to Java is too expensive. Instead, we'll
# form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
+ outputSerializer = self.ctx._unbatched_serializer
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
@@ -743,14 +762,14 @@ class RDD(object):
buckets[partitionFunc(k) % numPartitions].append((k, v))
for (split, items) in buckets.iteritems():
yield pack_long(split)
- yield dump_pickle(Batch(items))
+ yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
- rdd = RDD(jrdd, self.ctx)
+ rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if
# partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
@@ -787,7 +806,8 @@ class RDD(object):
numPartitions = self.ctx.defaultParallelism
def combineLocally(iterator):
combiners = {}
- for (k, v) in iterator:
+ for x in iterator:
+ (k, v) = x
if k not in combiners:
combiners[k] = createCombiner(v)
else:
@@ -929,50 +949,52 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
- if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
+ if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
+ # This transformation is the first in its stage:
+ self.func = func
+ self.preservesPartitioning = preservesPartitioning
+ self._prev_jrdd = prev._jrdd
+ self._prev_jrdd_deserializer = prev._jrdd_deserializer
+ else:
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
self.func = pipeline_func
self.preservesPartitioning = \
prev.preservesPartitioning and preservesPartitioning
- self._prev_jrdd = prev._prev_jrdd
- else:
- self.func = func
- self.preservesPartitioning = preservesPartitioning
- self._prev_jrdd = prev._jrdd
+ self._prev_jrdd = prev._prev_jrdd # maintain the pipeline
+ self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
+ self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
@property
def _jrdd(self):
if self._jrdd_val:
return self._jrdd_val
- func = self.func
- if not self._bypass_serializer and self.ctx.batchSize != 1:
- oldfunc = self.func
- batchSize = self.ctx.batchSize
- def batched_func(split, iterator):
- return batched(oldfunc(split, iterator), batchSize)
- func = batched_func
- cmds = [func, self._bypass_serializer]
- pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+ if self._bypass_serializer:
+ serializer = NoOpSerializer()
+ else:
+ serializer = self.ctx.serializer
+ command = (self.func, self._prev_jrdd_deserializer, serializer)
+ pickled_command = CloudPickleSerializer().dumps(command)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
- class_manifest = self._prev_jrdd.classManifest()
+ class_tag = self._prev_jrdd.classTag()
env = MapConverter().convert(self.ctx.environment,
self.ctx._gateway._gateway_client)
includes = ListConverter().convert(self.ctx._python_includes,
self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator, class_manifest)
+ bytearray(pickled_command), env, includes, self.preservesPartitioning,
+ self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
+ class_tag)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 54fed1c9c7..811fa6f018 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -15,45 +15,269 @@
# limitations under the License.
#
-import struct
+"""
+PySpark supports custom serializers for transferring data; this can improve
+performance.
+
+By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
+C{cPickle} serializer, which can serialize nearly any Python object.
+Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
+faster.
+
+The serializer is chosen when creating L{SparkContext}:
+
+>>> from pyspark.context import SparkContext
+>>> from pyspark.serializers import MarshalSerializer
+>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
+>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+>>> sc.stop()
+
+By default, PySpark serialize objects in batches; the batch size can be
+controlled through SparkContext's C{batchSize} parameter
+(the default size is 1024 objects):
+
+>>> sc = SparkContext('local', 'test', batchSize=2)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+
+Behind the scenes, this creates a JavaRDD with four partitions, each of
+which contains two batches of two objects:
+
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+8L
+>>> sc.stop()
+
+A batch size of -1 uses an unlimited batch size, and a size of 1 disables
+batching:
+
+>>> sc = SparkContext('local', 'test', batchSize=1)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+16L
+"""
+
import cPickle
+from itertools import chain, izip, product
+import marshal
+import struct
+from pyspark import cloudpickle
+
+
+__all__ = ["PickleSerializer", "MarshalSerializer"]
+
+
+class SpecialLengths(object):
+ END_OF_DATA_SECTION = -1
+ PYTHON_EXCEPTION_THROWN = -2
+ TIMING_DATA = -3
+
+
+class Serializer(object):
+
+ def dump_stream(self, iterator, stream):
+ """
+ Serialize an iterator of objects to the output stream.
+ """
+ raise NotImplementedError
+
+ def load_stream(self, stream):
+ """
+ Return an iterator of deserialized objects from the input stream.
+ """
+ raise NotImplementedError
+
+
+ def _load_stream_without_unbatching(self, stream):
+ return self.load_stream(stream)
+
+ # Note: our notion of "equality" is that output generated by
+ # equal serializers can be deserialized using the same serializer.
+
+ # This default implementation handles the simple cases;
+ # subclasses should override __eq__ as appropriate.
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class FramedSerializer(Serializer):
+ """
+ Serializer that writes objects as a stream of (length, data) pairs,
+ where C{length} is a 32-bit integer and data is C{length} bytes.
+ """
+
+ def dump_stream(self, iterator, stream):
+ for obj in iterator:
+ self._write_with_length(obj, stream)
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self._read_with_length(stream)
+ except EOFError:
+ return
+
+ def _write_with_length(self, obj, stream):
+ serialized = self.dumps(obj)
+ write_int(len(serialized), stream)
+ stream.write(serialized)
+
+ def _read_with_length(self, stream):
+ length = read_int(stream)
+ obj = stream.read(length)
+ if obj == "":
+ raise EOFError
+ return self.loads(obj)
+
+ def dumps(self, obj):
+ """
+ Serialize an object into a byte array.
+ When batching is used, this will be called with an array of objects.
+ """
+ raise NotImplementedError
+
+ def loads(self, obj):
+ """
+ Deserialize an object from a byte array.
+ """
+ raise NotImplementedError
+
+
+class BatchedSerializer(Serializer):
+ """
+ Serializes a stream of objects in batches by calling its wrapped
+ Serializer with streams of objects.
+ """
+
+ UNLIMITED_BATCH_SIZE = -1
+
+ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
+ self.serializer = serializer
+ self.batchSize = batchSize
+
+ def _batched(self, iterator):
+ if self.batchSize == self.UNLIMITED_BATCH_SIZE:
+ yield list(iterator)
+ else:
+ items = []
+ count = 0
+ for item in iterator:
+ items.append(item)
+ count += 1
+ if count == self.batchSize:
+ yield items
+ items = []
+ count = 0
+ if items:
+ yield items
+
+ def dump_stream(self, iterator, stream):
+ self.serializer.dump_stream(self._batched(iterator), stream)
+
+ def load_stream(self, stream):
+ return chain.from_iterable(self._load_stream_without_unbatching(stream))
+
+ def _load_stream_without_unbatching(self, stream):
+ return self.serializer.load_stream(stream)
+
+ def __eq__(self, other):
+ return isinstance(other, BatchedSerializer) and \
+ other.serializer == self.serializer
+
+ def __str__(self):
+ return "BatchedSerializer<%s>" % str(self.serializer)
-class Batch(object):
+class CartesianDeserializer(FramedSerializer):
"""
- Used to store multiple RDD entries as a single Java object.
+ Deserializes the JavaRDD cartesian() of two PythonRDDs.
+ """
+
+ def __init__(self, key_ser, val_ser):
+ self.key_ser = key_ser
+ self.val_ser = val_ser
+
+ def load_stream(self, stream):
+ key_stream = self.key_ser._load_stream_without_unbatching(stream)
+ val_stream = self.val_ser._load_stream_without_unbatching(stream)
+ key_is_batched = isinstance(self.key_ser, BatchedSerializer)
+ val_is_batched = isinstance(self.val_ser, BatchedSerializer)
+ for (keys, vals) in izip(key_stream, val_stream):
+ keys = keys if key_is_batched else [keys]
+ vals = vals if val_is_batched else [vals]
+ for pair in product(keys, vals):
+ yield pair
+
+ def __eq__(self, other):
+ return isinstance(other, CartesianDeserializer) and \
+ self.key_ser == other.key_ser and self.val_ser == other.val_ser
+
+ def __str__(self):
+ return "CartesianDeserializer<%s, %s>" % \
+ (str(self.key_ser), str(self.val_ser))
- This relieves us from having to explicitly track whether an RDD
- is stored as batches of objects and avoids problems when processing
- the union() of batched and unbatched RDDs (e.g. the union() of textFile()
- with another RDD).
+
+class NoOpSerializer(FramedSerializer):
+
+ def loads(self, obj): return obj
+ def dumps(self, obj): return obj
+
+
+class PickleSerializer(FramedSerializer):
"""
- def __init__(self, items):
- self.items = items
+ Serializes objects using Python's cPickle serializer:
+ http://docs.python.org/2/library/pickle.html
-def batched(iterator, batchSize):
- if batchSize == -1: # unlimited batch size
- yield Batch(list(iterator))
- else:
- items = []
- count = 0
- for item in iterator:
- items.append(item)
- count += 1
- if count == batchSize:
- yield Batch(items)
- items = []
- count = 0
- if items:
- yield Batch(items)
+ This serializer supports nearly any Python object, but may
+ not be as fast as more specialized serializers.
+ """
+ def dumps(self, obj): return cPickle.dumps(obj, 2)
+ loads = cPickle.loads
-def dump_pickle(obj):
- return cPickle.dumps(obj, 2)
+class CloudPickleSerializer(PickleSerializer):
+ def dumps(self, obj): return cloudpickle.dumps(obj, 2)
-load_pickle = cPickle.loads
+
+class MarshalSerializer(FramedSerializer):
+ """
+ Serializes objects using Python's Marshal serializer:
+
+ http://docs.python.org/2/library/marshal.html
+
+ This serializer is faster than PickleSerializer but supports fewer datatypes.
+ """
+
+ dumps = marshal.dumps
+ loads = marshal.loads
+
+
+class MUTF8Deserializer(Serializer):
+ """
+ Deserializes streams written by Java's DataOutputStream.writeUTF().
+ """
+
+ def loads(self, stream):
+ length = struct.unpack('>H', stream.read(2))[0]
+ return stream.read(length).decode('utf8')
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self.loads(stream)
+ except struct.error:
+ return
+ except EOFError:
+ return
def read_long(stream):
@@ -84,25 +308,4 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
- stream.write(obj)
-
-
-def read_with_length(stream):
- length = read_int(stream)
- obj = stream.read(length)
- if obj == "":
- raise EOFError
- return obj
-
-
-def read_from_pickle_file(stream):
- try:
- while True:
- obj = load_pickle(read_with_length(stream))
- if type(obj) == Batch: # We don't care about inheritance
- for item in obj.items:
- yield item
- else:
- yield obj
- except EOFError:
- return
+ stream.write(obj) \ No newline at end of file
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 29d6a128f6..3987642bf4 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -19,6 +19,8 @@
Unit tests for PySpark; additional tests are implemented as doctests in
individual modules.
"""
+from fileinput import input
+from glob import glob
import os
import shutil
import sys
@@ -86,7 +88,8 @@ class TestCheckpoint(PySparkTestCase):
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
- recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
+ flatMappedRDD._jrdd_deserializer)
self.assertEquals([1, 2, 3, 4], recovered.collect())
@@ -137,6 +140,19 @@ class TestAddFile(PySparkTestCase):
self.assertEqual("Hello World from inside a package!", UserClass().hello())
+class TestRDDFunctions(PySparkTestCase):
+
+ def test_save_as_textfile_with_unicode(self):
+ # Regression test for SPARK-970
+ x = u"\u00A1Hola, mundo!"
+ data = self.sc.parallelize([x])
+ tempFile = NamedTemporaryFile(delete=True)
+ tempFile.close()
+ data.saveAsTextFile(tempFile.name)
+ raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*")))
+ self.assertEqual(x, unicode(raw_contents.strip(), "utf-8"))
+
+
class TestIO(PySparkTestCase):
def test_stdout_redirection(self):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d63c2aaef7..f2b3f3c142 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,23 +23,22 @@ import sys
import time
import socket
import traceback
-from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, read_with_length, write_int, \
- read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+from pyspark.serializers import write_with_length, write_int, read_long, \
+ write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
-def load_obj(infile):
- return load_pickle(standard_b64decode(infile.readline().strip()))
+pickleSer = PickleSerializer()
+mutf8_deserializer = MUTF8Deserializer()
def report_times(outfile, boot, init, finish):
- write_int(-3, outfile)
+ write_int(SpecialLengths.TIMING_DATA, outfile)
write_long(1000 * boot, outfile)
write_long(1000 * init, outfile)
write_long(1000 * finish, outfile)
@@ -52,7 +51,7 @@ def main(infile, outfile):
return
# fetch name of workdir
- spark_files_dir = load_pickle(read_with_length(infile))
+ spark_files_dir = mutf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
@@ -60,38 +59,33 @@ def main(infile, outfile):
num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
- value = read_with_length(infile)
- _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+ value = pickleSer._read_with_length(infile)
+ _broadcastRegistry[bid] = Broadcast(bid, value)
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
- sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+ filename = mutf8_deserializer.loads(infile)
+ sys.path.append(os.path.join(spark_files_dir, filename))
- # now load function
- func = load_obj(infile)
- bypassSerializer = load_obj(infile)
- if bypassSerializer:
- dumps = lambda x: x
- else:
- dumps = dump_pickle
+ command = pickleSer._read_with_length(infile)
+ (func, deserializer, serializer) = command
init_time = time.time()
- iterator = read_from_pickle_file(infile)
try:
- for obj in func(split_index, iterator):
- write_with_length(dumps(obj), outfile)
+ iterator = deserializer.load_stream(infile)
+ serializer.dump_stream(func(split_index, iterator), outfile)
except Exception as e:
- write_int(-2, outfile)
+ write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
sys.exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
# Mark the beginning of the accumulators section of the output
- write_int(-1, outfile)
- for aid, accum in _accumulatorRegistry.items():
- write_with_length(dump_pickle((aid, accum._value)), outfile)
- write_int(-1, outfile)
+ write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+ write_int(len(_accumulatorRegistry), outfile)
+ for (aid, accum) in _accumulatorRegistry.items():
+ pickleSer._write_with_length((aid, accum._value), outfile)
if __name__ == '__main__':
diff --git a/python/run-tests b/python/run-tests
index cbc554ea9d..d4dad672d2 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -37,6 +37,7 @@ run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "-m doctest pyspark/broadcast.py"
run_test "-m doctest pyspark/accumulators.py"
+run_test "-m doctest pyspark/serializers.py"
run_test "pyspark/tests.py"
if [[ $FAILED != 0 ]]; then
diff --git a/python/test_support/userlibrary.py b/python/test_support/userlibrary.py
index 5bb6f5009f..8e4a6292bc 100755
--- a/python/test_support/userlibrary.py
+++ b/python/test_support/userlibrary.py
@@ -1,3 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
"""
Used to test shipping of code depenencies with SparkContext.addPyFile().
"""
diff --git a/repl-bin/pom.xml b/repl-bin/pom.xml
index f6bf94be6b..869dbdb9b0 100644
--- a/repl-bin/pom.xml
+++ b/repl-bin/pom.xml
@@ -26,7 +26,7 @@
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-repl-bin_2.9.3</artifactId>
+ <artifactId>spark-repl-bin_2.10</artifactId>
<packaging>pom</packaging>
<name>Spark Project REPL binary packaging</name>
<url>http://spark.incubator.apache.org/</url>
@@ -40,18 +40,18 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core_2.9.3</artifactId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-bagel_2.9.3</artifactId>
+ <artifactId>spark-bagel_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-repl_2.9.3</artifactId>
+ <artifactId>spark-repl_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>runtime</scope>
</dependency>
diff --git a/repl-bin/src/deb/bin/run b/repl-bin/src/deb/bin/run
index 8b5d8300f2..47bb654baf 100755
--- a/repl-bin/src/deb/bin/run
+++ b/repl-bin/src/deb/bin/run
@@ -17,7 +17,7 @@
# limitations under the License.
#
-SCALA_VERSION=2.9.3
+SCALA_VERSION=2.10
# Figure out where the Scala framework is installed
FWDIR="$(cd `dirname $0`; pwd)"
diff --git a/repl-bin/src/deb/bin/spark-executor b/repl-bin/src/deb/bin/spark-executor
index bcfae22677..052d76fb8d 100755
--- a/repl-bin/src/deb/bin/spark-executor
+++ b/repl-bin/src/deb/bin/spark-executor
@@ -19,4 +19,4 @@
FWDIR="$(cd `dirname $0`; pwd)"
echo "Running spark-executor with framework dir = $FWDIR"
-exec $FWDIR/run spark.executor.MesosExecutorBackend
+exec $FWDIR/run org.apache.spark.executor.MesosExecutorBackend
diff --git a/repl-bin/src/deb/bin/spark-shell b/repl-bin/src/deb/bin/spark-shell
index ec7e33e1e3..118349d7c3 100755
--- a/repl-bin/src/deb/bin/spark-shell
+++ b/repl-bin/src/deb/bin/spark-shell
@@ -18,4 +18,4 @@
#
FWDIR="$(cd `dirname $0`; pwd)"
-exec $FWDIR/run spark.repl.Main "$@"
+exec $FWDIR/run org.apache.spark.repl.Main "$@"
diff --git a/repl/lib/scala-jline.jar b/repl/lib/scala-jline.jar
deleted file mode 100644
index 2f18c95cdd..0000000000
--- a/repl/lib/scala-jline.jar
+++ /dev/null
Binary files differ
diff --git a/repl/pom.xml b/repl/pom.xml
index 49d86621dd..b0e7877bbb 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -26,7 +26,7 @@
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-repl_2.9.3</artifactId>
+ <artifactId>spark-repl_2.10</artifactId>
<packaging>jar</packaging>
<name>Spark Project REPL</name>
<url>http://spark.incubator.apache.org/</url>
@@ -39,18 +39,18 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core_2.9.3</artifactId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-bagel_2.9.3</artifactId>
+ <artifactId>spark-bagel_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-mllib_2.9.3</artifactId>
+ <artifactId>spark-mllib_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>runtime</scope>
</dependency>
@@ -61,10 +61,12 @@
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-compiler</artifactId>
+ <version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>jline</artifactId>
+ <version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
@@ -76,18 +78,18 @@
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_2.9.3</artifactId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_2.9.3</artifactId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
- <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
- <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/src/main/scala/org/apache/spark/repl/Main.scala
index 17e149f8ab..14b448d076 100644
--- a/repl/src/main/scala/org/apache/spark/repl/Main.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/Main.scala
@@ -20,12 +20,12 @@ package org.apache.spark.repl
import scala.collection.mutable.Set
object Main {
- private var _interp: SparkILoop = null
-
+ private var _interp: SparkILoop = _
+
def interp = _interp
-
+
def interp_=(i: SparkILoop) { _interp = i }
-
+
def main(args: Array[String]) {
_interp = new SparkILoop
_interp.process(args)
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
new file mode 100644
index 0000000000..b2e1df173e
--- /dev/null
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
@@ -0,0 +1,109 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2013 LAMP/EPFL
+ * @author Paul Phillips
+ */
+
+package org.apache.spark.repl
+
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter._
+
+import scala.reflect.internal.util.BatchSourceFile
+import scala.tools.nsc.ast.parser.Tokens.EOF
+
+import org.apache.spark.Logging
+
+trait SparkExprTyper extends Logging {
+ val repl: SparkIMain
+
+ import repl._
+ import global.{ reporter => _, Import => _, _ }
+ import definitions._
+ import syntaxAnalyzer.{ UnitParser, UnitScanner, token2name }
+ import naming.freshInternalVarName
+
+ object codeParser extends { val global: repl.global.type = repl.global } with CodeHandlers[Tree] {
+ def applyRule[T](code: String, rule: UnitParser => T): T = {
+ reporter.reset()
+ val scanner = newUnitParser(code)
+ val result = rule(scanner)
+
+ if (!reporter.hasErrors)
+ scanner.accept(EOF)
+
+ result
+ }
+
+ def defns(code: String) = stmts(code) collect { case x: DefTree => x }
+ def expr(code: String) = applyRule(code, _.expr())
+ def stmts(code: String) = applyRule(code, _.templateStats())
+ def stmt(code: String) = stmts(code).last // guaranteed nonempty
+ }
+
+ /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */
+ def parse(line: String): Option[List[Tree]] = debugging(s"""parse("$line")""") {
+ var isIncomplete = false
+ reporter.withIncompleteHandler((_, _) => isIncomplete = true) {
+ val trees = codeParser.stmts(line)
+ if (reporter.hasErrors) Some(Nil)
+ else if (isIncomplete) None
+ else Some(trees)
+ }
+ }
+ // def parsesAsExpr(line: String) = {
+ // import codeParser._
+ // (opt expr line).isDefined
+ // }
+
+ def symbolOfLine(code: String): Symbol = {
+ def asExpr(): Symbol = {
+ val name = freshInternalVarName()
+ // Typing it with a lazy val would give us the right type, but runs
+ // into compiler bugs with things like existentials, so we compile it
+ // behind a def and strip the NullaryMethodType which wraps the expr.
+ val line = "def " + name + " = {\n" + code + "\n}"
+
+ interpretSynthetic(line) match {
+ case IR.Success =>
+ val sym0 = symbolOfTerm(name)
+ // drop NullaryMethodType
+ val sym = sym0.cloneSymbol setInfo afterTyper(sym0.info.finalResultType)
+ if (sym.info.typeSymbol eq UnitClass) NoSymbol
+ else sym
+ case _ => NoSymbol
+ }
+ }
+ def asDefn(): Symbol = {
+ val old = repl.definedSymbolList.toSet
+
+ interpretSynthetic(code) match {
+ case IR.Success =>
+ repl.definedSymbolList filterNot old match {
+ case Nil => NoSymbol
+ case sym :: Nil => sym
+ case syms => NoSymbol.newOverloaded(NoPrefix, syms)
+ }
+ case _ => NoSymbol
+ }
+ }
+ beQuietDuring(asExpr()) orElse beQuietDuring(asDefn())
+ }
+
+ private var typeOfExpressionDepth = 0
+ def typeOfExpression(expr: String, silent: Boolean = true): Type = {
+ if (typeOfExpressionDepth > 2) {
+ logDebug("Terminating typeOfExpression recursion for expression: " + expr)
+ return NoType
+ }
+ typeOfExpressionDepth += 1
+ // Don't presently have a good way to suppress undesirable success output
+ // while letting errors through, so it is first trying it silently: if there
+ // is an error, and errors are desired, then it re-evaluates non-silently
+ // to induce the error message.
+ try beSilentDuring(symbolOfLine(expr).tpe) match {
+ case NoType if !silent => symbolOfLine(expr).tpe // generate error
+ case tpe => tpe
+ }
+ finally typeOfExpressionDepth -= 1
+ }
+}
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 36f54a22cf..523fd1222d 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -1,26 +1,38 @@
/* NSC -- new Scala compiler
- * Copyright 2005-2011 LAMP/EPFL
+ * Copyright 2005-2013 LAMP/EPFL
* @author Alexander Spoon
*/
package org.apache.spark.repl
+
import scala.tools.nsc._
import scala.tools.nsc.interpreter._
+import scala.tools.nsc.interpreter.{ Results => IR }
import Predef.{ println => _, _ }
-import java.io.{ BufferedReader, FileReader, PrintWriter }
+import java.io.{ BufferedReader, FileReader }
+import java.util.concurrent.locks.ReentrantLock
import scala.sys.process.Process
-import session._
-import scala.tools.nsc.interpreter.{ Results => IR }
-import scala.tools.util.{ SignalManager, Signallable, Javap }
+import scala.tools.nsc.interpreter.session._
+import scala.util.Properties.{ jdkHome, javaVersion }
+import scala.tools.util.{ Javap }
import scala.annotation.tailrec
-import scala.util.control.Exception.{ ignoring }
import scala.collection.mutable.ListBuffer
import scala.concurrent.ops
-import util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream }
-import interpreter._
-import io.{ File, Sources }
+import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream }
+import scala.tools.nsc.interpreter._
+import scala.tools.nsc.io.{ File, Directory }
+import scala.reflect.NameTransformer._
+import scala.tools.nsc.util.ScalaClassLoader
+import scala.tools.nsc.util.ScalaClassLoader._
+import scala.tools.util._
+import scala.language.{implicitConversions, existentials}
+import scala.reflect.{ClassTag, classTag}
+import scala.tools.reflect.StdRuntimeTags._
+
+import java.lang.{Class => jClass}
+import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse}
import org.apache.spark.Logging
import org.apache.spark.SparkContext
@@ -37,45 +49,86 @@ import org.apache.spark.SparkContext
* @author Lex Spoon
* @version 1.2
*/
-class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: Option[String])
+class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
+ val master: Option[String])
extends AnyRef
with LoopCommands
+ with SparkILoopInit
with Logging
{
- def this(in0: BufferedReader, out: PrintWriter, master: String) = this(Some(in0), out, Some(master))
- def this(in0: BufferedReader, out: PrintWriter) = this(Some(in0), out, None)
- def this() = this(None, new PrintWriter(Console.out, true), None)
-
+ def this(in0: BufferedReader, out: JPrintWriter, master: String) = this(Some(in0), out, Some(master))
+ def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out, None)
+ def this() = this(None, new JPrintWriter(Console.out, true), None)
+
var in: InteractiveReader = _ // the input stream from which commands come
var settings: Settings = _
var intp: SparkIMain = _
- /*
- lazy val power = {
- val g = intp.global
- Power[g.type](this, g)
- }
- */
-
- // TODO
- // object opt extends AestheticSettings
- //
- @deprecated("Use `intp` instead.", "2.9.0")
- def interpreter = intp
-
- @deprecated("Use `intp` instead.", "2.9.0")
- def interpreter_= (i: SparkIMain): Unit = intp = i
-
+ @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp
+ @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i
+
+ /** Having inherited the difficult "var-ness" of the repl instance,
+ * I'm trying to work around it by moving operations into a class from
+ * which it will appear a stable prefix.
+ */
+ private def onIntp[T](f: SparkIMain => T): T = f(intp)
+
+ class IMainOps[T <: SparkIMain](val intp: T) {
+ import intp._
+ import global._
+
+ def printAfterTyper(msg: => String) =
+ intp.reporter printMessage afterTyper(msg)
+
+ /** Strip NullaryMethodType artifacts. */
+ private def replInfo(sym: Symbol) = {
+ sym.info match {
+ case NullaryMethodType(restpe) if sym.isAccessor => restpe
+ case info => info
+ }
+ }
+ def echoTypeStructure(sym: Symbol) =
+ printAfterTyper("" + deconstruct.show(replInfo(sym)))
+
+ def echoTypeSignature(sym: Symbol, verbose: Boolean) = {
+ if (verbose) SparkILoop.this.echo("// Type signature")
+ printAfterTyper("" + replInfo(sym))
+
+ if (verbose) {
+ SparkILoop.this.echo("\n// Internal Type structure")
+ echoTypeStructure(sym)
+ }
+ }
+ }
+ implicit def stabilizeIMain(intp: SparkIMain) = new IMainOps[intp.type](intp)
+
+ /** TODO -
+ * -n normalize
+ * -l label with case class parameter names
+ * -c complete - leave nothing out
+ */
+ private def typeCommandInternal(expr: String, verbose: Boolean): Result = {
+ onIntp { intp =>
+ val sym = intp.symbolOfLine(expr)
+ if (sym.exists) intp.echoTypeSignature(sym, verbose)
+ else ""
+ }
+ }
+
+ var sparkContext: SparkContext = _
+
+ override def echoCommandMessage(msg: String) {
+ intp.reporter printMessage msg
+ }
+
+ // def isAsync = !settings.Yreplsync.value
+ def isAsync = false
+ // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals])
def history = in.history
/** The context class loader at the time this object was created */
protected val originalClassLoader = Thread.currentThread.getContextClassLoader
- // Install a signal handler so we can be prodded.
- private val signallable =
- /*if (isReplDebug) Signallable("Dump repl state.")(dumpCommand())
- else*/ null
-
// classpath entries added via :cp
var addedClasspath: String = ""
@@ -87,74 +140,49 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
/** Record a command for replay should the user request a :replay */
def addReplay(cmd: String) = replayCommandStack ::= cmd
-
- /** Try to install sigint handler: ignore failure. Signal handler
- * will interrupt current line execution if any is in progress.
- *
- * Attempting to protect the repl from accidental exit, we only honor
- * a single ctrl-C if the current buffer is empty: otherwise we look
- * for a second one within a short time.
- */
- private def installSigIntHandler() {
- def onExit() {
- Console.println("") // avoiding "shell prompt in middle of line" syndrome
- sys.exit(1)
- }
- ignoring(classOf[Exception]) {
- SignalManager("INT") = {
- if (intp == null)
- onExit()
- else if (intp.lineManager.running)
- intp.lineManager.cancel()
- else if (in.currentLine != "") {
- // non-empty buffer, so make them hit ctrl-C a second time
- SignalManager("INT") = onExit()
- io.timer(5)(installSigIntHandler()) // and restore original handler if they don't
- }
- else onExit()
- }
- }
+
+ def savingReplayStack[T](body: => T): T = {
+ val saved = replayCommandStack
+ try body
+ finally replayCommandStack = saved
+ }
+ def savingReader[T](body: => T): T = {
+ val saved = in
+ try body
+ finally in = saved
}
+
+ def sparkCleanUp(){
+ echo("Stopping spark context.")
+ intp.beQuietDuring {
+ command("sc.stop()")
+ }
+ }
/** Close the interpreter and set the var to null. */
def closeInterpreter() {
if (intp ne null) {
- intp.close
+ sparkCleanUp()
+ intp.close()
intp = null
- Thread.currentThread.setContextClassLoader(originalClassLoader)
}
}
-
+
class SparkILoopInterpreter extends SparkIMain(settings, out) {
+ outer =>
+
override lazy val formatting = new Formatting {
def prompt = SparkILoop.this.prompt
}
- override protected def createLineManager() = new Line.Manager {
- override def onRunaway(line: Line[_]): Unit = {
- val template = """
- |// She's gone rogue, captain! Have to take her out!
- |// Calling Thread.stop on runaway %s with offending code:
- |// scala> %s""".stripMargin
-
- echo(template.format(line.thread, line.code))
- // XXX no way to suppress the deprecation warning
- line.thread.stop()
- in.redrawLine()
- }
- }
- override protected def parentClassLoader = {
- SparkHelper.explicitParentLoader(settings).getOrElse( classOf[SparkILoop].getClassLoader )
- }
+ override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
}
/** Create a new interpreter. */
def createInterpreter() {
if (addedClasspath != "")
settings.classpath append addedClasspath
-
+
intp = new SparkILoopInterpreter
- intp.setContextClassLoader()
- installSigIntHandler()
}
/** print a friendly help message */
@@ -168,10 +196,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
private def helpSummary() = {
val usageWidth = commands map (_.usageMsg.length) max
val formatStr = "%-" + usageWidth + "s %s %s"
-
+
echo("All commands can be abbreviated, e.g. :he instead of :help.")
echo("Those marked with a * have more detailed help, e.g. :help imports.\n")
-
+
commands foreach { cmd =>
val star = if (cmd.hasLongHelp) "*" else " "
echo(formatStr.format(cmd.usageMsg, star, cmd.help))
@@ -182,7 +210,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
case Nil => echo(cmd + ": no such command. Type :help for help.")
case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?")
}
- Result(true, None)
+ Result(true, None)
}
private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd)
private def uniqueCommand(cmd: String): Option[LoopCommand] = {
@@ -193,31 +221,16 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
case xs => xs find (_.name == cmd)
}
}
-
- /** Print a welcome message */
- def printWelcome() {
- echo("""Welcome to
- ____ __
- / __/__ ___ _____/ /__
- _\ \/ _ \/ _ `/ __/ '_/
- /___/ .__/\_,_/_/ /_/\_\ version 0.9.0-SNAPSHOT
- /_/
-""")
- import Properties._
- val welcomeMsg = "Using Scala %s (%s, Java %s)".format(
- versionString, javaVmName, javaVersion)
- echo(welcomeMsg)
- }
-
+
/** Show the history */
lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") {
override def usage = "[num]"
def defaultLines = 20
-
+
def apply(line: String): Result = {
if (history eq NoHistory)
return "No history available."
-
+
val xs = words(line)
val current = history.index
val count = try xs.head.toInt catch { case _: Exception => defaultLines }
@@ -229,32 +242,38 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
}
}
- private def echo(msg: String) = {
+ // When you know you are most likely breaking into the middle
+ // of a line being typed. This softens the blow.
+ protected def echoAndRefresh(msg: String) = {
+ echo("\n" + msg)
+ in.redrawLine()
+ }
+ protected def echo(msg: String) = {
out println msg
out.flush()
}
- private def echoNoNL(msg: String) = {
+ protected def echoNoNL(msg: String) = {
out print msg
out.flush()
}
-
+
/** Search the history */
def searchHistory(_cmdline: String) {
val cmdline = _cmdline.toLowerCase
val offset = history.index - history.size + 1
-
+
for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline)
echo("%d %s".format(index + offset, line))
}
-
+
private var currentPrompt = Properties.shellPromptString
def setPrompt(prompt: String) = currentPrompt = prompt
/** Prompt to print when awaiting input */
def prompt = currentPrompt
-
+
import LoopCommand.{ cmd, nullary }
- /** Standard commands **/
+ /** Standard commands */
lazy val standardCommands = List(
cmd("cp", "<path>", "add a jar or directory to the classpath", addClasspath),
cmd("help", "[command]", "print this summary or command-specific help", helpCommand),
@@ -263,53 +282,30 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand),
cmd("implicits", "[-v]", "show the implicits in scope", implicitsCommand),
cmd("javap", "<path|class>", "disassemble a file or class name", javapCommand),
- nullary("keybindings", "show how ctrl-[A-Z] and other keys are bound", keybindingsCommand),
cmd("load", "<path>", "load and interpret a Scala file", loadCommand),
nullary("paste", "enter paste mode: all input up to ctrl-D compiled together", pasteCommand),
- //nullary("power", "enable power user mode", powerCmd),
- nullary("quit", "exit the interpreter", () => Result(false, None)),
+// nullary("power", "enable power user mode", powerCmd),
+ nullary("quit", "exit the repl", () => Result(false, None)),
nullary("replay", "reset execution and replay all previous commands", replay),
+ nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand),
shCommand,
nullary("silent", "disable/enable automatic printing of results", verbosity),
- cmd("type", "<expr>", "display the type of an expression without evaluating it", typeCommand)
+ cmd("type", "[-v] <expr>", "display the type of an expression without evaluating it", typeCommand),
+ nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand)
)
-
+
/** Power user commands */
lazy val powerCommands: List[LoopCommand] = List(
- //nullary("dump", "displays a view of the interpreter's internal state", dumpCommand),
- //cmd("phase", "<phase>", "set the implicit phase for power commands", phaseCommand),
- cmd("wrap", "<method>", "name of method to wrap around each repl line", wrapCommand) withLongHelp ("""
- |:wrap
- |:wrap clear
- |:wrap <method>
- |
- |Installs a wrapper around each line entered into the repl.
- |Currently it must be the simple name of an existing method
- |with the specific signature shown in the following example.
- |
- |def timed[T](body: => T): T = {
- | val start = System.nanoTime
- | try body
- | finally println((System.nanoTime - start) + " nanos elapsed.")
- |}
- |:wrap timed
- |
- |If given no argument, :wrap names the wrapper installed.
- |An argument of clear will remove the wrapper if any is active.
- |Note that wrappers do not compose (a new one replaces the old
- |one) and also that the :phase command uses the same machinery,
- |so setting :wrap will clear any :phase setting.
- """.stripMargin.trim)
+ // cmd("phase", "<phase>", "set the implicit phase for power commands", phaseCommand)
)
-
- /*
- private def dumpCommand(): Result = {
- echo("" + power)
- history.asStrings takeRight 30 foreach echo
- in.redrawLine()
- }
- */
-
+
+ // private def dumpCommand(): Result = {
+ // echo("" + power)
+ // history.asStrings takeRight 30 foreach echo
+ // in.redrawLine()
+ // }
+ // private def valsCommand(): Result = power.valsDescription
+
private val typeTransforms = List(
"scala.collection.immutable." -> "immutable.",
"scala.collection.mutable." -> "mutable.",
@@ -317,7 +313,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
"java.lang." -> "jl.",
"scala.runtime." -> "runtime."
)
-
+
private def importsCommand(line: String): Result = {
val tokens = words(line)
val handlers = intp.languageWildcardHandlers ++ intp.importHandlers
@@ -333,7 +329,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit"
val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "")
val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")")
-
+
intp.reporter.printMessage("%2d) %-30s %s%s".format(
idx + 1,
handler.importString,
@@ -342,12 +338,11 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
))
}
}
-
- private def implicitsCommand(line: String): Result = {
- val intp = SparkILoop.this.intp
+
+ private def implicitsCommand(line: String): Result = onIntp { intp =>
import intp._
- import global.Symbol
-
+ import global._
+
def p(x: Any) = intp.reporter.printMessage("" + x)
// If an argument is given, only show a source with that
@@ -360,17 +355,17 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
else (args exists (source.name.toString contains _))
}
}
-
+
if (filtered.isEmpty)
return "No implicits have been imported other than those in Predef."
-
+
filtered foreach {
case (source, syms) =>
p("/* " + syms.size + " implicit members imported from " + source.fullName + " */")
-
+
// This groups the members by where the symbol is defined
val byOwner = syms groupBy (_.owner)
- val sortedOwners = byOwner.toList sortBy { case (owner, _) => intp.afterTyper(source.info.baseClasses indexOf owner) }
+ val sortedOwners = byOwner.toList sortBy { case (owner, _) => afterTyper(source.info.baseClasses indexOf owner) }
sortedOwners foreach {
case (owner, members) =>
@@ -388,10 +383,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
xss map (xs => xs sortBy (_.name.toString))
}
-
- val ownerMessage = if (owner == source) " defined in " else " inherited from "
+
+ val ownerMessage = if (owner == source) " defined in " else " inherited from "
p(" /* " + members.size + ownerMessage + owner.fullName + " */")
-
+
memberGroups foreach { group =>
group foreach (s => p(" " + intp.symbolDefString(s)))
p("")
@@ -400,158 +395,182 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
p("")
}
}
-
- protected def newJavap() = new Javap(intp.classLoader, new SparkIMain.ReplStrippingWriter(intp)) {
+
+ private def findToolsJar() = {
+ val jdkPath = Directory(jdkHome)
+ val jar = jdkPath / "lib" / "tools.jar" toFile;
+
+ if (jar isFile)
+ Some(jar)
+ else if (jdkPath.isDirectory)
+ jdkPath.deepFiles find (_.name == "tools.jar")
+ else None
+ }
+ private def addToolsJarToLoader() = {
+ val cl = findToolsJar match {
+ case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader)
+ case _ => intp.classLoader
+ }
+ if (Javap.isAvailable(cl)) {
+ logDebug(":javap available.")
+ cl
+ }
+ else {
+ logDebug(":javap unavailable: no tools.jar at " + jdkHome)
+ intp.classLoader
+ }
+ }
+
+ protected def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) {
override def tryClass(path: String): Array[Byte] = {
- // Look for Foo first, then Foo$, but if Foo$ is given explicitly,
- // we have to drop the $ to find object Foo, then tack it back onto
- // the end of the flattened name.
- def className = intp flatName path
- def moduleName = (intp flatName path.stripSuffix("$")) + "$"
+ val hd :: rest = path split '.' toList;
+ // If there are dots in the name, the first segment is the
+ // key to finding it.
+ if (rest.nonEmpty) {
+ intp optFlatName hd match {
+ case Some(flat) =>
+ val clazz = flat :: rest mkString NAME_JOIN_STRING
+ val bytes = super.tryClass(clazz)
+ if (bytes.nonEmpty) bytes
+ else super.tryClass(clazz + MODULE_SUFFIX_STRING)
+ case _ => super.tryClass(path)
+ }
+ }
+ else {
+ // Look for Foo first, then Foo$, but if Foo$ is given explicitly,
+ // we have to drop the $ to find object Foo, then tack it back onto
+ // the end of the flattened name.
+ def className = intp flatName path
+ def moduleName = (intp flatName path.stripSuffix(MODULE_SUFFIX_STRING)) + MODULE_SUFFIX_STRING
- val bytes = super.tryClass(className)
- if (bytes.nonEmpty) bytes
- else super.tryClass(moduleName)
+ val bytes = super.tryClass(className)
+ if (bytes.nonEmpty) bytes
+ else super.tryClass(moduleName)
+ }
}
}
+ // private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap())
private lazy val javap =
try newJavap()
catch { case _: Exception => null }
-
- private def typeCommand(line: String): Result = {
- intp.typeOfExpression(line) match {
- case Some(tp) => tp.toString
- case _ => "Failed to determine type."
+
+ // Still todo: modules.
+ private def typeCommand(line0: String): Result = {
+ line0.trim match {
+ case "" => ":type [-v] <expression>"
+ case s if s startsWith "-v " => typeCommandInternal(s stripPrefix "-v " trim, true)
+ case s => typeCommandInternal(s, false)
}
}
-
+
+ private def warningsCommand(): Result = {
+ if (intp.lastWarnings.isEmpty)
+ "Can't find any cached warnings."
+ else
+ intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) }
+ }
+
private def javapCommand(line: String): Result = {
if (javap == null)
- return ":javap unavailable on this platform."
- if (line == "")
- return ":javap [-lcsvp] [path1 path2 ...]"
-
- javap(words(line)) foreach { res =>
- if (res.isError) return "Failed: " + res.value
- else res.show()
- }
- }
- private def keybindingsCommand(): Result = {
- if (in.keyBindings.isEmpty) "Key bindings unavailable."
- else {
- echo("Reading jline properties for default key bindings.")
- echo("Accuracy not guaranteed: treat this as a guideline only.\n")
- in.keyBindings foreach (x => echo ("" + x))
- }
+ ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome)
+ else if (javaVersion startsWith "1.7")
+ ":javap not yet working with java 1.7"
+ else if (line == "")
+ ":javap [-lcsvp] [path1 path2 ...]"
+ else
+ javap(words(line)) foreach { res =>
+ if (res.isError) return "Failed: " + res.value
+ else res.show()
+ }
}
+
private def wrapCommand(line: String): Result = {
def failMsg = "Argument to :wrap must be the name of a method with signature [T](=> T): T"
- val intp = SparkILoop.this.intp
- val g: intp.global.type = intp.global
- import g._
-
- words(line) match {
- case Nil =>
- intp.executionWrapper match {
- case "" => "No execution wrapper is set."
- case s => "Current execution wrapper: " + s
- }
- case "clear" :: Nil =>
- intp.executionWrapper match {
- case "" => "No execution wrapper is set."
- case s => intp.clearExecutionWrapper() ; "Cleared execution wrapper."
- }
- case wrapper :: Nil =>
- intp.typeOfExpression(wrapper) match {
- case Some(PolyType(List(targ), MethodType(List(arg), restpe))) =>
- intp setExecutionWrapper intp.pathToTerm(wrapper)
- "Set wrapper to '" + wrapper + "'"
- case Some(x) =>
- failMsg + "\nFound: " + x
- case _ =>
- failMsg + "\nFound: <unknown>"
- }
- case _ => failMsg
- }
- }
+ onIntp { intp =>
+ import intp._
+ import global._
- private def pathToPhaseWrapper = intp.pathToTerm("$r") + ".phased.atCurrent"
- /*
- private def phaseCommand(name: String): Result = {
- // This line crashes us in TreeGen:
- //
- // if (intp.power.phased set name) "..."
- //
- // Exception in thread "main" java.lang.AssertionError: assertion failed: ._7.type
- // at scala.Predef$.assert(Predef.scala:99)
- // at scala.tools.nsc.ast.TreeGen.mkAttributedQualifier(TreeGen.scala:69)
- // at scala.tools.nsc.ast.TreeGen.mkAttributedQualifier(TreeGen.scala:44)
- // at scala.tools.nsc.ast.TreeGen.mkAttributedRef(TreeGen.scala:101)
- // at scala.tools.nsc.ast.TreeGen.mkAttributedStableRef(TreeGen.scala:143)
- //
- // But it works like so, type annotated.
- val phased: Phased = power.phased
- import phased.NoPhaseName
-
- if (name == "clear") {
- phased.set(NoPhaseName)
- intp.clearExecutionWrapper()
- "Cleared active phase."
- }
- else if (name == "") phased.get match {
- case NoPhaseName => "Usage: :phase <expr> (e.g. typer, erasure.next, erasure+3)"
- case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get)
- }
- else {
- val what = phased.parse(name)
- if (what.isEmpty || !phased.set(what))
- "'" + name + "' does not appear to represent a valid phase."
- else {
- intp.setExecutionWrapper(pathToPhaseWrapper)
- val activeMessage =
- if (what.toString.length == name.length) "" + what
- else "%s (%s)".format(what, name)
-
- "Active phase is now: " + activeMessage
+ words(line) match {
+ case Nil =>
+ intp.executionWrapper match {
+ case "" => "No execution wrapper is set."
+ case s => "Current execution wrapper: " + s
+ }
+ case "clear" :: Nil =>
+ intp.executionWrapper match {
+ case "" => "No execution wrapper is set."
+ case s => intp.clearExecutionWrapper() ; "Cleared execution wrapper."
+ }
+ case wrapper :: Nil =>
+ intp.typeOfExpression(wrapper) match {
+ case PolyType(List(targ), MethodType(List(arg), restpe)) =>
+ intp setExecutionWrapper intp.pathToTerm(wrapper)
+ "Set wrapper to '" + wrapper + "'"
+ case tp =>
+ failMsg + "\nFound: <unknown>"
+ }
+ case _ => failMsg
}
}
}
- */
-
+
+ private def pathToPhaseWrapper = intp.pathToTerm("$r") + ".phased.atCurrent"
+ // private def phaseCommand(name: String): Result = {
+ // val phased: Phased = power.phased
+ // import phased.NoPhaseName
+
+ // if (name == "clear") {
+ // phased.set(NoPhaseName)
+ // intp.clearExecutionWrapper()
+ // "Cleared active phase."
+ // }
+ // else if (name == "") phased.get match {
+ // case NoPhaseName => "Usage: :phase <expr> (e.g. typer, erasure.next, erasure+3)"
+ // case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get)
+ // }
+ // else {
+ // val what = phased.parse(name)
+ // if (what.isEmpty || !phased.set(what))
+ // "'" + name + "' does not appear to represent a valid phase."
+ // else {
+ // intp.setExecutionWrapper(pathToPhaseWrapper)
+ // val activeMessage =
+ // if (what.toString.length == name.length) "" + what
+ // else "%s (%s)".format(what, name)
+
+ // "Active phase is now: " + activeMessage
+ // }
+ // }
+ // }
+
/** Available commands */
- def commands: List[LoopCommand] = standardCommands /* ++ (
+ def commands: List[LoopCommand] = standardCommands /*++ (
if (isReplPower) powerCommands else Nil
)*/
-
+
val replayQuestionMessage =
- """|The repl compiler has crashed spectacularly. Shall I replay your
- |session? I can re-run all lines except the last one.
+ """|That entry seems to have slain the compiler. Shall I replay
+ |your session? I can re-run each line except the last one.
|[y/n]
""".trim.stripMargin
- private val crashRecovery: PartialFunction[Throwable, Unit] = {
+ private val crashRecovery: PartialFunction[Throwable, Boolean] = {
case ex: Throwable =>
- if (settings.YrichExes.value) {
- val sources = implicitly[Sources]
- echo("\n" + ex.getMessage)
- echo(
- if (isReplDebug) "[searching " + sources.path + " for exception contexts...]"
- else "[searching for exception contexts...]"
- )
- echo(Exceptional(ex).force().context())
- }
- else {
- echo(util.stackTraceString(ex))
- }
+ echo(intp.global.throwableAsString(ex))
+
ex match {
case _: NoSuchMethodError | _: NoClassDefFoundError =>
- echo("Unrecoverable error.")
+ echo("\nUnrecoverable error.")
throw ex
case _ =>
- def fn(): Boolean = in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() })
+ def fn(): Boolean =
+ try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() })
+ catch { case _: RuntimeException => false }
+
if (fn()) replay()
else echo("\nAbandoning crashed session.")
}
+ true
}
/** The main read-eval-print loop for the repl. It calls
@@ -564,66 +583,89 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
in readLine prompt
}
// return false if repl should exit
- def processLine(line: String): Boolean =
+ def processLine(line: String): Boolean = {
+ if (isAsync) {
+ if (!awaitInitialized()) return false
+ runThunks()
+ }
if (line eq null) false // assume null means EOF
else command(line) match {
case Result(false, _) => false
case Result(_, Some(finalLine)) => addReplay(finalLine) ; true
case _ => true
}
-
- while (true) {
- try if (!processLine(readOneLine)) return
- catch crashRecovery
}
+ def innerLoop() {
+ if ( try processLine(readOneLine()) catch crashRecovery )
+ innerLoop()
+ }
+ innerLoop()
}
/** interpret all lines from a specified file */
- def interpretAllFrom(file: File) {
- val oldIn = in
- val oldReplay = replayCommandStack
-
- try file applyReader { reader =>
- in = SimpleReader(reader, out, false)
- echo("Loading " + file + "...")
- loop()
- }
- finally {
- in = oldIn
- replayCommandStack = oldReplay
+ def interpretAllFrom(file: File) {
+ savingReader {
+ savingReplayStack {
+ file applyReader { reader =>
+ in = SimpleReader(reader, out, false)
+ echo("Loading " + file + "...")
+ loop()
+ }
+ }
}
}
- /** create a new interpreter and replay all commands so far */
+ /** create a new interpreter and replay the given commands */
def replay() {
- closeInterpreter()
- createInterpreter()
- for (cmd <- replayCommands) {
+ reset()
+ if (replayCommandStack.isEmpty)
+ echo("Nothing to replay.")
+ else for (cmd <- replayCommands) {
echo("Replaying: " + cmd) // flush because maybe cmd will have its own output
command(cmd)
echo("")
}
}
-
+ def resetCommand() {
+ echo("Resetting repl state.")
+ if (replayCommandStack.nonEmpty) {
+ echo("Forgetting this session history:\n")
+ replayCommands foreach echo
+ echo("")
+ replayCommandStack = Nil
+ }
+ if (intp.namedDefinedTerms.nonEmpty)
+ echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", "))
+ if (intp.definedTypes.nonEmpty)
+ echo("Forgetting defined types: " + intp.definedTypes.mkString(", "))
+
+ reset()
+ }
+
+ def reset() {
+ intp.reset()
+ // unleashAndSetPhase()
+ }
+
/** fork a shell and run a command */
lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") {
override def usage = "<command line>"
def apply(line: String): Result = line match {
case "" => showUsage()
- case _ =>
+ case _ =>
val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")"
intp interpret toRun
()
}
}
-
+
def withFile(filename: String)(action: File => Unit) {
val f = File(filename)
-
+
if (f.exists) action(f)
else echo("That file does not exist")
}
-
+
def loadCommand(arg: String) = {
var shouldReplay: Option[String] = None
withFile(arg)(f => {
@@ -633,6 +675,20 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
Result(true, shouldReplay)
}
+ def addAllClasspath(args: Seq[String]): Unit = {
+ var added = false
+ var totalClasspath = ""
+ for (arg <- args) {
+ val f = File(arg).normalize
+ if (f.exists) {
+ added = true
+ addedClasspath = ClassPath.join(addedClasspath, f.path)
+ totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
+ }
+ }
+ if (added) replay()
+ }
+
def addClasspath(arg: String): Unit = {
val f = File(arg).normalize
if (f.exists) {
@@ -643,23 +699,36 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
}
else echo("The path '" + f + "' doesn't seem to exist.")
}
-
+
def powerCmd(): Result = {
if (isReplPower) "Already in power mode."
- else enablePowerMode()
+ else enablePowerMode(false)
}
- def enablePowerMode() = {
- //replProps.power setValue true
- //power.unleash()
- //echo(power.banner)
+
+ def enablePowerMode(isDuringInit: Boolean) = {
+ // replProps.power setValue true
+ // unleashAndSetPhase()
+ // asyncEcho(isDuringInit, power.banner)
}
-
+ // private def unleashAndSetPhase() {
+// if (isReplPower) {
+// // power.unleash()
+// // Set the phase to "typer"
+// intp beSilentDuring phaseCommand("typer")
+// }
+// }
+
+ def asyncEcho(async: Boolean, msg: => String) {
+ if (async) asyncMessage(msg)
+ else echo(msg)
+ }
+
def verbosity() = {
- val old = intp.printResults
- intp.printResults = !old
- echo("Switched " + (if (old) "off" else "on") + " result printing.")
+ // val old = intp.printResults
+ // intp.printResults = !old
+ // echo("Switched " + (if (old) "off" else "on") + " result printing.")
}
-
+
/** Run one command submitted by the user. Two values are returned:
* (1) whether to keep running, (2) the line to record for replay,
* if any. */
@@ -674,11 +743,11 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
else if (intp.global == null) Result(false, None) // Notice failure to create compiler
else Result(true, interpretStartingWith(line))
}
-
+
private def readWhile(cond: String => Boolean) = {
Iterator continually in.readLine("") takeWhile (x => x != null && cond(x))
}
-
+
def pasteCommand(): Result = {
echo("// Entering paste mode (ctrl-D to finish)\n")
val code = readWhile(_ => true) mkString "\n"
@@ -686,23 +755,19 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
intp interpret code
()
}
-
+
private object paste extends Pasted {
val ContinueString = " | "
val PromptString = "scala> "
-
+
def interpret(line: String): Unit = {
echo(line.trim)
intp interpret line
echo("")
}
-
+
def transcript(start: String) = {
- // Printing this message doesn't work very well because it's buried in the
- // transcript they just pasted. Todo: a short timer goes off when
- // lines stop coming which tells them to hit ctrl-D.
- //
- // echo("// Detected repl transcript paste: ctrl-D to finish.")
+ echo("\n// Detected repl transcript paste: ctrl-D to finish.\n")
apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim))
}
}
@@ -717,7 +782,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
def interpretStartingWith(code: String): Option[String] = {
// signal completion non-completion input has been received
in.completion.resetVerbosity()
-
+
def reallyInterpret = {
val reallyResult = intp.interpret(code)
(reallyResult, reallyResult match {
@@ -727,7 +792,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
if (in.interactive && code.endsWith("\n\n")) {
echo("You typed two blank lines. Starting a new command.")
None
- }
+ }
else in.readLine(ContinueString) match {
case null =>
// we know compilation is going to fail since we're at EOF and the
@@ -741,10 +806,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
}
})
}
-
+
/** Here we place ourselves between the user and the interpreter and examine
* the input they are ostensibly submitting. We intervene in several cases:
- *
+ *
* 1) If the line starts with "scala> " it is assumed to be an interpreter paste.
* 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation
* on the previous result.
@@ -759,28 +824,12 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") {
interpretStartingWith(intp.mostRecentVar + code)
}
- else {
- def runCompletion = in.completion execute code map (intp bindValue _)
- /** Due to my accidentally letting file completion execution sneak ahead
- * of actual parsing this now operates in such a way that the scala
- * interpretation always wins. However to avoid losing useful file
- * completion I let it fail and then check the others. So if you
- * type /tmp it will echo a failure and then give you a Directory object.
- * It's not pretty: maybe I'll implement the silence bits I need to avoid
- * echoing the failure.
- */
- if (intp isParseable code) {
- val (code, result) = reallyInterpret
- //if (power != null && code == IR.Error)
- // runCompletion
-
- result
- }
- else runCompletion match {
- case Some(_) => None // completion hit: avoid the latent error
- case _ => reallyInterpret._2 // trigger the latent error
- }
+ else if (code.trim startsWith "//") {
+ // line comment, do nothing
+ None
}
+ else
+ reallyInterpret._2
}
// runs :load `file` on any files passed via -i
@@ -794,7 +843,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
}
case _ =>
}
-
+
/** Tries to create a JLineReader, falling back to SimpleReader:
* unless settings or properties are such that it should start
* with SimpleReader.
@@ -802,7 +851,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
def chooseReader(settings: Settings): InteractiveReader = {
if (settings.Xnojline.value || Properties.isEmacsShell)
SimpleReader()
- else try SparkJLineReader(
+ else try new SparkJLineReader(
if (settings.noCompletion.value) NoCompletion
else new SparkJLineCompletion(intp)
)
@@ -813,22 +862,71 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
}
}
- def initializeSpark() {
- intp.beQuietDuring {
- command("""
- org.apache.spark.repl.Main.interp.out.println("Creating SparkContext...");
- org.apache.spark.repl.Main.interp.out.flush();
- @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext();
- org.apache.spark.repl.Main.interp.out.println("Spark context available as sc.");
- org.apache.spark.repl.Main.interp.out.flush();
- """)
- command("import org.apache.spark.SparkContext._")
+ val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
+ val m = u.runtimeMirror(getClass.getClassLoader)
+ private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
+ u.TypeTag[T](
+ m,
+ new TypeCreator {
+ def apply[U <: ApiUniverse with Singleton](m: Mirror[U]): U # Type =
+ m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type]
+ })
+
+ def process(settings: Settings): Boolean = savingContextLoader {
+ this.settings = settings
+ createInterpreter()
+
+ // sets in to some kind of reader depending on environmental cues
+ in = in0 match {
+ case Some(reader) => SimpleReader(reader, out, true)
+ case None =>
+ // some post-initialization
+ chooseReader(settings) match {
+ case x: SparkJLineReader => addThunk(x.consoleReader.postInit) ; x
+ case x => x
+ }
}
- echo("Type in expressions to have them evaluated.")
- echo("Type :help for more information.")
- }
+ lazy val tagOfSparkIMain = tagOfStaticClass[org.apache.spark.repl.SparkIMain]
+ // Bind intp somewhere out of the regular namespace where
+ // we can get at it in generated code.
+ addThunk(intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfSparkIMain, classTag[SparkIMain])))
+ addThunk({
+ import scala.tools.nsc.io._
+ import Properties.userHome
+ import scala.compat.Platform.EOL
+ val autorun = replProps.replAutorunCode.option flatMap (f => io.File(f).safeSlurp())
+ if (autorun.isDefined) intp.quietRun(autorun.get)
+ })
+
+ addThunk(printWelcome())
+ addThunk(initializeSpark())
+
+ loadFiles(settings)
+ // it is broken on startup; go ahead and exit
+ if (intp.reporter.hasErrors)
+ return false
- var sparkContext: SparkContext = null
+ // This is about the illusion of snappiness. We call initialize()
+ // which spins off a separate thread, then print the prompt and try
+ // our best to look ready. The interlocking lazy vals tend to
+ // inter-deadlock, so we break the cycle with a single asynchronous
+ // message to an actor.
+ if (isAsync) {
+ intp initialize initializedCallback()
+ createAsyncListener() // listens for signal to run postInitialization
+ }
+ else {
+ intp.initializeSynchronous()
+ postInitialization()
+ }
+ // printWelcome()
+
+ try loop()
+ catch AbstractOrMissingHandler()
+ finally closeInterpreter()
+
+ true
+ }
def createSparkContext(): SparkContext = {
val uri = System.getenv("SPARK_EXECUTOR_URI")
@@ -842,78 +940,26 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
if (prop != null) prop else "local"
}
}
- val jars = Option(System.getenv("ADD_JARS")).map(_.split(','))
- .getOrElse(new Array[String](0))
- .map(new java.io.File(_).getAbsolutePath)
+ val jars = SparkILoop.getAddedJars.map(new java.io.File(_).getAbsolutePath)
sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
+ echo("Created spark context..")
sparkContext
}
- def process(settings: Settings): Boolean = {
- // Ensure logging is initialized before any Spark threads try to use logs
- // (because SLF4J initialization is not thread safe)
- initLogging()
-
- printWelcome()
- echo("Initializing interpreter...")
-
- // Add JARS specified in Spark's ADD_JARS variable to classpath
- val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0))
- jars.foreach(settings.classpath.append(_))
-
- this.settings = settings
- createInterpreter()
-
- // sets in to some kind of reader depending on environmental cues
- in = in0 match {
- case Some(reader) => SimpleReader(reader, out, true)
- case None => chooseReader(settings)
- }
-
- loadFiles(settings)
- // it is broken on startup; go ahead and exit
- if (intp.reporter.hasErrors)
- return false
-
- try {
- // this is about the illusion of snappiness. We call initialize()
- // which spins off a separate thread, then print the prompt and try
- // our best to look ready. Ideally the user will spend a
- // couple seconds saying "wow, it starts so fast!" and by the time
- // they type a command the compiler is ready to roll.
- intp.initialize()
- initializeSpark()
- if (isReplPower) {
- echo("Starting in power mode, one moment...\n")
- enablePowerMode()
- }
- loop()
- }
- finally closeInterpreter()
- true
- }
-
/** process command-line arguments and do as they request */
def process(args: Array[String]): Boolean = {
- val command = new CommandLine(args.toList, msg => echo("scala: " + msg))
+ val command = new CommandLine(args.toList, echo)
def neededHelp(): String =
(if (command.settings.help.value) command.usageMsg + "\n" else "") +
(if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "")
-
+
// if they asked for no help and command is valid, we call the real main
neededHelp() match {
case "" => command.ok && process(command.settings)
case help => echoNoNL(help) ; true
}
}
-
- @deprecated("Use `process` instead", "2.9.0")
- def main(args: Array[String]): Unit = {
- if (isReplDebug)
- System.out.println(new java.util.Date)
-
- process(args)
- }
+
@deprecated("Use `process` instead", "2.9.0")
def main(settings: Settings): Unit = process(settings)
}
@@ -922,15 +968,17 @@ object SparkILoop {
implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp
private def echo(msg: String) = Console println msg
+ def getAddedJars: Array[String] = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0))
+
// Designed primarily for use by test code: take a String with a
// bunch of code, and prints out a transcript of what it would look
// like if you'd just typed it into the repl.
def runForTranscript(code: String, settings: Settings): String = {
import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
-
+
stringFromStream { ostream =>
Console.withOut(ostream) {
- val output = new PrintWriter(new OutputStreamWriter(ostream), true) {
+ val output = new JPrintWriter(new OutputStreamWriter(ostream), true) {
override def write(str: String) = {
// completely skip continuation lines
if (str forall (ch => ch.isWhitespace || ch == '|')) ()
@@ -949,26 +997,29 @@ object SparkILoop {
}
}
val repl = new SparkILoop(input, output)
+
if (settings.classpath.isDefault)
settings.classpath.value = sys.props("java.class.path")
+ getAddedJars.foreach(settings.classpath.append(_))
+
repl process settings
}
}
}
-
+
/** Creates an interpreter loop with default settings and feeds
* the given code to it as input.
*/
def run(code: String, sets: Settings = new Settings): String = {
import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
-
+
stringFromStream { ostream =>
Console.withOut(ostream) {
val input = new BufferedReader(new StringReader(code))
- val output = new PrintWriter(new OutputStreamWriter(ostream), true)
- val repl = new SparkILoop(input, output)
-
+ val output = new JPrintWriter(new OutputStreamWriter(ostream), true)
+ val repl = new ILoop(input, output)
+
if (sets.classpath.isDefault)
sets.classpath.value = sys.props("java.class.path")
@@ -977,32 +1028,4 @@ object SparkILoop {
}
}
def run(lines: List[String]): String = run(lines map (_ + "\n") mkString)
-
- // provide the enclosing type T
- // in order to set up the interpreter's classpath and parent class loader properly
- def breakIf[T: Manifest](assertion: => Boolean, args: NamedParam*): Unit =
- if (assertion) break[T](args.toList)
-
- // start a repl, binding supplied args
- def break[T: Manifest](args: List[NamedParam]): Unit = {
- val msg = if (args.isEmpty) "" else " Binding " + args.size + " value%s.".format(
- if (args.size == 1) "" else "s"
- )
- echo("Debug repl starting." + msg)
- val repl = new SparkILoop {
- override def prompt = "\ndebug> "
- }
- repl.settings = new Settings(echo)
- repl.settings.embeddedDefaults[T]
- repl.createInterpreter()
- repl.in = SparkJLineReader(repl)
-
- // rebind exit so people don't accidentally call sys.exit by way of predef
- repl.quietRun("""def exit = println("Type :quit to resume program execution.")""")
- args foreach (p => repl.bind(p.name, p.tpe, p.value))
- repl.loop()
-
- echo("\nDebug repl exiting.")
- repl.closeInterpreter()
- }
}
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
new file mode 100644
index 0000000000..21b1ba305d
--- /dev/null
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
@@ -0,0 +1,143 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2013 LAMP/EPFL
+ * @author Paul Phillips
+ */
+
+package org.apache.spark.repl
+
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter._
+
+import scala.reflect.internal.util.Position
+import scala.util.control.Exception.ignoring
+import scala.tools.nsc.util.stackTraceString
+
+/**
+ * Machinery for the asynchronous initialization of the repl.
+ */
+trait SparkILoopInit {
+ self: SparkILoop =>
+
+ /** Print a welcome message */
+ def printWelcome() {
+ echo("""Welcome to
+ ____ __
+ / __/__ ___ _____/ /__
+ _\ \/ _ \/ _ `/ __/ '_/
+ /___/ .__/\_,_/_/ /_/\_\ version 0.9.0-SNAPSHOT
+ /_/
+""")
+ import Properties._
+ val welcomeMsg = "Using Scala %s (%s, Java %s)".format(
+ versionString, javaVmName, javaVersion)
+ echo(welcomeMsg)
+ echo("Type in expressions to have them evaluated.")
+ echo("Type :help for more information.")
+ }
+
+ protected def asyncMessage(msg: String) {
+ if (isReplInfo || isReplPower)
+ echoAndRefresh(msg)
+ }
+
+ private val initLock = new java.util.concurrent.locks.ReentrantLock()
+ private val initCompilerCondition = initLock.newCondition() // signal the compiler is initialized
+ private val initLoopCondition = initLock.newCondition() // signal the whole repl is initialized
+ private val initStart = System.nanoTime
+
+ private def withLock[T](body: => T): T = {
+ initLock.lock()
+ try body
+ finally initLock.unlock()
+ }
+ // a condition used to ensure serial access to the compiler.
+ @volatile private var initIsComplete = false
+ @volatile private var initError: String = null
+ private def elapsed() = "%.3f".format((System.nanoTime - initStart).toDouble / 1000000000L)
+
+ // the method to be called when the interpreter is initialized.
+ // Very important this method does nothing synchronous (i.e. do
+ // not try to use the interpreter) because until it returns, the
+ // repl's lazy val `global` is still locked.
+ protected def initializedCallback() = withLock(initCompilerCondition.signal())
+
+ // Spins off a thread which awaits a single message once the interpreter
+ // has been initialized.
+ protected def createAsyncListener() = {
+ io.spawn {
+ withLock(initCompilerCondition.await())
+ asyncMessage("[info] compiler init time: " + elapsed() + " s.")
+ postInitialization()
+ }
+ }
+
+ // called from main repl loop
+ protected def awaitInitialized(): Boolean = {
+ if (!initIsComplete)
+ withLock { while (!initIsComplete) initLoopCondition.await() }
+ if (initError != null) {
+ println("""
+ |Failed to initialize the REPL due to an unexpected error.
+ |This is a bug, please, report it along with the error diagnostics printed below.
+ |%s.""".stripMargin.format(initError)
+ )
+ false
+ } else true
+ }
+ // private def warningsThunks = List(
+ // () => intp.bind("lastWarnings", "" + typeTag[List[(Position, String)]], intp.lastWarnings _),
+ // )
+
+ protected def postInitThunks = List[Option[() => Unit]](
+ Some(intp.setContextClassLoader _),
+ if (isReplPower) Some(() => enablePowerMode(true)) else None
+ ).flatten
+ // ++ (
+ // warningsThunks
+ // )
+ // called once after init condition is signalled
+ protected def postInitialization() {
+ try {
+ postInitThunks foreach (f => addThunk(f()))
+ runThunks()
+ } catch {
+ case ex: Throwable =>
+ initError = stackTraceString(ex)
+ throw ex
+ } finally {
+ initIsComplete = true
+
+ if (isAsync) {
+ asyncMessage("[info] total init time: " + elapsed() + " s.")
+ withLock(initLoopCondition.signal())
+ }
+ }
+ }
+
+ def initializeSpark() {
+ intp.beQuietDuring {
+ command("""
+ @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext();
+ """)
+ command("import org.apache.spark.SparkContext._")
+ }
+ echo("Spark context available as sc.")
+ }
+
+ // code to be executed only after the interpreter is initialized
+ // and the lazy val `global` can be accessed without risk of deadlock.
+ private var pendingThunks: List[() => Unit] = Nil
+ protected def addThunk(body: => Unit) = synchronized {
+ pendingThunks :+= (() => body)
+ }
+ protected def runThunks(): Unit = synchronized {
+ if (pendingThunks.nonEmpty)
+ logDebug("Clearing " + pendingThunks.size + " thunks.")
+
+ while (pendingThunks.nonEmpty) {
+ val thunk = pendingThunks.head
+ pendingThunks = pendingThunks.tail
+ thunk()
+ }
+ }
+}
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
index e6e35c9b5d..e1455ef8a1 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -1,5 +1,5 @@
/* NSC -- new Scala compiler
- * Copyright 2005-2011 LAMP/EPFL
+ * Copyright 2005-2013 LAMP/EPFL
* @author Martin Odersky
*/
@@ -9,304 +9,333 @@ import scala.tools.nsc._
import scala.tools.nsc.interpreter._
import Predef.{ println => _, _ }
-import java.io.{ PrintWriter }
-import java.lang.reflect
+import util.stringFromWriter
+import scala.reflect.internal.util._
import java.net.URL
-import util.{ Set => _, _ }
-import io.{ AbstractFile, PlainFile, VirtualDirectory }
-import reporters.{ ConsoleReporter, Reporter }
-import symtab.{ Flags, Names }
-import scala.tools.nsc.interpreter.{ Results => IR }
+import scala.sys.BooleanProp
+import io.{AbstractFile, PlainFile, VirtualDirectory}
+
+import reporters._
+import symtab.Flags
+import scala.reflect.internal.Names
import scala.tools.util.PathResolver
-import scala.tools.nsc.util.{ ScalaClassLoader, Exceptional }
+import scala.tools.nsc.util.ScalaClassLoader
import ScalaClassLoader.URLClassLoader
-import Exceptional.unwrap
+import scala.tools.nsc.util.Exceptional.unwrap
import scala.collection.{ mutable, immutable }
-import scala.PartialFunction.{ cond, condOpt }
import scala.util.control.Exception.{ ultimately }
-import scala.reflect.NameTransformer
import SparkIMain._
+import java.util.concurrent.Future
+import typechecker.Analyzer
+import scala.language.implicitConversions
+import scala.reflect.runtime.{ universe => ru }
+import scala.reflect.{ ClassTag, classTag }
+import scala.tools.reflect.StdRuntimeTags._
+import scala.util.control.ControlThrowable
+import util.stackTraceString
import org.apache.spark.HttpServer
import org.apache.spark.util.Utils
import org.apache.spark.SparkEnv
+import org.apache.spark.Logging
+
+// /** directory to save .class files to */
+// private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) {
+// private def pp(root: AbstractFile, indentLevel: Int) {
+// val spaces = " " * indentLevel
+// out.println(spaces + root.name)
+// if (root.isDirectory)
+// root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1))
+// }
+// // print the contents hierarchically
+// def show() = pp(this, 0)
+// }
+
+ /** An interpreter for Scala code.
+ *
+ * The main public entry points are compile(), interpret(), and bind().
+ * The compile() method loads a complete Scala file. The interpret() method
+ * executes one line of Scala code at the request of the user. The bind()
+ * method binds an object to a variable that can then be used by later
+ * interpreted code.
+ *
+ * The overall approach is based on compiling the requested code and then
+ * using a Java classloader and Java reflection to run the code
+ * and access its results.
+ *
+ * In more detail, a single compiler instance is used
+ * to accumulate all successfully compiled or interpreted Scala code. To
+ * "interpret" a line of code, the compiler generates a fresh object that
+ * includes the line of code and which has public member(s) to export
+ * all variables defined by that code. To extract the result of an
+ * interpreted line to show the user, a second "result object" is created
+ * which imports the variables exported by the above object and then
+ * exports members called "$eval" and "$print". To accomodate user expressions
+ * that read from variables or methods defined in previous statements, "import"
+ * statements are used.
+ *
+ * This interpreter shares the strengths and weaknesses of using the
+ * full compiler-to-Java. The main strength is that interpreted code
+ * behaves exactly as does compiled code, including running at full speed.
+ * The main weakness is that redefining classes and methods is not handled
+ * properly, because rebinding at the Java level is technically difficult.
+ *
+ * @author Moez A. Abdel-Gawad
+ * @author Lex Spoon
+ */
+ class SparkIMain(initialSettings: Settings, val out: JPrintWriter) extends SparkImports with Logging {
+ imain =>
-/** An interpreter for Scala code.
- *
- * The main public entry points are compile(), interpret(), and bind().
- * The compile() method loads a complete Scala file. The interpret() method
- * executes one line of Scala code at the request of the user. The bind()
- * method binds an object to a variable that can then be used by later
- * interpreted code.
- *
- * The overall approach is based on compiling the requested code and then
- * using a Java classloader and Java reflection to run the code
- * and access its results.
- *
- * In more detail, a single compiler instance is used
- * to accumulate all successfully compiled or interpreted Scala code. To
- * "interpret" a line of code, the compiler generates a fresh object that
- * includes the line of code and which has public member(s) to export
- * all variables defined by that code. To extract the result of an
- * interpreted line to show the user, a second "result object" is created
- * which imports the variables exported by the above object and then
- * exports a single member named "$export". To accomodate user expressions
- * that read from variables or methods defined in previous statements, "import"
- * statements are used.
- *
- * This interpreter shares the strengths and weaknesses of using the
- * full compiler-to-Java. The main strength is that interpreted code
- * behaves exactly as does compiled code, including running at full speed.
- * The main weakness is that redefining classes and methods is not handled
- * properly, because rebinding at the Java level is technically difficult.
- *
- * @author Moez A. Abdel-Gawad
- * @author Lex Spoon
- */
-class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends SparkImports {
- imain =>
-
- /** construct an interpreter that reports to Console */
- def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
- def this() = this(new Settings())
-
- /** whether to print out result lines */
- var printResults: Boolean = true
-
- /** whether to print errors */
- var totalSilence: Boolean = false
-
- private val RESULT_OBJECT_PREFIX = "RequestResult$"
-
- lazy val formatting: Formatting = new Formatting {
- val prompt = Properties.shellPromptString
- }
- import formatting._
-
- val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1")
-
- /** Local directory to save .class files too */
- val outputDir = {
- val tmp = System.getProperty("java.io.tmpdir")
- val rootDir = System.getProperty("spark.repl.classdir", tmp)
- Utils.createTempDir(rootDir)
- }
- if (SPARK_DEBUG_REPL) {
- echo("Output directory: " + outputDir)
- }
+ val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1")
- /** Scala compiler virtual directory for outputDir */
- val virtualDirectory = new PlainFile(outputDir)
+ /** Local directory to save .class files too */
+ val outputDir = {
+ val tmp = System.getProperty("java.io.tmpdir")
+ val rootDir = System.getProperty("spark.repl.classdir", tmp)
+ Utils.createTempDir(rootDir)
+ }
+ if (SPARK_DEBUG_REPL) {
+ echo("Output directory: " + outputDir)
+ }
- /** Jetty server that will serve our classes to worker nodes */
- val classServer = new HttpServer(outputDir)
+ val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles
+ val classServer = new HttpServer(outputDir) /** Jetty server that will serve our classes to worker nodes */
+ private var currentSettings: Settings = initialSettings
+ var printResults = true // whether to print result lines
+ var totalSilence = false // whether to print anything
+ private var _initializeComplete = false // compiler is initialized
+ private var _isInitialized: Future[Boolean] = null // set up initialization future
+ private var bindExceptions = true // whether to bind the lastException variable
+ private var _executionWrapper = "" // code to be wrapped around all lines
+
+
+ // Start the classServer and store its URI in a spark system property
+ // (which will be passed to executors so that they can connect to it)
+ classServer.start()
+ System.setProperty("spark.repl.class.uri", classServer.uri)
+ if (SPARK_DEBUG_REPL) {
+ echo("Class server started, URI = " + classServer.uri)
+ }
- // Start the classServer and store its URI in a spark system property
- // (which will be passed to executors so that they can connect to it)
- classServer.start()
- System.setProperty("spark.repl.class.uri", classServer.uri)
- if (SPARK_DEBUG_REPL) {
- echo("Class server started, URI = " + classServer.uri)
- }
+ /** We're going to go to some trouble to initialize the compiler asynchronously.
+ * It's critical that nothing call into it until it's been initialized or we will
+ * run into unrecoverable issues, but the perceived repl startup time goes
+ * through the roof if we wait for it. So we initialize it with a future and
+ * use a lazy val to ensure that any attempt to use the compiler object waits
+ * on the future.
+ */
+ private var _classLoader: AbstractFileClassLoader = null // active classloader
+ private val _compiler: Global = newCompiler(settings, reporter) // our private compiler
- /*
- // directory to save .class files to
- val virtualDirectory = new VirtualDirectory("(memory)", None) {
- private def pp(root: io.AbstractFile, indentLevel: Int) {
- val spaces = " " * indentLevel
- out.println(spaces + root.name)
- if (root.isDirectory)
- root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1))
+ private val nextReqId = {
+ var counter = 0
+ () => { counter += 1 ; counter }
}
- // print the contents hierarchically
- def show() = pp(this, 0)
- }
- */
-
- /** reporter */
- lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this)
- import reporter.{ printMessage, withoutTruncating }
-
- // not sure if we have some motivation to print directly to console
- private def echo(msg: String) { Console println msg }
-
- // protected def defaultImports: List[String] = List("_root_.scala.sys.exit")
-
- /** We're going to go to some trouble to initialize the compiler asynchronously.
- * It's critical that nothing call into it until it's been initialized or we will
- * run into unrecoverable issues, but the perceived repl startup time goes
- * through the roof if we wait for it. So we initialize it with a future and
- * use a lazy val to ensure that any attempt to use the compiler object waits
- * on the future.
- */
- private val _compiler: Global = newCompiler(settings, reporter)
- private var _initializeComplete = false
- def isInitializeComplete = _initializeComplete
-
- private def _initialize(): Boolean = {
- val source = """
- |class $repl_$init {
- | List(1) map (_ + 1)
- |}
- |""".stripMargin
-
- val result = try {
- new _compiler.Run() compileSources List(new BatchSourceFile("<init>", source))
- if (isReplDebug || settings.debug.value) {
- // Can't use printMessage here, it deadlocks
- Console.println("Repl compiler initialized.")
- }
- // addImports(defaultImports: _*)
- true
- }
- catch {
- case x: AbstractMethodError =>
- printMessage("""
- |Failed to initialize compiler: abstract method error.
- |This is most often remedied by a full clean and recompile.
- |""".stripMargin
- )
- x.printStackTrace()
- false
- case x: MissingRequirementError => printMessage("""
- |Failed to initialize compiler: %s not found.
- |** Note that as of 2.8 scala does not assume use of the java classpath.
- |** For the old behavior pass -usejavacp to scala, or if using a Settings
- |** object programatically, settings.usejavacp.value = true.""".stripMargin.format(x.req)
+
+ def compilerClasspath: Seq[URL] = (
+ if (isInitializeComplete) global.classPath.asURLs
+ else new PathResolver(settings).result.asURLs // the compiler's classpath
)
- false
+ def settings = currentSettings
+ def mostRecentLine = prevRequestList match {
+ case Nil => ""
+ case req :: _ => req.originalLine
+ }
+ // Run the code body with the given boolean settings flipped to true.
+ def withoutWarnings[T](body: => T): T = beQuietDuring {
+ val saved = settings.nowarn.value
+ if (!saved)
+ settings.nowarn.value = true
+
+ try body
+ finally if (!saved) settings.nowarn.value = false
}
-
- try result
- finally _initializeComplete = result
- }
-
- // set up initialization future
- private var _isInitialized: () => Boolean = null
- def initialize() = synchronized {
- if (_isInitialized == null)
- _isInitialized = scala.concurrent.ops future _initialize()
- }
- /** the public, go through the future compiler */
- lazy val global: Global = {
- initialize()
+ /** construct an interpreter that reports to Console */
+ def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
+ def this() = this(new Settings())
- // blocks until it is ; false means catastrophic failure
- if (_isInitialized()) _compiler
- else null
- }
- @deprecated("Use `global` for access to the compiler instance.", "2.9.0")
- lazy val compiler: global.type = global
-
- import global._
-
- object naming extends {
- val global: imain.global.type = imain.global
- } with Naming {
- // make sure we don't overwrite their unwisely named res3 etc.
- override def freshUserVarName(): String = {
- val name = super.freshUserVarName()
- if (definedNameMap contains name) freshUserVarName()
- else name
+ lazy val repllog: Logger = new Logger {
+ val out: JPrintWriter = imain.out
+ val isInfo: Boolean = BooleanProp keyExists "scala.repl.info"
+ val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug"
+ val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace"
}
- }
- import naming._
-
- // object dossiers extends {
- // val intp: imain.type = imain
- // } with Dossiers { }
- // import dossiers._
-
- lazy val memberHandlers = new {
- val intp: imain.type = imain
- } with SparkMemberHandlers
- import memberHandlers._
-
- def atPickler[T](op: => T): T = atPhase(currentRun.picklerPhase)(op)
- def afterTyper[T](op: => T): T = atPhase(currentRun.typerPhase.next)(op)
-
- /** Temporarily be quiet */
- def beQuietDuring[T](operation: => T): T = {
- val wasPrinting = printResults
- ultimately(printResults = wasPrinting) {
- if (isReplDebug) echo(">> beQuietDuring")
- else printResults = false
-
- operation
+ lazy val formatting: Formatting = new Formatting {
+ val prompt = Properties.shellPromptString
}
- }
- def beSilentDuring[T](operation: => T): T = {
- val saved = totalSilence
- totalSilence = true
- try operation
- finally totalSilence = saved
- }
-
- def quietRun[T](code: String) = beQuietDuring(interpret(code))
-
- /** whether to bind the lastException variable */
- private var bindLastException = true
-
- /** A string representing code to be wrapped around all lines. */
- private var _executionWrapper: String = ""
- def executionWrapper = _executionWrapper
- def setExecutionWrapper(code: String) = _executionWrapper = code
- def clearExecutionWrapper() = _executionWrapper = ""
-
- /** Temporarily stop binding lastException */
- def withoutBindingLastException[T](operation: => T): T = {
- val wasBinding = bindLastException
- ultimately(bindLastException = wasBinding) {
- bindLastException = false
- operation
+ lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this)
+
+ import formatting._
+ import reporter.{ printMessage, withoutTruncating }
+
+ // This exists mostly because using the reporter too early leads to deadlock.
+ private def echo(msg: String) { Console println msg }
+ private def _initSources = List(new BatchSourceFile("<init>", "class $repl_$init { }"))
+ private def _initialize() = {
+ try {
+ // todo. if this crashes, REPL will hang
+ new _compiler.Run() compileSources _initSources
+ _initializeComplete = true
+ true
+ }
+ catch AbstractOrMissingHandler()
+ }
+ private def tquoted(s: String) = "\"\"\"" + s + "\"\"\""
+
+ // argument is a thunk to execute after init is done
+ def initialize(postInitSignal: => Unit) {
+ synchronized {
+ if (_isInitialized == null) {
+ _isInitialized = io.spawn {
+ try _initialize()
+ finally postInitSignal
+ }
+ }
+ }
+ }
+ def initializeSynchronous(): Unit = {
+ if (!isInitializeComplete) {
+ _initialize()
+ assert(global != null, global)
+ }
+ }
+ def isInitializeComplete = _initializeComplete
+
+ /** the public, go through the future compiler */
+ lazy val global: Global = {
+ if (isInitializeComplete) _compiler
+ else {
+ // If init hasn't been called yet you're on your own.
+ if (_isInitialized == null) {
+ logWarning("Warning: compiler accessed before init set up. Assuming no postInit code.")
+ initialize(())
+ }
+ // // blocks until it is ; false means catastrophic failure
+ if (_isInitialized.get()) _compiler
+ else null
+ }
+ }
+ @deprecated("Use `global` for access to the compiler instance.", "2.9.0")
+ lazy val compiler: global.type = global
+
+ import global._
+ import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember}
+ import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass}
+
+ implicit class ReplTypeOps(tp: Type) {
+ def orElse(other: => Type): Type = if (tp ne NoType) tp else other
+ def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp)
+ }
+
+ // TODO: If we try to make naming a lazy val, we run into big time
+ // scalac unhappiness with what look like cycles. It has not been easy to
+ // reduce, but name resolution clearly takes different paths.
+ object naming extends {
+ val global: imain.global.type = imain.global
+ } with Naming {
+ // make sure we don't overwrite their unwisely named res3 etc.
+ def freshUserTermName(): TermName = {
+ val name = newTermName(freshUserVarName())
+ if (definedNameMap contains name) freshUserTermName()
+ else name
+ }
+ def isUserTermName(name: Name) = isUserVarName("" + name)
+ def isInternalTermName(name: Name) = isInternalVarName("" + name)
+ }
+ import naming._
+
+ object deconstruct extends {
+ val global: imain.global.type = imain.global
+ } with StructuredTypeStrings
+
+ lazy val memberHandlers = new {
+ val intp: imain.type = imain
+ } with SparkMemberHandlers
+ import memberHandlers._
+
+ /** Temporarily be quiet */
+ def beQuietDuring[T](body: => T): T = {
+ val saved = printResults
+ printResults = false
+ try body
+ finally printResults = saved
+ }
+ def beSilentDuring[T](operation: => T): T = {
+ val saved = totalSilence
+ totalSilence = true
+ try operation
+ finally totalSilence = saved
+ }
+
+ def quietRun[T](code: String) = beQuietDuring(interpret(code))
+
+
+ private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = {
+ case t: ControlThrowable => throw t
+ case t: Throwable =>
+ logDebug(label + ": " + unwrap(t))
+ logDebug(stackTraceString(unwrap(t)))
+ alt
+ }
+ /** takes AnyRef because it may be binding a Throwable or an Exceptional */
+
+ private def withLastExceptionLock[T](body: => T, alt: => T): T = {
+ assert(bindExceptions, "withLastExceptionLock called incorrectly.")
+ bindExceptions = false
+
+ try beQuietDuring(body)
+ catch logAndDiscard("withLastExceptionLock", alt)
+ finally bindExceptions = true
}
- }
-
- protected def createLineManager(): Line.Manager = new Line.Manager
- lazy val lineManager = createLineManager()
-
- /** interpreter settings */
- lazy val isettings = new SparkISettings(this)
-
- /** Instantiate a compiler. Subclasses can override this to
- * change the compiler class used by this interpreter. */
- protected def newCompiler(settings: Settings, reporter: Reporter) = {
- settings.outputDirs setSingleOutput virtualDirectory
- settings.exposeEmptyPackage.value = true
- new Global(settings, reporter)
- }
-
- /** the compiler's classpath, as URL's */
- lazy val compilerClasspath: List[URL] = new PathResolver(settings) asURLs
- /* A single class loader is used for all commands interpreted by this Interpreter.
+ def executionWrapper = _executionWrapper
+ def setExecutionWrapper(code: String) = _executionWrapper = code
+ def clearExecutionWrapper() = _executionWrapper = ""
+
+ /** interpreter settings */
+ lazy val isettings = new SparkISettings(this)
+
+ /** Instantiate a compiler. Overridable. */
+ protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = {
+ settings.outputDirs setSingleOutput virtualDirectory
+ settings.exposeEmptyPackage.value = true
+ new Global(settings, reporter) with ReplGlobal {
+ override def toString: String = "<global>"
+ }
+ }
+
+ /** Parent classloader. Overridable. */
+ protected def parentClassLoader: ClassLoader =
+ SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() )
+
+ /* A single class loader is used for all commands interpreted by this Interpreter.
It would also be possible to create a new class loader for each command
to interpret. The advantages of the current approach are:
- - Expressions are only evaluated one time. This is especially
- significant for I/O, e.g. "val x = Console.readLine"
-
- The main disadvantage is:
-
- - Objects, classes, and methods cannot be rebound. Instead, definitions
- shadow the old ones, and old code objects refer to the old
- definitions.
- */
- private var _classLoader: AbstractFileClassLoader = null
- def resetClassLoader() = _classLoader = makeClassLoader()
- def classLoader: AbstractFileClassLoader = {
- if (_classLoader == null)
- resetClassLoader()
-
- _classLoader
- }
- private def makeClassLoader(): AbstractFileClassLoader = {
- val parent =
- if (parentClassLoader == null) ScalaClassLoader fromURLs compilerClasspath
- else new URLClassLoader(compilerClasspath, parentClassLoader)
+ - Expressions are only evaluated one time. This is especially
+ significant for I/O, e.g. "val x = Console.readLine"
+
+ The main disadvantage is:
- new AbstractFileClassLoader(virtualDirectory, parent) {
+ - Objects, classes, and methods cannot be rebound. Instead, definitions
+ shadow the old ones, and old code objects refer to the old
+ definitions.
+ */
+ def resetClassLoader() = {
+ logDebug("Setting new classloader: was " + _classLoader)
+ _classLoader = null
+ ensureClassLoader()
+ }
+ final def ensureClassLoader() {
+ if (_classLoader == null)
+ _classLoader = makeClassLoader()
+ }
+ def classLoader: AbstractFileClassLoader = {
+ ensureClassLoader()
+ _classLoader
+ }
+ private class TranslatingClassLoader(parent: ClassLoader) extends AbstractFileClassLoader(virtualDirectory, parent) {
/** Overridden here to try translating a simple name to the generated
* class name if the original attempt fails. This method is used by
* getResourceAsStream as well as findClass.
@@ -314,223 +343,300 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
override protected def findAbstractFile(name: String): AbstractFile = {
super.findAbstractFile(name) match {
// deadlocks on startup if we try to translate names too early
- case null if isInitializeComplete => generatedName(name) map (x => super.findAbstractFile(x)) orNull
- case file => file
+ case null if isInitializeComplete =>
+ generatedName(name) map (x => super.findAbstractFile(x)) orNull
+ case file =>
+ file
}
}
}
- }
- private def loadByName(s: String): JClass =
- (classLoader tryToInitializeClass s) getOrElse sys.error("Failed to load expected class: '" + s + "'")
-
- protected def parentClassLoader: ClassLoader =
- SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() )
-
- def getInterpreterClassLoader() = classLoader
-
- // Set the current Java "context" class loader to this interpreter's class loader
- def setContextClassLoader() = classLoader.setAsContext()
-
- /** Given a simple repl-defined name, returns the real name of
- * the class representing it, e.g. for "Bippy" it may return
- *
- * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy
- */
- def generatedName(simpleName: String): Option[String] = {
- if (simpleName endsWith "$") optFlatName(simpleName.init) map (_ + "$")
- else optFlatName(simpleName)
- }
- def flatName(id: String) = optFlatName(id) getOrElse id
- def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id)
-
- def allDefinedNames = definedNameMap.keys.toList sortBy (_.toString)
- def pathToType(id: String): String = pathToName(newTypeName(id))
- def pathToTerm(id: String): String = pathToName(newTermName(id))
- def pathToName(name: Name): String = {
- if (definedNameMap contains name)
- definedNameMap(name) fullPath name
- else name.toString
- }
+ private def makeClassLoader(): AbstractFileClassLoader =
+ new TranslatingClassLoader(parentClassLoader match {
+ case null => ScalaClassLoader fromURLs compilerClasspath
+ case p => new URLClassLoader(compilerClasspath, p)
+ })
+
+ def getInterpreterClassLoader() = classLoader
+
+ // Set the current Java "context" class loader to this interpreter's class loader
+ def setContextClassLoader() = classLoader.setAsContext()
+
+ /** Given a simple repl-defined name, returns the real name of
+ * the class representing it, e.g. for "Bippy" it may return
+ * {{{
+ * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy
+ * }}}
+ */
+ def generatedName(simpleName: String): Option[String] = {
+ if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING)
+ else optFlatName(simpleName)
+ }
+ def flatName(id: String) = optFlatName(id) getOrElse id
+ def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id)
+
+ def allDefinedNames = definedNameMap.keys.toList.sorted
+ def pathToType(id: String): String = pathToName(newTypeName(id))
+ def pathToTerm(id: String): String = pathToName(newTermName(id))
+ def pathToName(name: Name): String = {
+ if (definedNameMap contains name)
+ definedNameMap(name) fullPath name
+ else name.toString
+ }
- /** Most recent tree handled which wasn't wholly synthetic. */
- private def mostRecentlyHandledTree: Option[Tree] = {
- prevRequests.reverse foreach { req =>
- req.handlers.reverse foreach {
- case x: MemberDefHandler if x.definesValue && !isInternalVarName(x.name.toString) => return Some(x.member)
- case _ => ()
+ /** Most recent tree handled which wasn't wholly synthetic. */
+ private def mostRecentlyHandledTree: Option[Tree] = {
+ prevRequests.reverse foreach { req =>
+ req.handlers.reverse foreach {
+ case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member)
+ case _ => ()
+ }
}
+ None
}
- None
- }
-
- /** Stubs for work in progress. */
- def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = {
- for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) {
- DBG("Redefining type '%s'\n %s -> %s".format(name, t1, t2))
+
+ /** Stubs for work in progress. */
+ def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = {
+ for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) {
+ logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2))
+ }
}
- }
- def handleTermRedefinition(name: TermName, old: Request, req: Request) = {
- for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) {
- // Printing the types here has a tendency to cause assertion errors, like
- // assertion failed: fatal: <refinement> has owner value x, but a class owner is required
- // so DBG is by-name now to keep it in the family. (It also traps the assertion error,
- // but we don't want to unnecessarily risk hosing the compiler's internal state.)
- DBG("Redefining term '%s'\n %s -> %s".format(name, t1, t2))
+ def handleTermRedefinition(name: TermName, old: Request, req: Request) = {
+ for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) {
+ // Printing the types here has a tendency to cause assertion errors, like
+ // assertion failed: fatal: <refinement> has owner value x, but a class owner is required
+ // so DBG is by-name now to keep it in the family. (It also traps the assertion error,
+ // but we don't want to unnecessarily risk hosing the compiler's internal state.)
+ logDebug("Redefining term '%s'\n %s -> %s".format(name, t1, t2))
+ }
}
- }
- def recordRequest(req: Request) {
- if (req == null || referencedNameMap == null)
- return
-
- prevRequests += req
- req.referencedNames foreach (x => referencedNameMap(x) = req)
-
- // warning about serially defining companions. It'd be easy
- // enough to just redefine them together but that may not always
- // be what people want so I'm waiting until I can do it better.
- if (!settings.nowarnings.value) {
+
+ def recordRequest(req: Request) {
+ if (req == null || referencedNameMap == null)
+ return
+
+ prevRequests += req
+ req.referencedNames foreach (x => referencedNameMap(x) = req)
+
+ // warning about serially defining companions. It'd be easy
+ // enough to just redefine them together but that may not always
+ // be what people want so I'm waiting until I can do it better.
for {
name <- req.definedNames filterNot (x => req.definedNames contains x.companionName)
oldReq <- definedNameMap get name.companionName
newSym <- req.definedSymbols get name
oldSym <- oldReq.definedSymbols get name.companionName
+ if Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }
} {
- printMessage("warning: previously defined %s is not a companion to %s.".format(oldSym, newSym))
- printMessage("Companions must be defined together; you may wish to use :paste mode for this.")
+ afterTyper(replwarn(s"warning: previously defined $oldSym is not a companion to $newSym."))
+ replwarn("Companions must be defined together; you may wish to use :paste mode for this.")
}
- }
-
- // Updating the defined name map
- req.definedNames foreach { name =>
- if (definedNameMap contains name) {
- if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req)
- else handleTermRedefinition(name.toTermName, definedNameMap(name), req)
+
+ // Updating the defined name map
+ req.definedNames foreach { name =>
+ if (definedNameMap contains name) {
+ if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req)
+ else handleTermRedefinition(name.toTermName, definedNameMap(name), req)
+ }
+ definedNameMap(name) = req
}
- definedNameMap(name) = req
}
- }
- /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */
- def parse(line: String): Option[List[Tree]] = {
- var justNeedsMore = false
- reporter.withIncompleteHandler((pos,msg) => {justNeedsMore = true}) {
- // simple parse: just parse it, nothing else
- def simpleParse(code: String): List[Tree] = {
- reporter.reset()
- val unit = new CompilationUnit(new BatchSourceFile("<console>", code))
- val scanner = new syntaxAnalyzer.UnitParser(unit)
-
- scanner.templateStatSeq(false)._2
- }
- val trees = simpleParse(line)
-
- if (reporter.hasErrors) Some(Nil) // the result did not parse, so stop
- else if (justNeedsMore) None
- else Some(trees)
+ def replwarn(msg: => String) {
+ if (!settings.nowarnings.value)
+ printMessage(msg)
}
- }
-
- def isParseable(line: String): Boolean = {
- beSilentDuring {
- parse(line) match {
- case Some(xs) => xs.nonEmpty // parses as-is
- case None => true // incomplete
+
+ def isParseable(line: String): Boolean = {
+ beSilentDuring {
+ try parse(line) match {
+ case Some(xs) => xs.nonEmpty // parses as-is
+ case None => true // incomplete
+ }
+ catch { case x: Exception => // crashed the compiler
+ replwarn("Exception in isParseable(\"" + line + "\"): " + x)
+ false
+ }
}
}
+
+ def compileSourcesKeepingRun(sources: SourceFile*) = {
+ val run = new Run()
+ reporter.reset()
+ run compileSources sources.toList
+ (!reporter.hasErrors, run)
+ }
+
+ /** Compile an nsc SourceFile. Returns true if there are
+ * no compilation errors, or false otherwise.
+ */
+ def compileSources(sources: SourceFile*): Boolean =
+ compileSourcesKeepingRun(sources: _*)._1
+
+ /** Compile a string. Returns true if there are no
+ * compilation errors, or false otherwise.
+ */
+ def compileString(code: String): Boolean =
+ compileSources(new BatchSourceFile("<script>", code))
+
+ /** Build a request from the user. `trees` is `line` after being parsed.
+ */
+ private def buildRequest(line: String, trees: List[Tree]): Request = {
+ executingRequest = new Request(line, trees)
+ executingRequest
+ }
+
+ // rewriting "5 // foo" to "val x = { 5 // foo }" creates broken code because
+ // the close brace is commented out. Strip single-line comments.
+ // ... but for error message output reasons this is not used, and rather than
+ // enclosing in braces it is constructed like "val x =\n5 // foo".
+ private def removeComments(line: String): String = {
+ showCodeIfDebugging(line) // as we're about to lose our // show
+ line.lines map (s => s indexOf "//" match {
+ case -1 => s
+ case idx => s take idx
+ }) mkString "\n"
}
- /** Compile an nsc SourceFile. Returns true if there are
- * no compilation errors, or false otherwise.
- */
- def compileSources(sources: SourceFile*): Boolean = {
- reporter.reset()
- new Run() compileSources sources.toList
- !reporter.hasErrors
+ private def safePos(t: Tree, alt: Int): Int =
+ try t.pos.startOrPoint
+ catch { case _: UnsupportedOperationException => alt }
+
+ // Given an expression like 10 * 10 * 10 we receive the parent tree positioned
+ // at a '*'. So look at each subtree and find the earliest of all positions.
+ private def earliestPosition(tree: Tree): Int = {
+ var pos = Int.MaxValue
+ tree foreach { t =>
+ pos = math.min(pos, safePos(t, Int.MaxValue))
+ }
+ pos
}
- /** Compile a string. Returns true if there are no
- * compilation errors, or false otherwise.
- */
- def compileString(code: String): Boolean =
- compileSources(new BatchSourceFile("<script>", code))
- /** Build a request from the user. `trees` is `line` after being parsed.
- */
- private def buildRequest(line: String, trees: List[Tree]): Request = new Request(line, trees)
-
private def requestFromLine(line: String, synthetic: Boolean): Either[IR.Result, Request] = {
- val trees = parse(indentCode(line)) match {
+ val content = indentCode(line)
+ val trees = parse(content) match {
case None => return Left(IR.Incomplete)
case Some(Nil) => return Left(IR.Error) // parse error or empty input
case Some(trees) => trees
}
-
- // use synthetic vars to avoid filling up the resXX slots
- def varName = if (synthetic) freshInternalVarName() else freshUserVarName()
-
- // Treat a single bare expression specially. This is necessary due to it being hard to
- // modify code at a textual level, and it being hard to submit an AST to the compiler.
- if (trees.size == 1) trees.head match {
- case _:Assign => // we don't want to include assignments
- case _:TermTree | _:Ident | _:Select => // ... but do want these as valdefs.
- requestFromLine("val %s =\n%s".format(varName, line), synthetic) match {
+ logDebug(
+ trees map (t => {
+ // [Eugene to Paul] previously it just said `t map ...`
+ // because there was an implicit conversion from Tree to a list of Trees
+ // however Martin and I have removed the conversion
+ // (it was conflicting with the new reflection API),
+ // so I had to rewrite this a bit
+ val subs = t collect { case sub => sub }
+ subs map (t0 =>
+ " " + safePos(t0, -1) + ": " + t0.shortClass + "\n"
+ ) mkString ""
+ }) mkString "\n"
+ )
+ // If the last tree is a bare expression, pinpoint where it begins using the
+ // AST node position and snap the line off there. Rewrite the code embodied
+ // by the last tree as a ValDef instead, so we can access the value.
+ trees.last match {
+ case _:Assign => // we don't want to include assignments
+ case _:TermTree | _:Ident | _:Select => // ... but do want other unnamed terms.
+ val varName = if (synthetic) freshInternalVarName() else freshUserVarName()
+ val rewrittenLine = (
+ // In theory this would come out the same without the 1-specific test, but
+ // it's a cushion against any more sneaky parse-tree position vs. code mismatches:
+ // this way such issues will only arise on multiple-statement repl input lines,
+ // which most people don't use.
+ if (trees.size == 1) "val " + varName + " =\n" + content
+ else {
+ // The position of the last tree
+ val lastpos0 = earliestPosition(trees.last)
+ // Oh boy, the parser throws away parens so "(2+2)" is mispositioned,
+ // with increasingly hard to decipher positions as we move on to "() => 5",
+ // (x: Int) => x + 1, and more. So I abandon attempts to finesse and just
+ // look for semicolons and newlines, which I'm sure is also buggy.
+ val (raw1, raw2) = content splitAt lastpos0
+ logDebug("[raw] " + raw1 + " <---> " + raw2)
+
+ val adjustment = (raw1.reverse takeWhile (ch => (ch != ';') && (ch != '\n'))).size
+ val lastpos = lastpos0 - adjustment
+
+ // the source code split at the laboriously determined position.
+ val (l1, l2) = content splitAt lastpos
+ logDebug("[adj] " + l1 + " <---> " + l2)
+
+ val prefix = if (l1.trim == "") "" else l1 + ";\n"
+ // Note to self: val source needs to have this precise structure so that
+ // error messages print the user-submitted part without the "val res0 = " part.
+ val combined = prefix + "val " + varName + " =\n" + l2
+
+ logDebug(List(
+ " line" -> line,
+ " content" -> content,
+ " was" -> l2,
+ "combined" -> combined) map {
+ case (label, s) => label + ": '" + s + "'"
+ } mkString "\n"
+ )
+ combined
+ }
+ )
+ // Rewriting "foo ; bar ; 123"
+ // to "foo ; bar ; val resXX = 123"
+ requestFromLine(rewrittenLine, synthetic) match {
case Right(req) => return Right(req withOriginalLine line)
case x => return x
}
- case _ =>
+ case _ =>
}
-
- // figure out what kind of request
Right(buildRequest(line, trees))
}
- /**
- * Interpret one line of input. All feedback, including parse errors
- * and evaluation results, are printed via the supplied compiler's
- * reporter. Values defined are available for future interpreted
- * strings.
- *
- *
- * The return value is whether the line was interpreter successfully,
- * e.g. that there were no parse errors.
- *
+ // normalize non-public types so we don't see protected aliases like Self
+ def normalizeNonPublic(tp: Type) = tp match {
+ case TypeRef(_, sym, _) if sym.isAliasType && !sym.isPublic => tp.dealias
+ case _ => tp
+ }
+
+ /**
+ * Interpret one line of input. All feedback, including parse errors
+ * and evaluation results, are printed via the supplied compiler's
+ * reporter. Values defined are available for future interpreted strings.
*
- * @param line ...
- * @return ...
+ * The return value is whether the line was interpreter successfully,
+ * e.g. that there were no parse errors.
*/
def interpret(line: String): IR.Result = interpret(line, false)
+ def interpretSynthetic(line: String): IR.Result = interpret(line, true)
def interpret(line: String, synthetic: Boolean): IR.Result = {
def loadAndRunReq(req: Request) = {
+ classLoader.setAsContext()
val (result, succeeded) = req.loadAndRun
+
/** To our displeasure, ConsoleReporter offers only printMessage,
* which tacks a newline on the end. Since that breaks all the
* output checking, we have to take one off to balance.
*/
- def show() = {
- if (result == "") ()
- else printMessage(result stripSuffix "\n")
- }
-
if (succeeded) {
- if (printResults)
- show()
+ if (printResults && result != "")
+ printMessage(result stripSuffix "\n")
+ else if (isReplDebug) // show quiet-mode activity
+ printMessage(result.trim.lines map ("[quiet] " + _) mkString "\n")
+
// Book-keeping. Have to record synthetic requests too,
// as they may have been issued for information, e.g. :type
recordRequest(req)
IR.Success
}
- else {
- // don't truncate stack traces
- withoutTruncating(show())
- IR.Error
- }
+ else {
+ // don't truncate stack traces
+ withoutTruncating(printMessage(result))
+ IR.Error
+ }
}
-
+
if (global == null) IR.Error
else requestFromLine(line, synthetic) match {
case Left(result) => result
- case Right(req) =>
+ case Right(req) =>
// null indicates a disallowed statement type; otherwise compile and
// fail if false (implying e.g. a type error)
if (req == null || !req.compile) IR.Error
@@ -546,23 +652,39 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
* @param value the object value to bind to it
* @return an indication of whether the binding succeeded
*/
- def bind(name: String, boundType: String, value: Any): IR.Result = {
+ def bind(name: String, boundType: String, value: Any, modifiers: List[String] = Nil): IR.Result = {
val bindRep = new ReadEvalPrint()
val run = bindRep.compile("""
- |object %s {
- | var value: %s = _
- | def set(x: Any) = value = x.asInstanceOf[%s]
- |}
- """.stripMargin.format(bindRep.evalName, boundType, boundType)
- )
- bindRep.callOpt("set", value) match {
- case Some(_) => interpret("val %s = %s.value".format(name, bindRep.evalPath))
- case _ => DBG("Set failed in bind(%s, %s, %s)".format(name, boundType, value)) ; IR.Error
+ |object %s {
+ | var value: %s = _
+ | def set(x: Any) = value = x.asInstanceOf[%s]
+ |}
+ """.stripMargin.format(bindRep.evalName, boundType, boundType)
+ )
+ bindRep.callEither("set", value) match {
+ case Left(ex) =>
+ logDebug("Set failed in bind(%s, %s, %s)".format(name, boundType, value))
+ logDebug(util.stackTraceString(ex))
+ IR.Error
+
+ case Right(_) =>
+ val line = "%sval %s = %s.value".format(modifiers map (_ + " ") mkString, name, bindRep.evalPath)
+ logDebug("Interpreting: " + line)
+ interpret(line)
}
}
+ def directBind(name: String, boundType: String, value: Any): IR.Result = {
+ val result = bind(name, boundType, value)
+ if (result == IR.Success)
+ directlyBoundNames += newTermName(name)
+ result
+ }
+ def directBind(p: NamedParam): IR.Result = directBind(p.name, p.tpe, p.value)
+ def directBind[T: ru.TypeTag : ClassTag](name: String, value: T): IR.Result = directBind((name, value))
+
def rebind(p: NamedParam): IR.Result = {
val name = p.name
- val oldType = typeOfTerm(name) getOrElse { return IR.Error }
+ val oldType = typeOfTerm(name) orElse { return IR.Error }
val newType = p.tpe
val tempName = freshInternalVarName()
@@ -570,23 +692,27 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
quietRun("val %s = %s.asInstanceOf[%s]".format(name, tempName, newType))
}
def quietImport(ids: String*): IR.Result = beQuietDuring(addImports(ids: _*))
- def addImports(ids: String*): IR.Result =
+ def addImports(ids: String*): IR.Result =
if (ids.isEmpty) IR.Success
else interpret("import " + ids.mkString(", "))
- def quietBind(p: NamedParam): IR.Result = beQuietDuring(bind(p))
- def bind(p: NamedParam): IR.Result = bind(p.name, p.tpe, p.value)
- def bind[T: Manifest](name: String, value: T): IR.Result = bind((name, value))
- def bindValue(x: Any): IR.Result = bind(freshUserVarName(), TypeStrings.fromValue(x), x)
+ def quietBind(p: NamedParam): IR.Result = beQuietDuring(bind(p))
+ def bind(p: NamedParam): IR.Result = bind(p.name, p.tpe, p.value)
+ def bind[T: ru.TypeTag : ClassTag](name: String, value: T): IR.Result = bind((name, value))
+ def bindSyntheticValue(x: Any): IR.Result = bindValue(freshInternalVarName(), x)
+ def bindValue(x: Any): IR.Result = bindValue(freshUserVarName(), x)
+ def bindValue(name: String, x: Any): IR.Result = bind(name, TypeStrings.fromValue(x), x)
/** Reset this interpreter, forgetting all user-specified requests. */
def reset() {
- //virtualDirectory.clear()
- virtualDirectory.delete()
- virtualDirectory.create()
+ clearExecutionWrapper()
resetClassLoader()
resetAllCreators()
prevRequests.clear()
+ referencedNameMap.clear()
+ definedNameMap.clear()
+ virtualDirectory.delete()
+ virtualDirectory.create()
}
/** This instance is no longer needed, so release any resources
@@ -596,9 +722,9 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
reporter.flush()
classServer.stop()
}
-
+
/** Here is where we:
- *
+ *
* 1) Read some source code, and put it in the "read" object.
* 2) Evaluate the read object, and put the result in the "eval" object.
* 3) Create a String for human consumption, and put it in the "print" object.
@@ -608,115 +734,172 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
class ReadEvalPrint(lineId: Int) {
def this() = this(freshLineId())
- val packageName = "$line" + lineId
- val readName = "$read"
- val evalName = "$eval"
- val printName = "$print"
- val valueMethod = "$result" // no-args method giving result
-
+ private var lastRun: Run = _
+ private var evalCaught: Option[Throwable] = None
+ private var conditionalWarnings: List[ConditionalWarning] = Nil
+
+ val packageName = sessionNames.line + lineId
+ val readName = sessionNames.read
+ val evalName = sessionNames.eval
+ val printName = sessionNames.print
+ val resultName = sessionNames.result
+
+ def bindError(t: Throwable) = {
+ if (!bindExceptions) // avoid looping if already binding
+ throw t
+
+ val unwrapped = unwrap(t)
+ withLastExceptionLock[String]({
+ directBind[Throwable]("lastException", unwrapped)(tagOfThrowable, classTag[Throwable])
+ util.stackTraceString(unwrapped)
+ }, util.stackTraceString(unwrapped))
+ }
+
// TODO: split it out into a package object and a regular
// object and we can do that much less wrapping.
def packageDecl = "package " + packageName
-
+
def pathTo(name: String) = packageName + "." + name
def packaged(code: String) = packageDecl + "\n\n" + code
def readPath = pathTo(readName)
def evalPath = pathTo(evalName)
def printPath = pathTo(printName)
-
- def call(name: String, args: Any*): AnyRef =
- evalMethod(name).invoke(evalClass, args.map(_.asInstanceOf[AnyRef]): _*)
-
+
+ def call(name: String, args: Any*): AnyRef = {
+ val m = evalMethod(name)
+ logDebug("Invoking: " + m)
+ if (args.nonEmpty)
+ logDebug(" with args: " + args.mkString(", "))
+
+ m.invoke(evalClass, args.map(_.asInstanceOf[AnyRef]): _*)
+ }
+
+ def callEither(name: String, args: Any*): Either[Throwable, AnyRef] =
+ try Right(call(name, args: _*))
+ catch { case ex: Throwable => Left(ex) }
+
def callOpt(name: String, args: Any*): Option[AnyRef] =
try Some(call(name, args: _*))
- catch { case ex: Exception =>
- quietBind("lastException", ex)
- None
- }
-
- lazy val evalClass = loadByName(evalPath)
- lazy val evalValue = callOpt(valueMethod)
+ catch { case ex: Throwable => bindError(ex) ; None }
- def compile(source: String): Boolean = compileAndSaveRun("<console>", source)
- def lineAfterTyper[T](op: => T): T = {
- assert(lastRun != null, "Internal error: trying to use atPhase, but Run is null." + this)
- atPhase(lastRun.typerPhase.next)(op)
+ class EvalException(msg: String, cause: Throwable) extends RuntimeException(msg, cause) { }
+
+ private def evalError(path: String, ex: Throwable) =
+ throw new EvalException("Failed to load '" + path + "': " + ex.getMessage, ex)
+
+ private def load(path: String): Class[_] = {
+ try Class.forName(path, true, classLoader)
+ catch { case ex: Throwable => evalError(path, unwrap(ex)) }
+ }
+
+ lazy val evalClass = load(evalPath)
+ lazy val evalValue = callEither(resultName) match {
+ case Left(ex) => evalCaught = Some(ex) ; None
+ case Right(result) => Some(result)
}
-
+
+ def compile(source: String): Boolean = compileAndSaveRun("<console>", source)
+
/** The innermost object inside the wrapper, found by
- * following accessPath into the outer one.
- */
+ * following accessPath into the outer one.
+ */
def resolvePathToSymbol(accessPath: String): Symbol = {
- //val readRoot = definitions.getModule(readPath) // the outermost wrapper
- // MATEI: changed this to getClass because the root object is no longer a module (Scala singleton object)
- val readRoot = definitions.getClass(readPath) // the outermost wrapper
- (accessPath split '.').foldLeft(readRoot) { (sym, name) =>
- if (name == "") sym else
- lineAfterTyper(sym.info member newTermName(name))
+ // val readRoot = getRequiredModule(readPath) // the outermost wrapper
+ // MATEI: Changed this to getClass because the root object is no longer a module (Scala singleton object)
+
+ val readRoot = rootMirror.getClassByName(newTypeName(readPath)) // the outermost wrapper
+ (accessPath split '.').foldLeft(readRoot: Symbol) {
+ case (sym, "") => sym
+ case (sym, name) => afterTyper(termMember(sym, name))
}
}
-
- // def compileAndTypeExpr(expr: String): Option[Typer] = {
- // class TyperRun extends Run {
- // override def stopPhase(name: String) = name == "superaccessors"
- // }
- // }
- private var lastRun: Run = _
- private def evalMethod(name: String) = {
- val methods = evalClass.getMethods filter (_.getName == name)
- assert(methods.size == 1, "Internal error - eval object method " + name + " is overloaded: " + methods)
- methods.head
+ /** We get a bunch of repeated warnings for reasons I haven't
+ * entirely figured out yet. For now, squash.
+ */
+ private def updateRecentWarnings(run: Run) {
+ def loop(xs: List[(Position, String)]): List[(Position, String)] = xs match {
+ case Nil => Nil
+ case ((pos, msg)) :: rest =>
+ val filtered = rest filter { case (pos0, msg0) =>
+ (msg != msg0) || (pos.lineContent.trim != pos0.lineContent.trim) || {
+ // same messages and same line content after whitespace removal
+ // but we want to let through multiple warnings on the same line
+ // from the same run. The untrimmed line will be the same since
+ // there's no whitespace indenting blowing it.
+ (pos.lineContent == pos0.lineContent)
+ }
+ }
+ ((pos, msg)) :: loop(filtered)
+ }
+ //PRASHANT: This leads to a NoSuchMethodError for _.warnings. Yet to figure out its purpose.
+ // val warnings = loop(run.allConditionalWarnings flatMap (_.warnings))
+ // if (warnings.nonEmpty)
+ // mostRecentWarnings = warnings
+ }
+ private def evalMethod(name: String) = evalClass.getMethods filter (_.getName == name) match {
+ case Array(method) => method
+ case xs => sys.error("Internal error: eval object " + evalClass + ", " + xs.mkString("\n", "\n", ""))
}
private def compileAndSaveRun(label: String, code: String) = {
showCodeIfDebugging(code)
- reporter.reset()
- lastRun = new Run()
- lastRun.compileSources(List(new BatchSourceFile(label, packaged(code))))
- !reporter.hasErrors
+ val (success, run) = compileSourcesKeepingRun(new BatchSourceFile(label, packaged(code)))
+ updateRecentWarnings(run)
+ lastRun = run
+ success
}
}
/** One line of code submitted by the user for interpretation */
- // private
+ // private
class Request(val line: String, val trees: List[Tree]) {
- val lineRep = new ReadEvalPrint()
- import lineRep.lineAfterTyper
-
+ val reqId = nextReqId()
+ val lineRep = new ReadEvalPrint()
+
private var _originalLine: String = null
def withOriginalLine(s: String): this.type = { _originalLine = s ; this }
def originalLine = if (_originalLine == null) line else _originalLine
/** handlers for each tree in this request */
val handlers: List[MemberHandler] = trees map (memberHandlers chooseHandler _)
+ def defHandlers = handlers collect { case x: MemberDefHandler => x }
/** all (public) names defined by these statements */
val definedNames = handlers flatMap (_.definedNames)
/** list of names used by this expression */
val referencedNames: List[Name] = handlers flatMap (_.referencedNames)
-
+
/** def and val names */
def termNames = handlers flatMap (_.definesTerm)
def typeNames = handlers flatMap (_.definesType)
+ def definedOrImported = handlers flatMap (_.definedOrImported)
+ def definedSymbolList = defHandlers flatMap (_.definedSymbols)
+
+ def definedTypeSymbol(name: String) = definedSymbols(newTypeName(name))
+ def definedTermSymbol(name: String) = definedSymbols(newTermName(name))
/** Code to import bound names from previous lines - accessPath is code to
- * append to objectName to access anything bound by request.
- */
- val ComputedImports(importsPreamble, importsTrailer, accessPath) =
+ * append to objectName to access anything bound by request.
+ */
+ val SparkComputedImports(importsPreamble, importsTrailer, accessPath) =
importsCode(referencedNames.toSet)
/** Code to access a variable with the specified name */
- def fullPath(vname: String) = (
- //lineRep.readPath + accessPath + ".`%s`".format(vname)
+ def fullPath(vname: String) = {
+ // lineRep.readPath + accessPath + ".`%s`".format(vname)
lineRep.readPath + ".INSTANCE" + accessPath + ".`%s`".format(vname)
- )
- /** Same as fullpath, but after it has been flattened, so:
- * $line5.$iw.$iw.$iw.Bippy // fullPath
- * $line5.$iw$$iw$$iw$Bippy // fullFlatName
- */
- def fullFlatName(name: String) =
- lineRep.readPath + accessPath.replace('.', '$') + "$" + name
+ }
+ /** Same as fullpath, but after it has been flattened, so:
+ * $line5.$iw.$iw.$iw.Bippy // fullPath
+ * $line5.$iw$$iw$$iw$Bippy // fullFlatName
+ */
+ def fullFlatName(name: String) =
+ // lineRep.readPath + accessPath.replace('.', '$') + nme.NAME_JOIN_STRING + name
+ lineRep.readPath + ".INSTANCE" + accessPath.replace('.', '$') + nme.NAME_JOIN_STRING + name
+
+ /** The unmangled symbol name, but supplemented with line info. */
+ def disambiguated(name: Name): String = name + " (in " + lineRep + ")"
/** Code to access a variable with the specified name */
def fullPath(vname: Name): String = fullPath(vname.toString)
@@ -726,52 +909,66 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
/** generate the source code for the object that computes this request */
private object ObjectSourceCode extends CodeAssembler[MemberHandler] {
+ def path = pathToTerm("$intp")
+ def envLines = {
+ if (!isReplPower) Nil // power mode only for now
+ // $intp is not bound; punt, but include the line.
+ else if (path == "$intp") List(
+ "def $line = " + tquoted(originalLine),
+ "def $trees = Nil"
+ )
+ else List(
+ "def $line = " + tquoted(originalLine),
+ "def $req = %s.requestForReqId(%s).orNull".format(path, reqId),
+ "def $trees = if ($req eq null) Nil else $req.trees".format(lineRep.readName, path, reqId)
+ )
+ }
+
val preamble = """
|class %s extends Serializable {
- | %s%s
- """.stripMargin.format(lineRep.readName, importsPreamble, indentCode(toCompute))
+ | %s%s%s
+ """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute))
val postamble = importsTrailer + "\n}" + "\n" +
"object " + lineRep.readName + " {\n" +
" val INSTANCE = new " + lineRep.readName + "();\n" +
"}\n"
val generate = (m: MemberHandler) => m extraCodeToEvaluate Request.this
+
/*
val preamble = """
- |object %s {
- | %s%s
- """.stripMargin.format(lineRep.readName, importsPreamble, indentCode(toCompute))
+ |object %s extends Serializable {
+ |%s%s%s
+ """.stripMargin.format(lineRep.readName, envLines.map(" " + _ + ";\n").mkString, importsPreamble, indentCode(toCompute))
val postamble = importsTrailer + "\n}"
val generate = (m: MemberHandler) => m extraCodeToEvaluate Request.this
*/
+
}
-
+
private object ResultObjectSourceCode extends CodeAssembler[MemberHandler] {
/** We only want to generate this code when the result
* is a value which can be referred to as-is.
- */
+ */
val evalResult =
if (!handlers.last.definesValue) ""
else handlers.last.definesTerm match {
case Some(vname) if typeOf contains vname =>
- """
- |lazy val $result = {
- | $export
- | %s
- |}""".stripMargin.format(fullPath(vname))
+ "lazy val %s = %s".format(lineRep.resultName, fullPath(vname))
case _ => ""
}
+
// first line evaluates object to make sure constructor is run
// initial "" so later code can uniformly be: + etc
val preamble = """
|object %s {
| %s
- | val $export: String = %s {
+ | val %s: String = %s {
| %s
| (""
""".stripMargin.format(
- lineRep.evalName, evalResult, executionWrapper, lineRep.readName + ".INSTANCE" + accessPath
+ lineRep.evalName, evalResult, lineRep.printName,
+ executionWrapper, lineRep.readName + ".INSTANCE" + accessPath
)
-
val postamble = """
| )
| }
@@ -785,7 +982,7 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
def getEval: Option[AnyRef] = {
// ensure it has been compiled
compile
- // try to load it and call the value method
+ // try to load it and call the value method
lineRep.evalValue filterNot (_ == null)
}
@@ -797,114 +994,54 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
// compile the object containing the user's code
lineRep.compile(ObjectSourceCode(handlers)) && {
- // extract and remember types
+ // extract and remember types
typeOf
typesOfDefinedTerms
+ // Assign symbols to the original trees
+ // TODO - just use the new trees.
+ defHandlers foreach { dh =>
+ val name = dh.member.name
+ definedSymbols get name foreach { sym =>
+ dh.member setSymbol sym
+ logDebug("Set symbol of " + name + " to " + sym.defString)
+ }
+ }
+
// compile the result-extraction object
- lineRep compile ResultObjectSourceCode(handlers)
+ withoutWarnings(lineRep compile ResultObjectSourceCode(handlers))
}
}
lazy val resultSymbol = lineRep.resolvePathToSymbol(accessPath)
- def applyToResultMember[T](name: Name, f: Symbol => T) = lineAfterTyper(f(resultSymbol.info.nonPrivateDecl(name)))
+ def applyToResultMember[T](name: Name, f: Symbol => T) = afterTyper(f(resultSymbol.info.nonPrivateDecl(name)))
/* typeOf lookup with encoding */
- def lookupTypeOf(name: Name) = {
- typeOf.getOrElse(name, typeOf(global.encode(name.toString)))
- }
- def simpleNameOfType(name: TypeName) = {
- (compilerTypeOf get name) map (_.typeSymbol.simpleName)
- }
-
- private def typeMap[T](f: Type => T): Map[Name, T] = {
- def toType(name: Name): T = {
- // the types are all =>T; remove the =>
- val tp1 = lineAfterTyper(resultSymbol.info.nonPrivateDecl(name).tpe match {
- case NullaryMethodType(tp) => tp
- case tp => tp
- })
- // normalize non-public types so we don't see protected aliases like Self
- lineAfterTyper(tp1 match {
- case TypeRef(_, sym, _) if !sym.isPublic => f(tp1.normalize)
- case tp => f(tp)
- })
- }
- termNames ++ typeNames map (x => x -> toType(x)) toMap
- }
+ def lookupTypeOf(name: Name) = typeOf.getOrElse(name, typeOf(global.encode(name.toString)))
+ def simpleNameOfType(name: TypeName) = (compilerTypeOf get name) map (_.typeSymbol.simpleName)
+
+ private def typeMap[T](f: Type => T) =
+ mapFrom[Name, Name, T](termNames ++ typeNames)(x => f(cleanMemberDecl(resultSymbol, x)))
+
/** Types of variables defined by this request. */
- lazy val compilerTypeOf = typeMap[Type](x => x)
+ lazy val compilerTypeOf = typeMap[Type](x => x) withDefaultValue NoType
/** String representations of same. */
- lazy val typeOf = typeMap[String](_.toString)
-
+ lazy val typeOf = typeMap[String](tp => afterTyper(tp.toString))
+
// lazy val definedTypes: Map[Name, Type] = {
// typeNames map (x => x -> afterTyper(resultSymbol.info.nonPrivateDecl(x).tpe)) toMap
// }
- lazy val definedSymbols: Map[Name, Symbol] = (
+ lazy val definedSymbols = (
termNames.map(x => x -> applyToResultMember(x, x => x)) ++
- typeNames.map(x => x -> compilerTypeOf.get(x).map(_.typeSymbol).getOrElse(NoSymbol))
- ).toMap
-
- lazy val typesOfDefinedTerms: Map[Name, Type] =
- termNames map (x => x -> applyToResultMember(x, _.tpe)) toMap
-
- private def bindExceptionally(t: Throwable) = {
- val ex: Exceptional =
- if (isettings.showInternalStackTraces) Exceptional(t)
- else new Exceptional(t) {
- override def spanFn(frame: JavaStackFrame) = !(frame.className startsWith lineRep.evalPath)
- override def contextPrelude = super.contextPrelude + "/* The repl internal portion of the stack trace is elided. */\n"
- }
-
- quietBind("lastException", ex)
- ex.contextHead + "\n(access lastException for the full trace)"
- }
- private def bindUnexceptionally(t: Throwable) = {
- quietBind("lastException", t)
- stackTraceString(t)
- }
+ typeNames.map(x => x -> compilerTypeOf(x).typeSymbolDirect)
+ ).toMap[Name, Symbol] withDefaultValue NoSymbol
+
+ lazy val typesOfDefinedTerms = mapFrom[Name, Name, Type](termNames)(x => applyToResultMember(x, _.tpe))
/** load and run the code using reflection */
def loadAndRun: (String, Boolean) = {
- import interpreter.Line._
-
- def handleException(t: Throwable) = {
- /** We turn off the binding to accomodate ticket #2817 */
- withoutBindingLastException {
- val message =
- if (opt.richExes) bindExceptionally(unwrap(t))
- else bindUnexceptionally(unwrap(t))
-
- (message, false)
- }
- }
-
- try {
- val execution = lineManager.set(originalLine) {
- // MATEI: set the right SparkEnv for our SparkContext, because
- // this execution will happen in a separate thread
- val sc = org.apache.spark.repl.Main.interp.sparkContext
- if (sc != null && sc.env != null)
- SparkEnv.set(sc.env)
- // Execute the line
- lineRep call "$export"
- }
- execution.await()
-
- execution.state match {
- case Done => ("" + execution.get(), true)
- case Threw =>
- val ex = execution.caught()
- if (isReplDebug)
- ex.printStackTrace()
-
- if (bindLastException) handleException(ex)
- else throw ex
- case Cancelled => ("Execution interrupted by signal.\n", false)
- case Running => ("Execution still running! Seems impossible.", false)
- }
- }
- finally lineManager.clear()
+ try { ("" + (lineRep call sessionNames.print), true) }
+ catch { case ex: Throwable => (lineRep.bindError(ex), false) }
}
override def toString = "Request(line=%s, %s trees)".format(line, trees.size)
@@ -922,136 +1059,157 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
case ModuleDef(_, name, _) => name
case _ => naming.mostRecentVar
})
-
- private def requestForName(name: Name): Option[Request] = {
+
+ private var mostRecentWarnings: List[(global.Position, String)] = Nil
+ def lastWarnings = mostRecentWarnings
+
+ def treesForRequestId(id: Int): List[Tree] =
+ requestForReqId(id).toList flatMap (_.trees)
+
+ def requestForReqId(id: Int): Option[Request] =
+ if (executingRequest != null && executingRequest.reqId == id) Some(executingRequest)
+ else prevRequests find (_.reqId == id)
+
+ def requestForName(name: Name): Option[Request] = {
assert(definedNameMap != null, "definedNameMap is null")
definedNameMap get name
}
- private def requestForIdent(line: String): Option[Request] =
+ def requestForIdent(line: String): Option[Request] =
requestForName(newTermName(line)) orElse requestForName(newTypeName(line))
-
- def safeClass(name: String): Option[Symbol] = {
- try Some(definitions.getClass(newTypeName(name)))
- catch { case _: MissingRequirementError => None }
- }
- def safeModule(name: String): Option[Symbol] = {
- try Some(definitions.getModule(newTermName(name)))
- catch { case _: MissingRequirementError => None }
- }
+
+ def requestHistoryForName(name: Name): List[Request] =
+ prevRequests.toList.reverse filter (_.definedNames contains name)
+
def definitionForName(name: Name): Option[MemberHandler] =
requestForName(name) flatMap { req =>
req.handlers find (_.definedNames contains name)
}
-
+
def valueOfTerm(id: String): Option[AnyRef] =
- requestForIdent(id) flatMap (_.getEval)
+ requestForName(newTermName(id)) flatMap (_.getEval)
def classOfTerm(id: String): Option[JClass] =
- valueOfTerm(id) map (_.getClass)
+ valueOfTerm(id) map (_.getClass)
- def typeOfTerm(id: String): Option[Type] = newTermName(id) match {
- case nme.ROOTPKG => Some(definitions.RootClass.tpe)
- case name => requestForName(name) flatMap (_.compilerTypeOf get name)
+ def typeOfTerm(id: String): Type = newTermName(id) match {
+ case nme.ROOTPKG => RootClass.tpe
+ case name => requestForName(name).fold(NoType: Type)(_ compilerTypeOf name)
}
+
+ def symbolOfType(id: String): Symbol =
+ requestForName(newTypeName(id)).fold(NoSymbol: Symbol)(_ definedTypeSymbol id)
+
def symbolOfTerm(id: String): Symbol =
- requestForIdent(id) flatMap (_.definedSymbols get newTermName(id)) getOrElse NoSymbol
+ requestForIdent(newTermName(id)).fold(NoSymbol: Symbol)(_ definedTermSymbol id)
def runtimeClassAndTypeOfTerm(id: String): Option[(JClass, Type)] = {
- for {
- clazz <- classOfTerm(id)
- tpe <- runtimeTypeOfTerm(id)
- nonAnon <- new RichClass(clazz).supers.find(c => !(new RichClass(c).isScalaAnonymous))
- } yield {
- (nonAnon, tpe)
- }
- }
-
- def runtimeTypeOfTerm(id: String): Option[Type] = {
- for {
- tpe <- typeOfTerm(id)
- clazz <- classOfTerm(id)
- val staticSym = tpe.typeSymbol
- runtimeSym <- safeClass(clazz.getName)
- if runtimeSym != staticSym
- if runtimeSym isSubClass staticSym
- } yield {
- runtimeSym.info
+ classOfTerm(id) flatMap { clazz =>
+ new RichClass(clazz).supers find(c => !(new RichClass(c).isScalaAnonymous)) map { nonAnon =>
+ (nonAnon, runtimeTypeOfTerm(id))
+ }
}
}
-
- // XXX literals.
- // 1) Identifiers defined in the repl.
- // 2) A path loadable via getModule.
- // 3) Try interpreting it as an expression.
- private var typeOfExpressionDepth = 0
- def typeOfExpression(expr: String): Option[Type] = {
- DBG("typeOfExpression(" + expr + ")")
- if (typeOfExpressionDepth > 2) {
- DBG("Terminating typeOfExpression recursion for expression: " + expr)
- return None
- }
- def asQualifiedImport = {
- val name = expr.takeWhile(_ != '.')
- importedTermNamed(name) flatMap { sym =>
- typeOfExpression(sym.fullName + expr.drop(name.length))
- }
+ def runtimeTypeOfTerm(id: String): Type = {
+ typeOfTerm(id) andAlso { tpe =>
+ val clazz = classOfTerm(id) getOrElse { return NoType }
+ val staticSym = tpe.typeSymbol
+ val runtimeSym = getClassIfDefined(clazz.getName)
+
+ if ((runtimeSym != NoSymbol) && (runtimeSym != staticSym) && (runtimeSym isSubClass staticSym))
+ runtimeSym.info
+ else NoType
}
- def asModule = safeModule(expr) map (_.tpe)
- def asExpr = beSilentDuring {
- val lhs = freshInternalVarName()
- val line = "lazy val " + lhs + " = { " + expr + " } "
-
- interpret(line, true) match {
- case IR.Success => typeOfExpression(lhs)
- case _ => None
+ }
+ def cleanMemberDecl(owner: Symbol, member: Name): Type = afterTyper {
+ normalizeNonPublic {
+ owner.info.nonPrivateDecl(member).tpe match {
+ case NullaryMethodType(tp) => tp
+ case tp => tp
}
}
-
- typeOfExpressionDepth += 1
- try typeOfTerm(expr) orElse asModule orElse asExpr orElse asQualifiedImport
- finally typeOfExpressionDepth -= 1
}
- // def compileAndTypeExpr(expr: String): Option[Typer] = {
- // class TyperRun extends Run {
- // override def stopPhase(name: String) = name == "superaccessors"
- // }
- // }
-
+
+ object exprTyper extends {
+ val repl: SparkIMain.this.type = imain
+ } with SparkExprTyper { }
+
+ def parse(line: String): Option[List[Tree]] = exprTyper.parse(line)
+
+ def symbolOfLine(code: String): Symbol =
+ exprTyper.symbolOfLine(code)
+
+ def typeOfExpression(expr: String, silent: Boolean = true): Type =
+ exprTyper.typeOfExpression(expr, silent)
+
protected def onlyTerms(xs: List[Name]) = xs collect { case x: TermName => x }
protected def onlyTypes(xs: List[Name]) = xs collect { case x: TypeName => x }
-
- def definedTerms = onlyTerms(allDefinedNames) filterNot (x => isInternalVarName(x.toString))
- def definedTypes = onlyTypes(allDefinedNames)
- def definedSymbols = prevRequests.toSet flatMap ((x: Request) => x.definedSymbols.values)
-
+
+ def definedTerms = onlyTerms(allDefinedNames) filterNot isInternalTermName
+ def definedTypes = onlyTypes(allDefinedNames)
+ def definedSymbols = prevRequestList.flatMap(_.definedSymbols.values).toSet[Symbol]
+ def definedSymbolList = prevRequestList flatMap (_.definedSymbolList) filterNot (s => isInternalTermName(s.name))
+
+ // Terms with user-given names (i.e. not res0 and not synthetic)
+ def namedDefinedTerms = definedTerms filterNot (x => isUserVarName("" + x) || directlyBoundNames(x))
+
+ private def findName(name: Name) = definedSymbols find (_.name == name) getOrElse NoSymbol
+
+ /** Translate a repl-defined identifier into a Symbol.
+ */
+ def apply(name: String): Symbol =
+ types(name) orElse terms(name)
+
+ def types(name: String): Symbol = {
+ val tpname = newTypeName(name)
+ findName(tpname) orElse getClassIfDefined(tpname)
+ }
+ def terms(name: String): Symbol = {
+ val termname = newTypeName(name)
+ findName(termname) orElse getModuleIfDefined(termname)
+ }
+ // [Eugene to Paul] possibly you could make use of TypeTags here
+ def types[T: ClassTag] : Symbol = types(classTag[T].runtimeClass.getName)
+ def terms[T: ClassTag] : Symbol = terms(classTag[T].runtimeClass.getName)
+ def apply[T: ClassTag] : Symbol = apply(classTag[T].runtimeClass.getName)
+
+ def classSymbols = allDefSymbols collect { case x: ClassSymbol => x }
+ def methodSymbols = allDefSymbols collect { case x: MethodSymbol => x }
+
/** the previous requests this interpreter has processed */
- private lazy val prevRequests = mutable.ListBuffer[Request]()
- private lazy val referencedNameMap = mutable.Map[Name, Request]()
- private lazy val definedNameMap = mutable.Map[Name, Request]()
- protected def prevRequestList = prevRequests.toList
- private def allHandlers = prevRequestList flatMap (_.handlers)
- def allSeenTypes = prevRequestList flatMap (_.typeOf.values.toList) distinct
- def allImplicits = allHandlers filter (_.definesImplicit) flatMap (_.definedNames)
- def importHandlers = allHandlers collect { case x: ImportHandler => x }
-
+ private var executingRequest: Request = _
+ private val prevRequests = mutable.ListBuffer[Request]()
+ private val referencedNameMap = mutable.Map[Name, Request]()
+ private val definedNameMap = mutable.Map[Name, Request]()
+ private val directlyBoundNames = mutable.Set[Name]()
+
+ def allHandlers = prevRequestList flatMap (_.handlers)
+ def allDefHandlers = allHandlers collect { case x: MemberDefHandler => x }
+ def allDefSymbols = allDefHandlers map (_.symbol) filter (_ ne NoSymbol)
+
+ def lastRequest = if (prevRequests.isEmpty) null else prevRequests.last
+ def prevRequestList = prevRequests.toList
+ def allSeenTypes = prevRequestList flatMap (_.typeOf.values.toList) distinct
+ def allImplicits = allHandlers filter (_.definesImplicit) flatMap (_.definedNames)
+ def importHandlers = allHandlers collect { case x: ImportHandler => x }
+
def visibleTermNames: List[Name] = definedTerms ++ importedTerms distinct
/** Another entry point for tab-completion, ids in scope */
def unqualifiedIds = visibleTermNames map (_.toString) filterNot (_ contains "$") sorted
-
+
/** Parse the ScalaSig to find type aliases */
def aliasForType(path: String) = ByteCode.aliasForType(path)
-
+
def withoutUnwrapping(op: => Unit): Unit = {
val saved = isettings.unwrapStrings
isettings.unwrapStrings = false
try op
finally isettings.unwrapStrings = saved
}
-
+
def symbolDefString(sym: Symbol) = {
TypeStrings.quieter(
afterTyper(sym.defString),
@@ -1059,38 +1217,41 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
sym.owner.fullName + "."
)
}
-
+
def showCodeIfDebugging(code: String) {
/** Secret bookcase entrance for repl debuggers: end the line
* with "// show" and see what's going on.
*/
- if (SPARK_DEBUG_REPL || code.lines.exists(_.trim endsWith "// show")) {
- echo(code)
- parse(code) foreach (ts => ts foreach (t => withoutUnwrapping(DBG(asCompactString(t)))))
+ def isShow = code.lines exists (_.trim endsWith "// show")
+ def isShowRaw = code.lines exists (_.trim endsWith "// raw")
+
+ // old style
+ beSilentDuring(parse(code)) foreach { ts =>
+ ts foreach { t =>
+ withoutUnwrapping(logDebug(asCompactString(t)))
+ }
}
}
+
// debugging
def debugging[T](msg: String)(res: T) = {
- DBG(msg + " " + res)
+ logDebug(msg + " " + res)
res
}
- def DBG(s: => String) = if (isReplDebug) {
- //try repldbg(s)
- //catch { case x: AssertionError => repldbg("Assertion error printing debug string:\n " + x) }
- }
}
/** Utility methods for the Interpreter. */
object SparkIMain {
// The two name forms this is catching are the two sides of this assignment:
//
- // $line3.$read.$iw.$iw.Bippy =
+ // $line3.$read.$iw.$iw.Bippy =
// $line3.$read$$iw$$iw$Bippy@4a6a00ca
private def removeLineWrapper(s: String) = s.replaceAll("""\$line\d+[./]\$(read|eval|print)[$.]""", "")
private def removeIWPackages(s: String) = s.replaceAll("""\$(iw|iwC|read|eval|print)[$.]""", "")
private def removeSparkVals(s: String) = s.replaceAll("""\$VAL[0-9]+[$.]""", "")
+
def stripString(s: String) = removeSparkVals(removeIWPackages(removeLineWrapper(s)))
-
+
trait CodeAssembler[T] {
def preamble: String
def generate: T => String
@@ -1102,7 +1263,7 @@ object SparkIMain {
code println postamble
}
}
-
+
trait StrippingWriter {
def isStripping: Boolean
def stripImpl(str: String): String
@@ -1112,17 +1273,17 @@ object SparkIMain {
def maxStringLength: Int
def isTruncating: Boolean
def truncate(str: String): String = {
- if (isTruncating && str.length > maxStringLength)
+ if (isTruncating && (maxStringLength != 0 && str.length > maxStringLength))
(str take maxStringLength - 3) + "..."
else str
}
}
- abstract class StrippingTruncatingWriter(out: PrintWriter)
- extends PrintWriter(out)
+ abstract class StrippingTruncatingWriter(out: JPrintWriter)
+ extends JPrintWriter(out)
with StrippingWriter
with TruncatingWriter {
self =>
-
+
def clean(str: String): String = truncate(strip(str))
override def write(str: String) = super.write(clean(str))
}
@@ -1132,18 +1293,7 @@ object SparkIMain {
def isStripping = isettings.unwrapStrings
def isTruncating = reporter.truncationOK
- def stripImpl(str: String): String = {
- val cleaned = stripString(str)
- var ctrlChars = 0
- cleaned map { ch =>
- if (ch.isControl && !ch.isWhitespace) {
- ctrlChars += 1
- if (ctrlChars > 5) return "[line elided for control chars: possibly a scala signature]"
- else '?'
- }
- else ch
- }
- }
+ def stripImpl(str: String): String = naming.unmangle(str)
}
class ReplReporter(intp: SparkIMain) extends ConsoleReporter(intp.settings, null, new ReplStrippingWriter(intp)) {
@@ -1156,5 +1306,55 @@ object SparkIMain {
}
else Console.println(msg)
}
- }
+ }
+}
+
+class SparkISettings(intp: SparkIMain) extends Logging {
+ /** A list of paths where :load should look */
+ var loadPath = List(".")
+
+ /** Set this to true to see repl machinery under -Yrich-exceptions.
+ */
+ var showInternalStackTraces = false
+
+ /** The maximum length of toString to use when printing the result
+ * of an evaluation. 0 means no maximum. If a printout requires
+ * more than this number of characters, then the printout is
+ * truncated.
+ */
+ var maxPrintString = 800
+
+ /** The maximum number of completion candidates to print for tab
+ * completion without requiring confirmation.
+ */
+ var maxAutoprintCompletion = 250
+
+ /** String unwrapping can be disabled if it is causing issues.
+ * Settings this to false means you will see Strings like "$iw.$iw.".
+ */
+ var unwrapStrings = true
+
+ def deprecation_=(x: Boolean) = {
+ val old = intp.settings.deprecation.value
+ intp.settings.deprecation.value = x
+ if (!old && x) logDebug("Enabled -deprecation output.")
+ else if (old && !x) logDebug("Disabled -deprecation output.")
+ }
+
+ def deprecation: Boolean = intp.settings.deprecation.value
+
+ def allSettings = Map(
+ "maxPrintString" -> maxPrintString,
+ "maxAutoprintCompletion" -> maxAutoprintCompletion,
+ "unwrapStrings" -> unwrapStrings,
+ "deprecation" -> deprecation
+ )
+
+ private def allSettingsString =
+ allSettings.toList sortBy (_._1) map { case (k, v) => " " + k + " = " + v + "\n" } mkString
+
+ override def toString = """
+ | SparkISettings {
+ | %s
+ | }""".stripMargin.format(allSettingsString)
}
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkISettings.scala b/repl/src/main/scala/org/apache/spark/repl/SparkISettings.scala
deleted file mode 100644
index 605b7b259b..0000000000
--- a/repl/src/main/scala/org/apache/spark/repl/SparkISettings.scala
+++ /dev/null
@@ -1,63 +0,0 @@
-/* NSC -- new Scala compiler
- * Copyright 2005-2011 LAMP/EPFL
- * @author Alexander Spoon
- */
-
-package org.apache.spark.repl
-
-import scala.tools.nsc._
-import scala.tools.nsc.interpreter._
-
-/** Settings for the interpreter
- *
- * @version 1.0
- * @author Lex Spoon, 2007/3/24
- **/
-class SparkISettings(intp: SparkIMain) {
- /** A list of paths where :load should look */
- var loadPath = List(".")
-
- /** Set this to true to see repl machinery under -Yrich-exceptions.
- */
- var showInternalStackTraces = false
-
- /** The maximum length of toString to use when printing the result
- * of an evaluation. 0 means no maximum. If a printout requires
- * more than this number of characters, then the printout is
- * truncated.
- */
- var maxPrintString = 800
-
- /** The maximum number of completion candidates to print for tab
- * completion without requiring confirmation.
- */
- var maxAutoprintCompletion = 250
-
- /** String unwrapping can be disabled if it is causing issues.
- * Settings this to false means you will see Strings like "$iw.$iw.".
- */
- var unwrapStrings = true
-
- def deprecation_=(x: Boolean) = {
- val old = intp.settings.deprecation.value
- intp.settings.deprecation.value = x
- if (!old && x) println("Enabled -deprecation output.")
- else if (old && !x) println("Disabled -deprecation output.")
- }
- def deprecation: Boolean = intp.settings.deprecation.value
-
- def allSettings = Map(
- "maxPrintString" -> maxPrintString,
- "maxAutoprintCompletion" -> maxAutoprintCompletion,
- "unwrapStrings" -> unwrapStrings,
- "deprecation" -> deprecation
- )
-
- private def allSettingsString =
- allSettings.toList sortBy (_._1) map { case (k, v) => " " + k + " = " + v + "\n" } mkString
-
- override def toString = """
- | SparkISettings {
- | %s
- | }""".stripMargin.format(allSettingsString)
-}
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala
index 41a1731d60..64084209e8 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala
@@ -1,5 +1,5 @@
/* NSC -- new Scala compiler
- * Copyright 2005-2011 LAMP/EPFL
+ * Copyright 2005-2013 LAMP/EPFL
* @author Paul Phillips
*/
@@ -12,30 +12,34 @@ import scala.collection.{ mutable, immutable }
trait SparkImports {
self: SparkIMain =>
-
+
import global._
import definitions.{ ScalaPackage, JavaLangPackage, PredefModule }
import memberHandlers._
-
+
+ def isNoImports = settings.noimports.value
+ def isNoPredef = settings.nopredef.value
+
/** Synthetic import handlers for the language defined imports. */
private def makeWildcardImportHandler(sym: Symbol): ImportHandler = {
val hd :: tl = sym.fullName.split('.').toList map newTermName
val tree = Import(
tl.foldLeft(Ident(hd): Tree)((x, y) => Select(x, y)),
- List(ImportSelector(nme.WILDCARD, -1, null, -1))
+ ImportSelector.wildList
)
tree setSymbol sym
new ImportHandler(tree)
}
-
+
/** Symbols whose contents are language-defined to be imported. */
def languageWildcardSyms: List[Symbol] = List(JavaLangPackage, ScalaPackage, PredefModule)
def languageWildcards: List[Type] = languageWildcardSyms map (_.tpe)
def languageWildcardHandlers = languageWildcardSyms map makeWildcardImportHandler
-
- def importedTerms = onlyTerms(importHandlers flatMap (_.importedNames))
- def importedTypes = onlyTypes(importHandlers flatMap (_.importedNames))
-
+
+ def allImportedNames = importHandlers flatMap (_.importedNames)
+ def importedTerms = onlyTerms(allImportedNames)
+ def importedTypes = onlyTypes(allImportedNames)
+
/** Types which have been wildcard imported, such as:
* val x = "abc" ; import x._ // type java.lang.String
* import java.lang.String._ // object java.lang.String
@@ -48,30 +52,28 @@ trait SparkImports {
* into the compiler scopes.
*/
def sessionWildcards: List[Type] = {
- importHandlers flatMap {
- case x if x.importsWildcard => x.targetType
- case _ => None
- } distinct
+ importHandlers filter (_.importsWildcard) map (_.targetType) distinct
}
def wildcardTypes = languageWildcards ++ sessionWildcards
-
+
def languageSymbols = languageWildcardSyms flatMap membersAtPickler
def sessionImportedSymbols = importHandlers flatMap (_.importedSymbols)
def importedSymbols = languageSymbols ++ sessionImportedSymbols
def importedTermSymbols = importedSymbols collect { case x: TermSymbol => x }
def importedTypeSymbols = importedSymbols collect { case x: TypeSymbol => x }
def implicitSymbols = importedSymbols filter (_.isImplicit)
-
- def importedTermNamed(name: String) = importedTermSymbols find (_.name.toString == name)
+
+ def importedTermNamed(name: String): Symbol =
+ importedTermSymbols find (_.name.toString == name) getOrElse NoSymbol
/** Tuples of (source, imported symbols) in the order they were imported.
*/
def importedSymbolsBySource: List[(Symbol, List[Symbol])] = {
val lang = languageWildcardSyms map (sym => (sym, membersAtPickler(sym)))
- val session = importHandlers filter (_.targetType.isDefined) map { mh =>
- (mh.targetType.get.typeSymbol, mh.importedSymbols)
+ val session = importHandlers filter (_.targetType != NoType) map { mh =>
+ (mh.targetType.typeSymbol, mh.importedSymbols)
}
-
+
lang ++ session
}
def implicitSymbolsBySource: List[(Symbol, List[Symbol])] = {
@@ -79,7 +81,7 @@ trait SparkImports {
case (k, vs) => (k, vs filter (_.isImplicit))
} filterNot (_._2.isEmpty)
}
-
+
/** Compute imports that allow definitions from previous
* requests to be visible in a new request. Returns
* three pieces of related code:
@@ -90,7 +92,7 @@ trait SparkImports {
* 2. A code fragment that should go after the code
* of the new request.
*
- * 3. An access path which can be traverested to access
+ * 3. An access path which can be traversed to access
* any bindings inside code wrapped by #1 and #2 .
*
* The argument is a set of Names that need to be imported.
@@ -103,27 +105,27 @@ trait SparkImports {
* (3) It imports multiple same-named implicits, but only the
* last one imported is actually usable.
*/
- case class ComputedImports(prepend: String, append: String, access: String)
- protected def importsCode(wanted: Set[Name]): ComputedImports = {
- /** Narrow down the list of requests from which imports
+ case class SparkComputedImports(prepend: String, append: String, access: String)
+
+ protected def importsCode(wanted: Set[Name]): SparkComputedImports = {
+ /** Narrow down the list of requests from which imports
* should be taken. Removes requests which cannot contribute
* useful imports for the specified set of wanted names.
*/
case class ReqAndHandler(req: Request, handler: MemberHandler) { }
-
- def reqsToUse: List[ReqAndHandler] = {
+
+ def reqsToUse: List[ReqAndHandler] = {
/** Loop through a list of MemberHandlers and select which ones to keep.
* 'wanted' is the set of names that need to be imported.
*/
def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = {
- val isWanted = wanted contains _
// Single symbol imports might be implicits! See bug #1752. Rather than
// try to finesse this, we will mimic all imports for now.
def keepHandler(handler: MemberHandler) = handler match {
case _: ImportHandler => true
- case x => x.definesImplicit || (x.definedNames exists isWanted)
+ case x => x.definesImplicit || (x.definedNames exists wanted)
}
-
+
reqs match {
case Nil => Nil
case rh :: rest if !keepHandler(rh.handler) => select(rest, wanted)
@@ -133,7 +135,7 @@ trait SparkImports {
rh :: select(rest, newWanted)
}
}
-
+
/** Flatten the handlers out and pair each with the original request */
select(allReqAndHandlers reverseMap { case (r, h) => ReqAndHandler(r, h) }, wanted).reverse
}
@@ -147,8 +149,13 @@ trait SparkImports {
code append "class %sC extends Serializable {\n".format(impname)
trailingBraces append "}\nval " + impname + " = new " + impname + "C;\n"
accessPath append ("." + impname)
-
+
currentImps.clear
+ // code append "object %s {\n".format(impname)
+ // trailingBraces append "}\n"
+ // accessPath append ("." + impname)
+
+ // currentImps.clear
}
addWrapper()
@@ -159,36 +166,33 @@ trait SparkImports {
// If the user entered an import, then just use it; add an import wrapping
// level if the import might conflict with some other import
case x: ImportHandler =>
- if (x.importsWildcard || (currentImps exists (x.importedNames contains _)))
+ if (x.importsWildcard || currentImps.exists(x.importedNames contains _))
addWrapper()
-
+
code append (x.member + "\n")
-
+
// give wildcard imports a import wrapper all to their own
if (x.importsWildcard) addWrapper()
else currentImps ++= x.importedNames
// For other requests, import each defined name.
// import them explicitly instead of with _, so that
- // ambiguity errors will not be generated. Also, quote
- // the name of the variable, so that we don't need to
- // handle quoting keywords separately.
+ // ambiguity errors will not be generated. Also, quote
+ // the name of the variable, so that we don't need to
+ // handle quoting keywords separately.
case x =>
for (imv <- x.definedNames) {
- // MATEI: Changed this check because it was messing up for case classes
- // (trying to import them twice within the same wrapper), but that is more likely
- // due to a miscomputation of names that makes the code think they're unique.
- // Need to evaluate whether having so many wrappers is a bad thing.
- /*if (currentImps contains imv)*/
- val imvName = imv.toString
- if (currentImps exists (_.toString == imvName)) addWrapper()
-
+ if (currentImps contains imv) addWrapper()
val objName = req.lineRep.readPath
val valName = "$VAL" + newValId();
- code.append("val " + valName + " = " + objName + ".INSTANCE;\n")
- code.append("import " + valName + req.accessPath + ".`" + imv + "`;\n")
-
- //code append ("import %s\n" format (req fullPath imv))
+
+ if(!code.toString.endsWith(".`" + imv + "`;\n")) { // Which means already imported
+ code.append("val " + valName + " = " + objName + ".INSTANCE;\n")
+ code.append("import " + valName + req.accessPath + ".`" + imv + "`;\n")
+ }
+ // code.append("val " + valName + " = " + objName + ".INSTANCE;\n")
+ // code.append("import " + valName + req.accessPath + ".`" + imv + "`;\n")
+ // code append ("import " + (req fullPath imv) + "\n")
currentImps += imv
}
}
@@ -196,14 +200,14 @@ trait SparkImports {
// add one extra wrapper, to prevent warnings in the common case of
// redefining the value bound in the last interpreter request.
addWrapper()
- ComputedImports(code.toString, trailingBraces.toString, accessPath.toString)
+ SparkComputedImports(code.toString, trailingBraces.toString, accessPath.toString)
}
-
+
private def allReqAndHandlers =
prevRequestList flatMap (req => req.handlers map (req -> _))
private def membersAtPickler(sym: Symbol): List[Symbol] =
- atPickler(sym.info.nonPrivateMembers)
+ beforePickler(sym.info.nonPrivateMembers.toList)
private var curValId = 0
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala b/repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala
index fdc172d753..8865f82bc0 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala
@@ -1,5 +1,5 @@
/* NSC -- new Scala compiler
- * Copyright 2005-2011 LAMP/EPFL
+ * Copyright 2005-2013 LAMP/EPFL
* @author Paul Phillips
*/
@@ -11,30 +11,31 @@ import scala.tools.nsc.interpreter._
import scala.tools.jline._
import scala.tools.jline.console.completer._
import Completion._
-import collection.mutable.ListBuffer
+import scala.collection.mutable.ListBuffer
+import org.apache.spark.Logging
// REPL completor - queries supplied interpreter for valid
// completions based on current contents of buffer.
-class SparkJLineCompletion(val intp: SparkIMain) extends Completion with CompletionOutput {
+class SparkJLineCompletion(val intp: SparkIMain) extends Completion with CompletionOutput with Logging {
val global: intp.global.type = intp.global
import global._
- import definitions.{ PredefModule, RootClass, AnyClass, AnyRefClass, ScalaPackage, JavaLangPackage }
+ import definitions.{ PredefModule, AnyClass, AnyRefClass, ScalaPackage, JavaLangPackage }
+ import rootMirror.{ RootClass, getModuleIfDefined }
type ExecResult = Any
- import intp.{ DBG, debugging, afterTyper }
-
+ import intp.{ debugging }
+
// verbosity goes up with consecutive tabs
private var verbosity: Int = 0
def resetVerbosity() = verbosity = 0
-
- def getType(name: String, isModule: Boolean) = {
- val f = if (isModule) definitions.getModule(_: Name) else definitions.getClass(_: Name)
- try Some(f(name).tpe)
- catch { case _: MissingRequirementError => None }
- }
-
- def typeOf(name: String) = getType(name, false)
- def moduleOf(name: String) = getType(name, true)
-
+
+ def getSymbol(name: String, isModule: Boolean) = (
+ if (isModule) getModuleIfDefined(name)
+ else getModuleIfDefined(name)
+ )
+ def getType(name: String, isModule: Boolean) = getSymbol(name, isModule).tpe
+ def typeOf(name: String) = getType(name, false)
+ def moduleOf(name: String) = getType(name, true)
+
trait CompilerCompletion {
def tp: Type
def effectiveTp = tp match {
@@ -48,16 +49,16 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
private def anyMembers = AnyClass.tpe.nonPrivateMembers
def anyRefMethodsToShow = Set("isInstanceOf", "asInstanceOf", "toString")
- def tos(sym: Symbol) = sym.name.decode.toString
- def memberNamed(s: String) = members find (x => tos(x) == s)
- def hasMethod(s: String) = methods exists (x => tos(x) == s)
+ def tos(sym: Symbol): String = sym.decodedName
+ def memberNamed(s: String) = afterTyper(effectiveTp member newTermName(s))
+ def hasMethod(s: String) = memberNamed(s).isMethod
// XXX we'd like to say "filterNot (_.isDeprecated)" but this causes the
// compiler to crash for reasons not yet known.
- def members = afterTyper((effectiveTp.nonPrivateMembers ++ anyMembers) filter (_.isPublic))
- def methods = members filter (_.isMethod)
- def packages = members filter (_.isPackage)
- def aliases = members filter (_.isAliasType)
+ def members = afterTyper((effectiveTp.nonPrivateMembers.toList ++ anyMembers) filter (_.isPublic))
+ def methods = members.toList filter (_.isMethod)
+ def packages = members.toList filter (_.isPackage)
+ def aliases = members.toList filter (_.isAliasType)
def memberNames = members map tos
def methodNames = methods map tos
@@ -65,12 +66,19 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
def aliasNames = aliases map tos
}
+ object NoTypeCompletion extends TypeMemberCompletion(NoType) {
+ override def memberNamed(s: String) = NoSymbol
+ override def members = Nil
+ override def follow(s: String) = None
+ override def alternativesFor(id: String) = Nil
+ }
+
object TypeMemberCompletion {
def apply(tp: Type, runtimeType: Type, param: NamedParam): TypeMemberCompletion = {
new TypeMemberCompletion(tp) {
var upgraded = false
lazy val upgrade = {
- intp rebind param
+ intp rebind param
intp.reporter.printMessage("\nRebinding stable value %s from %s to %s".format(param.name, tp, param.tpe))
upgraded = true
new TypeMemberCompletion(runtimeType)
@@ -92,7 +100,8 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
}
}
def apply(tp: Type): TypeMemberCompletion = {
- if (tp.typeSymbol.isPackageClass) new PackageCompletion(tp)
+ if (tp eq NoType) NoTypeCompletion
+ else if (tp.typeSymbol.isPackageClass) new PackageCompletion(tp)
else new TypeMemberCompletion(tp)
}
def imported(tp: Type) = new ImportCompletion(tp)
@@ -103,9 +112,9 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
def excludeEndsWith: List[String] = Nil
def excludeStartsWith: List[String] = List("<") // <byname>, <repeated>, etc.
def excludeNames: List[String] = (anyref.methodNames filterNot anyRefMethodsToShow) :+ "_root_"
-
+
def methodSignatureString(sym: Symbol) = {
- SparkIMain stripString afterTyper(new MethodSymbolOutput(sym).methodString())
+ IMain stripString afterTyper(new MethodSymbolOutput(sym).methodString())
}
def exclude(name: String): Boolean = (
@@ -118,10 +127,10 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
def completions(verbosity: Int) =
debugging(tp + " completions ==> ")(filtered(memberNames))
-
+
override def follow(s: String): Option[CompletionAware] =
- debugging(tp + " -> '" + s + "' ==> ")(memberNamed(s) map (x => TypeMemberCompletion(x.tpe)))
-
+ debugging(tp + " -> '" + s + "' ==> ")(Some(TypeMemberCompletion(memberNamed(s).tpe)) filterNot (_ eq NoTypeCompletion))
+
override def alternativesFor(id: String): List[String] =
debugging(id + " alternatives ==> ") {
val alts = members filter (x => x.isMethod && tos(x) == id) map methodSignatureString
@@ -131,7 +140,7 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
override def toString = "%s (%d members)".format(tp, members.size)
}
-
+
class PackageCompletion(tp: Type) extends TypeMemberCompletion(tp) {
override def excludeNames = anyref.methodNames
}
@@ -142,43 +151,44 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
case _ => memberNames
}
}
-
+
class ImportCompletion(tp: Type) extends TypeMemberCompletion(tp) {
override def completions(verbosity: Int) = verbosity match {
case 0 => filtered(members filterNot (_.isSetter) map tos)
case _ => super.completions(verbosity)
}
}
-
+
// not for completion but for excluding
object anyref extends TypeMemberCompletion(AnyRefClass.tpe) { }
-
+
// the unqualified vals/defs/etc visible in the repl
object ids extends CompletionAware {
override def completions(verbosity: Int) = intp.unqualifiedIds ++ List("classOf") //, "_root_")
// now we use the compiler for everything.
- override def follow(id: String) = {
- if (completions(0) contains id) {
- intp typeOfExpression id map { tpe =>
- def default = TypeMemberCompletion(tpe)
-
- // only rebinding vals in power mode for now.
- if (!isReplPower) default
- else intp runtimeClassAndTypeOfTerm id match {
- case Some((clazz, runtimeType)) =>
- val sym = intp.symbolOfTerm(id)
- if (sym.isStable) {
- val param = new NamedParam.Untyped(id, intp valueOfTerm id getOrElse null)
- TypeMemberCompletion(tpe, runtimeType, param)
- }
- else default
- case _ =>
- default
+ override def follow(id: String): Option[CompletionAware] = {
+ if (!completions(0).contains(id))
+ return None
+
+ val tpe = intp typeOfExpression id
+ if (tpe == NoType)
+ return None
+
+ def default = Some(TypeMemberCompletion(tpe))
+
+ // only rebinding vals in power mode for now.
+ if (!isReplPower) default
+ else intp runtimeClassAndTypeOfTerm id match {
+ case Some((clazz, runtimeType)) =>
+ val sym = intp.symbolOfTerm(id)
+ if (sym.isStable) {
+ val param = new NamedParam.Untyped(id, intp valueOfTerm id getOrElse null)
+ Some(TypeMemberCompletion(tpe, runtimeType, param))
}
- }
+ else default
+ case _ =>
+ default
}
- else
- None
}
override def toString = "<repl ids> (%s)".format(completions(0).size)
}
@@ -187,17 +197,10 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
private def imported = intp.sessionWildcards map TypeMemberCompletion.imported
// literal Ints, Strings, etc.
- object literals extends CompletionAware {
- def simpleParse(code: String): Tree = {
- val unit = new CompilationUnit(new util.BatchSourceFile("<console>", code))
- val scanner = new syntaxAnalyzer.UnitParser(unit)
- val tss = scanner.templateStatSeq(false)._2
-
- if (tss.size == 1) tss.head else EmptyTree
- }
-
+ object literals extends CompletionAware {
+ def simpleParse(code: String): Tree = newUnitParser(code).templateStats().last
def completions(verbosity: Int) = Nil
-
+
override def follow(id: String) = simpleParse(id) match {
case x: Literal => Some(new LiteralCompletion(x))
case _ => None
@@ -210,18 +213,18 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
override def follow(id: String) = id match {
case "_root_" => Some(this)
case _ => super.follow(id)
- }
+ }
}
// members of Predef
object predef extends TypeMemberCompletion(PredefModule.tpe) {
override def excludeEndsWith = super.excludeEndsWith ++ List("Wrapper", "ArrayOps")
override def excludeStartsWith = super.excludeStartsWith ++ List("wrap")
override def excludeNames = anyref.methodNames
-
+
override def exclude(name: String) = super.exclude(name) || (
(name contains "2")
)
-
+
override def completions(verbosity: Int) = verbosity match {
case 0 => Nil
case _ => super.completions(verbosity)
@@ -234,7 +237,7 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
override def exclude(name: String) = super.exclude(name) || (
skipArity(name)
)
-
+
override def completions(verbosity: Int) = verbosity match {
case 0 => filtered(packageNames ++ aliasNames)
case _ => super.completions(verbosity)
@@ -244,7 +247,7 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
object javalang extends PackageCompletion(JavaLangPackage.tpe) {
override lazy val excludeEndsWith = super.excludeEndsWith ++ List("Exception", "Error")
override lazy val excludeStartsWith = super.excludeStartsWith ++ List("CharacterData")
-
+
override def completions(verbosity: Int) = verbosity match {
case 0 => filtered(packageNames)
case _ => super.completions(verbosity)
@@ -256,7 +259,7 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
lazy val topLevelBase: List[CompletionAware] = List(ids, rootClass, predef, scalalang, javalang, literals)
def topLevel = topLevelBase ++ imported
def topLevelThreshold = 50
-
+
// the first tier of top level objects (doesn't include file completion)
def topLevelFor(parsed: Parsed): List[String] = {
val buf = new ListBuffer[String]
@@ -280,19 +283,6 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
if (parsed.isEmpty) xs map ("." + _) else xs
}
- // chasing down results which won't parse
- def execute(line: String): Option[ExecResult] = {
- val parsed = Parsed(line)
- def noDotOrSlash = line forall (ch => ch != '.' && ch != '/')
-
- if (noDotOrSlash) None // we defer all unqualified ids to the repl.
- else {
- (ids executionFor parsed) orElse
- (rootClass executionFor parsed) orElse
- (FileCompletion executionFor line)
- }
- }
-
// generic interface for querying (e.g. interpreter loop, testing)
def completions(buf: String): List[String] =
topLevelFor(Parsed.dotted(buf + ".", buf.length + 1))
@@ -327,11 +317,11 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
// This is jline's entry point for completion.
override def complete(buf: String, cursor: Int): Candidates = {
verbosity = if (isConsecutiveTabs(buf, cursor)) verbosity + 1 else 0
- DBG("\ncomplete(%s, %d) last = (%s, %d), verbosity: %s".format(buf, cursor, lastBuf, lastCursor, verbosity))
+ logDebug("\ncomplete(%s, %d) last = (%s, %d), verbosity: %s".format(buf, cursor, lastBuf, lastCursor, verbosity))
// we don't try lower priority completions unless higher ones return no results.
def tryCompletion(p: Parsed, completionFunction: Parsed => List[String]): Option[Candidates] = {
- val winners = completionFunction(p)
+ val winners = completionFunction(p)
if (winners.isEmpty)
return None
val newCursor =
@@ -340,39 +330,45 @@ class SparkJLineCompletion(val intp: SparkIMain) extends Completion with Complet
val advance = commonPrefix(winners)
lastCursor = p.position + advance.length
lastBuf = (buf take p.position) + advance
- DBG("tryCompletion(%s, _) lastBuf = %s, lastCursor = %s, p.position = %s".format(
+ logDebug("tryCompletion(%s, _) lastBuf = %s, lastCursor = %s, p.position = %s".format(
p, lastBuf, lastCursor, p.position))
p.position
}
-
+
Some(Candidates(newCursor, winners))
}
-
+
def mkDotted = Parsed.dotted(buf, cursor) withVerbosity verbosity
def mkUndelimited = Parsed.undelimited(buf, cursor) withVerbosity verbosity
// a single dot is special cased to completion on the previous result
def lastResultCompletion =
- if (!looksLikeInvocation(buf)) None
+ if (!looksLikeInvocation(buf)) None
else tryCompletion(Parsed.dotted(buf drop 1, cursor), lastResultFor)
- def regularCompletion = tryCompletion(mkDotted, topLevelFor)
- def fileCompletion =
- if (!looksLikePath(buf)) None
- else tryCompletion(mkUndelimited, FileCompletion completionsFor _.buffer)
-
- /** This is the kickoff point for all manner of theoretically possible compiler
- * unhappiness - fault may be here or elsewhere, but we don't want to crash the
- * repl regardless. Hopefully catching Exception is enough, but because the
- * compiler still throws some Errors it may not be.
+ def tryAll = (
+ lastResultCompletion
+ orElse tryCompletion(mkDotted, topLevelFor)
+ getOrElse Candidates(cursor, Nil)
+ )
+
+ /**
+ * This is the kickoff point for all manner of theoretically
+ * possible compiler unhappiness. The fault may be here or
+ * elsewhere, but we don't want to crash the repl regardless.
+ * The compiler makes it impossible to avoid catching Throwable
+ * with its unfortunate tendency to throw java.lang.Errors and
+ * AssertionErrors as the hats drop. We take two swings at it
+ * because there are some spots which like to throw an assertion
+ * once, then work after that. Yeah, what can I say.
*/
- try {
- (lastResultCompletion orElse regularCompletion orElse fileCompletion) getOrElse Candidates(cursor, Nil)
- }
- catch {
- case ex: Exception =>
- DBG("Error: complete(%s, %s) provoked %s".format(buf, cursor, ex))
- Candidates(cursor, List(" ", "<completion error: " + ex.getMessage + ">"))
+ try tryAll
+ catch { case ex: Throwable =>
+ logWarning("Error: complete(%s, %s) provoked".format(buf, cursor) + ex)
+ Candidates(cursor,
+ if (isReplDebug) List("<error:" + ex + ">")
+ else Nil
+ )
}
}
}
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala b/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala
index d9e1de105c..60a4d7841e 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala
@@ -1,5 +1,5 @@
/* NSC -- new Scala compiler
- * Copyright 2005-2011 LAMP/EPFL
+ * Copyright 2005-2013 LAMP/EPFL
* @author Stepan Koltsov
*/
@@ -15,28 +15,33 @@ import scala.collection.JavaConverters._
import Completion._
import io.Streamable.slurp
-/** Reads from the console using JLine */
-class SparkJLineReader(val completion: Completion) extends InteractiveReader {
+/**
+ * Reads from the console using JLine.
+ */
+class SparkJLineReader(_completion: => Completion) extends InteractiveReader {
val interactive = true
+ val consoleReader = new JLineConsoleReader()
+
+ lazy val completion = _completion
lazy val history: JLineHistory = JLineHistory()
- lazy val keyBindings =
- try KeyBinding parse slurp(term.getDefaultBindings)
- catch { case _: Exception => Nil }
private def term = consoleReader.getTerminal()
def reset() = term.reset()
def init() = term.init()
-
+
def scalaToJline(tc: ScalaCompleter): Completer = new Completer {
def complete(_buf: String, cursor: Int, candidates: JList[CharSequence]): Int = {
- val buf = if (_buf == null) "" else _buf
+ val buf = if (_buf == null) "" else _buf
val Candidates(newCursor, newCandidates) = tc.complete(buf, cursor)
newCandidates foreach (candidates add _)
newCursor
}
}
-
+
class JLineConsoleReader extends ConsoleReader with ConsoleReaderHelper {
+ if ((history: History) ne NoHistory)
+ this setHistory history
+
// working around protected/trait/java insufficiencies.
def goBack(num: Int): Unit = back(num)
def readOneKey(prompt: String) = {
@@ -46,34 +51,28 @@ class SparkJLineReader(val completion: Completion) extends InteractiveReader {
}
def eraseLine() = consoleReader.resetPromptLine("", "", 0)
def redrawLineAndFlush(): Unit = { flush() ; drawLine() ; flush() }
-
- this setBellEnabled false
- if (history ne NoHistory)
- this setHistory history
-
- if (completion ne NoCompletion) {
- val argCompletor: ArgumentCompleter =
- new ArgumentCompleter(new JLineDelimiter, scalaToJline(completion.completer()))
- argCompletor setStrict false
-
- this addCompleter argCompletor
- this setAutoprintThreshold 400 // max completion candidates without warning
+ // override def readLine(prompt: String): String
+
+ // A hook for running code after the repl is done initializing.
+ lazy val postInit: Unit = {
+ this setBellEnabled false
+
+ if (completion ne NoCompletion) {
+ val argCompletor: ArgumentCompleter =
+ new ArgumentCompleter(new JLineDelimiter, scalaToJline(completion.completer()))
+ argCompletor setStrict false
+
+ this addCompleter argCompletor
+ this setAutoprintThreshold 400 // max completion candidates without warning
+ }
}
}
-
- val consoleReader: JLineConsoleReader = new JLineConsoleReader()
- def currentLine: String = consoleReader.getCursorBuffer.buffer.toString
+ def currentLine = consoleReader.getCursorBuffer.buffer.toString
def redrawLine() = consoleReader.redrawLineAndFlush()
- def eraseLine() = {
- while (consoleReader.delete()) { }
- // consoleReader.eraseLine()
- }
+ def eraseLine() = consoleReader.eraseLine()
+ // Alternate implementation, not sure if/when I need this.
+ // def eraseLine() = while (consoleReader.delete()) { }
def readOneLine(prompt: String) = consoleReader readLine prompt
def readOneKey(prompt: String) = consoleReader readOneKey prompt
}
-
-object SparkJLineReader {
- def apply(intp: SparkIMain): SparkJLineReader = apply(new SparkJLineCompletion(intp))
- def apply(comp: Completion): SparkJLineReader = new SparkJLineReader(comp)
-}
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala b/repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala
index a3409bf665..382f8360a7 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala
@@ -1,5 +1,5 @@
/* NSC -- new Scala compiler
- * Copyright 2005-2011 LAMP/EPFL
+ * Copyright 2005-2013 LAMP/EPFL
* @author Martin Odersky
*/
@@ -10,13 +10,14 @@ import scala.tools.nsc.interpreter._
import scala.collection.{ mutable, immutable }
import scala.PartialFunction.cond
-import scala.reflect.NameTransformer
-import util.Chars
+import scala.reflect.internal.Chars
+import scala.reflect.internal.Flags._
+import scala.language.implicitConversions
trait SparkMemberHandlers {
val intp: SparkIMain
- import intp.{ Request, global, naming, atPickler }
+ import intp.{ Request, global, naming }
import global._
import naming._
@@ -29,7 +30,7 @@ trait SparkMemberHandlers {
front + (xs map string2codeQuoted mkString " + ")
}
private implicit def name2string(name: Name) = name.toString
-
+
/** A traverser that finds all mentioned identifiers, i.e. things
* that need to be imported. It might return extra names.
*/
@@ -54,26 +55,28 @@ trait SparkMemberHandlers {
}
def chooseHandler(member: Tree): MemberHandler = member match {
- case member: DefDef => new DefHandler(member)
- case member: ValDef => new ValHandler(member)
- case member@Assign(Ident(_), _) => new AssignHandler(member)
- case member: ModuleDef => new ModuleHandler(member)
- case member: ClassDef => new ClassHandler(member)
- case member: TypeDef => new TypeAliasHandler(member)
- case member: Import => new ImportHandler(member)
- case DocDef(_, documented) => chooseHandler(documented)
- case member => new GenericHandler(member)
+ case member: DefDef => new DefHandler(member)
+ case member: ValDef => new ValHandler(member)
+ case member: Assign => new AssignHandler(member)
+ case member: ModuleDef => new ModuleHandler(member)
+ case member: ClassDef => new ClassHandler(member)
+ case member: TypeDef => new TypeAliasHandler(member)
+ case member: Import => new ImportHandler(member)
+ case DocDef(_, documented) => chooseHandler(documented)
+ case member => new GenericHandler(member)
}
-
+
sealed abstract class MemberDefHandler(override val member: MemberDef) extends MemberHandler(member) {
+ def symbol = if (member.symbol eq null) NoSymbol else member.symbol
def name: Name = member.name
def mods: Modifiers = member.mods
def keyword = member.keyword
- def prettyName = NameTransformer.decode(name)
-
+ def prettyName = name.decode
+
override def definesImplicit = member.mods.isImplicit
override def definesTerm: Option[TermName] = Some(name.toTermName) filter (_ => name.isTermName)
override def definesType: Option[TypeName] = Some(name.toTypeName) filter (_ => name.isTypeName)
+ override def definedSymbols = if (symbol eq NoSymbol) Nil else List(symbol)
}
/** Class to handle one member among all the members included
@@ -82,11 +85,8 @@ trait SparkMemberHandlers {
sealed abstract class MemberHandler(val member: Tree) {
def definesImplicit = false
def definesValue = false
- def isLegalTopLevel = member match {
- case _: ModuleDef | _: ClassDef | _: Import => true
- case _ => false
- }
-
+ def isLegalTopLevel = false
+
def definesTerm = Option.empty[TermName]
def definesType = Option.empty[TypeName]
@@ -94,6 +94,7 @@ trait SparkMemberHandlers {
def importedNames = List[Name]()
def definedNames = definesTerm.toList ++ definesType.toList
def definedOrImported = definedNames ++ importedNames
+ def definedSymbols = List[Symbol]()
def extraCodeToEvaluate(req: Request): String = ""
def resultExtractionCode(req: Request): String = ""
@@ -103,11 +104,11 @@ trait SparkMemberHandlers {
}
class GenericHandler(member: Tree) extends MemberHandler(member)
-
+
class ValHandler(member: ValDef) extends MemberDefHandler(member) {
- val maxStringElements = 1000 // no need to mkString billions of elements
+ val maxStringElements = 1000 // no need to mkString billions of elements
override def definesValue = true
-
+
override def resultExtractionCode(req: Request): String = {
val isInternal = isUserVarName(name) && req.lookupTypeOf(name) == "Unit"
if (!mods.isPublic || isInternal) ""
@@ -116,22 +117,27 @@ trait SparkMemberHandlers {
val resultString =
if (mods.isLazy) codegenln(false, "<lazy>")
else any2stringOf(req fullPath name, maxStringElements)
-
- """ + "%s: %s = " + %s""".format(prettyName, string2code(req typeOf name), resultString)
+
+ val vidString =
+ if (replProps.vids) """" + " @ " + "%%8x".format(System.identityHashCode(%s)) + " """.trim.format(req fullPath name)
+ else ""
+
+ """ + "%s%s: %s = " + %s""".format(string2code(prettyName), vidString, string2code(req typeOf name), resultString)
}
}
}
class DefHandler(member: DefDef) extends MemberDefHandler(member) {
private def vparamss = member.vparamss
- // true if 0-arity
- override def definesValue = vparamss.isEmpty || vparamss.head.isEmpty
+ private def isMacro = member.symbol hasFlag MACRO
+ // true if not a macro and 0-arity
+ override def definesValue = !isMacro && flattensToEmpty(vparamss)
override def resultExtractionCode(req: Request) =
if (mods.isPublic) codegenln(name, ": ", req.typeOf(name)) else ""
}
class AssignHandler(member: Assign) extends MemberHandler(member) {
- val lhs = member.lhs.asInstanceOf[Ident] // an unfortunate limitation
+ val Assign(lhs, rhs) = member
val name = newTermName(freshInternalVarName())
override def definesTerm = Some(name)
@@ -142,15 +148,15 @@ trait SparkMemberHandlers {
/** Print out lhs instead of the generated varName */
override def resultExtractionCode(req: Request) = {
val lhsType = string2code(req lookupTypeOf name)
- val res = string2code(req fullPath name)
-
- """ + "%s: %s = " + %s + "\n" """.format(lhs, lhsType, res) + "\n"
+ val res = string2code(req fullPath name)
+ """ + "%s: %s = " + %s + "\n" """.format(string2code(lhs.toString), lhsType, res) + "\n"
}
}
class ModuleHandler(module: ModuleDef) extends MemberDefHandler(module) {
override def definesTerm = Some(name)
override def definesValue = true
+ override def isLegalTopLevel = true
override def resultExtractionCode(req: Request) = codegenln("defined module ", name)
}
@@ -158,7 +164,8 @@ trait SparkMemberHandlers {
class ClassHandler(member: ClassDef) extends MemberDefHandler(member) {
override def definesType = Some(name.toTypeName)
override def definesTerm = Some(name.toTermName) filter (_ => mods.isCase)
-
+ override def isLegalTopLevel = true
+
override def resultExtractionCode(req: Request) =
codegenln("defined %s %s".format(keyword, name))
}
@@ -173,26 +180,42 @@ trait SparkMemberHandlers {
class ImportHandler(imp: Import) extends MemberHandler(imp) {
val Import(expr, selectors) = imp
- def targetType = intp.typeOfExpression("" + expr)
-
+ def targetType: Type = intp.typeOfExpression("" + expr)
+ override def isLegalTopLevel = true
+
+ def createImportForName(name: Name): String = {
+ selectors foreach {
+ case sel @ ImportSelector(old, _, `name`, _) => return "import %s.{ %s }".format(expr, sel)
+ case _ => ()
+ }
+ "import %s.%s".format(expr, name)
+ }
+ // TODO: Need to track these specially to honor Predef masking attempts,
+ // because they must be the leading imports in the code generated for each
+ // line. We can use the same machinery as Contexts now, anyway.
+ def isPredefImport = isReferenceToPredef(expr)
+
// wildcard imports, e.g. import foo._
private def selectorWild = selectors filter (_.name == nme.USCOREkw)
// renamed imports, e.g. import foo.{ bar => baz }
private def selectorRenames = selectors map (_.rename) filterNot (_ == null)
-
+
/** Whether this import includes a wildcard import */
val importsWildcard = selectorWild.nonEmpty
-
+
+ /** Whether anything imported is implicit .*/
+ def importsImplicit = implicitSymbols.nonEmpty
+
def implicitSymbols = importedSymbols filter (_.isImplicit)
def importedSymbols = individualSymbols ++ wildcardSymbols
-
+
lazy val individualSymbols: List[Symbol] =
- atPickler(targetType.toList flatMap (tp => individualNames map (tp nonPrivateMember _)))
+ beforePickler(individualNames map (targetType nonPrivateMember _))
lazy val wildcardSymbols: List[Symbol] =
- if (importsWildcard) atPickler(targetType.toList flatMap (_.nonPrivateMembers))
+ if (importsWildcard) beforePickler(targetType.nonPrivateMembers.toList)
else Nil
-
+
/** Complete list of names imported by a wildcard */
lazy val wildcardNames: List[Name] = wildcardSymbols map (_.name)
lazy val individualNames: List[Name] = selectorRenames filterNot (_ == nme.USCOREkw) flatMap (_.bothNames)
@@ -200,7 +223,7 @@ trait SparkMemberHandlers {
/** The names imported by this statement */
override lazy val importedNames: List[Name] = wildcardNames ++ individualNames
lazy val importsSymbolNamed: Set[String] = importedNames map (_.toString) toSet
-
+
def importString = imp.toString
override def resultExtractionCode(req: Request) = codegenln(importString) + "\n"
}
diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 8f9b632c0e..daaa2a0305 100644
--- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -1,32 +1,17 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
package org.apache.spark.repl
import java.io._
import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer
-import scala.collection.JavaConversions._
-import org.scalatest.FunSuite
import com.google.common.io.Files
+import org.scalatest.FunSuite
+import org.apache.spark.SparkContext
+
class ReplSuite extends FunSuite {
+
def runInterpreter(master: String, input: String): String = {
val in = new BufferedReader(new StringReader(input + "\n"))
val out = new StringWriter()
@@ -56,104 +41,140 @@ class ReplSuite extends FunSuite {
def assertContains(message: String, output: String) {
assert(output.contains(message),
- "Interpreter output did not contain '" + message + "':\n" + output)
+ "Interpreter output did not contain '" + message + "':\n" + output)
}
def assertDoesNotContain(message: String, output: String) {
assert(!output.contains(message),
- "Interpreter output contained '" + message + "':\n" + output)
+ "Interpreter output contained '" + message + "':\n" + output)
+ }
+
+ test("propagation of local properties") {
+ // A mock ILoop that doesn't install the SIGINT handler.
+ class ILoop(out: PrintWriter) extends SparkILoop(None, out, None) {
+ settings = new scala.tools.nsc.Settings
+ settings.usejavacp.value = true
+ org.apache.spark.repl.Main.interp = this
+ override def createInterpreter() {
+ intp = new SparkILoopInterpreter
+ intp.setContextClassLoader()
+ }
+ }
+
+ val out = new StringWriter()
+ val interp = new ILoop(new PrintWriter(out))
+ interp.sparkContext = new SparkContext("local", "repl-test")
+ interp.createInterpreter()
+ interp.intp.initialize()
+ interp.sparkContext.setLocalProperty("someKey", "someValue")
+
+ // Make sure the value we set in the caller to interpret is propagated in the thread that
+ // interprets the command.
+ interp.interpret("org.apache.spark.repl.Main.interp.sparkContext.getLocalProperty(\"someKey\")")
+ assert(out.toString.contains("someValue"))
+
+ interp.sparkContext.stop()
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
}
- test ("simple foreach with accumulator") {
- val output = runInterpreter("local", """
- val accum = sc.accumulator(0)
- sc.parallelize(1 to 10).foreach(x => accum += x)
- accum.value
- """)
+ test("simple foreach with accumulator") {
+ val output = runInterpreter("local",
+ """
+ |val accum = sc.accumulator(0)
+ |sc.parallelize(1 to 10).foreach(x => accum += x)
+ |accum.value
+ """.stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res1: Int = 55", output)
}
- test ("external vars") {
- val output = runInterpreter("local", """
- var v = 7
- sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
- v = 10
- sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
- """)
+ test("external vars") {
+ val output = runInterpreter("local",
+ """
+ |var v = 7
+ |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
+ |v = 10
+ |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
+ """.stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
}
- test ("external classes") {
- val output = runInterpreter("local", """
- class C {
- def foo = 5
- }
- sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_)
- """)
+ test("external classes") {
+ val output = runInterpreter("local",
+ """
+ |class C {
+ |def foo = 5
+ |}
+ |sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_)
+ """.stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 50", output)
}
- test ("external functions") {
- val output = runInterpreter("local", """
- def double(x: Int) = x + x
- sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_)
- """)
+ test("external functions") {
+ val output = runInterpreter("local",
+ """
+ |def double(x: Int) = x + x
+ |sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_)
+ """.stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 110", output)
}
- test ("external functions that access vars") {
- val output = runInterpreter("local", """
- var v = 7
- def getV() = v
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- v = 10
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- """)
+ test("external functions that access vars") {
+ val output = runInterpreter("local",
+ """
+ |var v = 7
+ |def getV() = v
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ |v = 10
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ """.stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
}
- test ("broadcast vars") {
+ test("broadcast vars") {
// Test that the value that a broadcast var had when it was created is used,
// even if that variable is then modified in the driver program
// TODO: This doesn't actually work for arrays when we run in local mode!
- val output = runInterpreter("local", """
- var array = new Array[Int](5)
- val broadcastArray = sc.broadcast(array)
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- array(0) = 5
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- """)
+ val output = runInterpreter("local",
+ """
+ |var array = new Array[Int](5)
+ |val broadcastArray = sc.broadcast(array)
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ |array(0) = 5
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ """.stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output)
assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output)
}
- test ("interacting with files") {
+ test("interacting with files") {
val tempDir = Files.createTempDir()
val out = new FileWriter(tempDir + "/input")
out.write("Hello world!\n")
out.write("What's up?\n")
out.write("Goodbye\n")
out.close()
- val output = runInterpreter("local", """
- var file = sc.textFile("%s/input").cache()
- file.count()
- file.count()
- file.count()
- """.format(tempDir.getAbsolutePath))
+ val output = runInterpreter("local",
+ """
+ |var file = sc.textFile("%s/input").cache()
+ |file.count()
+ |file.count()
+ |file.count()
+ """.stripMargin.format(tempDir.getAbsolutePath))
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Long = 3", output)
@@ -161,19 +182,20 @@ class ReplSuite extends FunSuite {
assertContains("res2: Long = 3", output)
}
- test ("local-cluster mode") {
- val output = runInterpreter("local-cluster[1,1,512]", """
- var v = 7
- def getV() = v
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- v = 10
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- var array = new Array[Int](5)
- val broadcastArray = sc.broadcast(array)
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- array(0) = 5
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- """)
+ test("local-cluster mode") {
+ val output = runInterpreter("local-cluster[1,1,512]",
+ """
+ |var v = 7
+ |def getV() = v
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ |v = 10
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ |var array = new Array[Int](5)
+ |val broadcastArray = sc.broadcast(array)
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ |array(0) = 5
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ """.stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
@@ -183,19 +205,20 @@ class ReplSuite extends FunSuite {
}
if (System.getenv("MESOS_NATIVE_LIBRARY") != null) {
- test ("running on Mesos") {
- val output = runInterpreter("localquiet", """
- var v = 7
- def getV() = v
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- v = 10
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- var array = new Array[Int](5)
- val broadcastArray = sc.broadcast(array)
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- array(0) = 5
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- """)
+ test("running on Mesos") {
+ val output = runInterpreter("localquiet",
+ """
+ |var v = 7
+ |def getV() = v
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ |v = 10
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ |var array = new Array[Int](5)
+ |val broadcastArray = sc.broadcast(array)
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ |array(0) = 5
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ """.stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
diff --git a/run-example b/run-example
index 08ec717ca5..a78192d31d 100755
--- a/run-example
+++ b/run-example
@@ -17,7 +17,12 @@
# limitations under the License.
#
-SCALA_VERSION=2.9.3
+cygwin=false
+case "`uname`" in
+ CYGWIN*) cygwin=true;;
+esac
+
+SCALA_VERSION=2.10
# Figure out where the Scala framework is installed
FWDIR="$(cd `dirname $0`; pwd)"
@@ -59,6 +64,11 @@ fi
CLASSPATH=`$FWDIR/bin/compute-classpath.sh`
CLASSPATH="$SPARK_EXAMPLES_JAR:$CLASSPATH"
+if $cygwin; then
+ CLASSPATH=`cygpath -wp $CLASSPATH`
+ export SPARK_EXAMPLES_JAR=`cygpath -w $SPARK_EXAMPLES_JAR`
+fi
+
# Find java binary
if [ -n "${JAVA_HOME}" ]; then
RUNNER="${JAVA_HOME}/bin/java"
diff --git a/run-example2.cmd b/run-example2.cmd
index dbb371ecfc..d4ad98d6e7 100644
--- a/run-example2.cmd
+++ b/run-example2.cmd
@@ -17,7 +17,7 @@ rem See the License for the specific language governing permissions and
rem limitations under the License.
rem
-set SCALA_VERSION=2.9.3
+set SCALA_VERSION=2.10
rem Figure out where the Spark framework is installed
set FWDIR=%~dp0
diff --git a/sbt/sbt b/sbt/sbt
index c31a0280ff..5942280585 100755
--- a/sbt/sbt
+++ b/sbt/sbt
@@ -17,12 +17,27 @@
# limitations under the License.
#
-EXTRA_ARGS=""
+cygwin=false
+case "`uname`" in
+ CYGWIN*) cygwin=true;;
+esac
+
+EXTRA_ARGS="-Xmx1200m -XX:MaxPermSize=350m -XX:ReservedCodeCacheSize=256m"
if [ "$MESOS_HOME" != "" ]; then
- EXTRA_ARGS="-Djava.library.path=$MESOS_HOME/lib/java"
+ EXTRA_ARGS="$EXTRA_ARGS -Djava.library.path=$MESOS_HOME/lib/java"
fi
export SPARK_HOME=$(cd "$(dirname $0)/.." 2>&1 >/dev/null ; pwd)
export SPARK_TESTING=1 # To put test classes on classpath
-java -Xmx1200m -XX:MaxPermSize=350m -XX:ReservedCodeCacheSize=256m $EXTRA_ARGS $SBT_OPTS -jar "$SPARK_HOME"/sbt/sbt-launch-*.jar "$@"
+SBT_JAR="$SPARK_HOME"/sbt/sbt-launch-*.jar
+if $cygwin; then
+ SBT_JAR=`cygpath -w $SBT_JAR`
+ export SPARK_HOME=`cygpath -w $SPARK_HOME`
+ EXTRA_ARGS="$EXTRA_ARGS -Djline.terminal=jline.UnixTerminal -Dsbt.cygwin=true"
+ stty -icanon min 1 -echo > /dev/null 2>&1
+ java $EXTRA_ARGS $SBT_OPTS -jar $SBT_JAR "$@"
+ stty icanon echo > /dev/null 2>&1
+else
+ java $EXTRA_ARGS $SBT_OPTS -jar $SBT_JAR "$@"
+fi \ No newline at end of file
diff --git a/spark-class b/spark-class
index e111ef6da7..4eb95a9ba2 100755
--- a/spark-class
+++ b/spark-class
@@ -17,7 +17,12 @@
# limitations under the License.
#
-SCALA_VERSION=2.9.3
+cygwin=false
+case "`uname`" in
+ CYGWIN*) cygwin=true;;
+esac
+
+SCALA_VERSION=2.10
# Figure out where the Scala framework is installed
FWDIR="$(cd `dirname $0`; pwd)"
@@ -55,7 +60,7 @@ case "$1" in
'org.apache.spark.deploy.worker.Worker')
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS"
;;
- 'org.apache.spark.executor.StandaloneExecutorBackend')
+ 'org.apache.spark.executor.CoarseGrainedExecutorBackend')
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
;;
'org.apache.spark.executor.MesosExecutorBackend')
@@ -95,16 +100,41 @@ export JAVA_OPTS
if [ ! -f "$FWDIR/RELEASE" ]; then
# Exit if the user hasn't compiled Spark
- ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null
- if [[ $? != 0 ]]; then
- echo "Failed to find Spark assembly in $FWDIR/assembly/target" >&2
- echo "You need to build Spark with sbt/sbt assembly before running this program" >&2
+ num_jars=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar" | wc -l)
+ jars_list=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar")
+ if [ "$num_jars" -eq "0" ]; then
+ echo "Failed to find Spark assembly in $FWDIR/assembly/target/scala-$SCALA_VERSION/" >&2
+ echo "You need to build Spark with 'sbt/sbt assembly' before running this program." >&2
+ exit 1
+ fi
+ if [ "$num_jars" -gt "1" ]; then
+ echo "Found multiple Spark assembly jars in $FWDIR/assembly/target/scala-$SCALA_VERSION:" >&2
+ echo "$jars_list"
+ echo "Please remove all but one jar."
exit 1
fi
fi
+TOOLS_DIR="$FWDIR"/tools
+SPARK_TOOLS_JAR=""
+if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar ]; then
+ # Use the JAR from the SBT build
+ export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar`
+fi
+if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then
+ # Use the JAR from the Maven build
+ # TODO: this also needs to become an assembly!
+ export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar`
+fi
+
# Compute classpath using external script
CLASSPATH=`$FWDIR/bin/compute-classpath.sh`
+CLASSPATH="$SPARK_TOOLS_JAR:$CLASSPATH"
+
+if $cygwin; then
+ CLASSPATH=`cygpath -wp $CLASSPATH`
+ export SPARK_TOOLS_JAR=`cygpath -w $SPARK_TOOLS_JAR`
+fi
export CLASSPATH
if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then
@@ -115,3 +145,5 @@ if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then
fi
exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@"
+
+
diff --git a/spark-class2.cmd b/spark-class2.cmd
index d4d853e8ad..3869d0761b 100644
--- a/spark-class2.cmd
+++ b/spark-class2.cmd
@@ -65,10 +65,17 @@ if "%FOUND_JAR%"=="0" (
)
:skip_build_test
+set TOOLS_DIR=%FWDIR%tools
+set SPARK_TOOLS_JAR=
+for %%d in ("%TOOLS_DIR%\target\scala-%SCALA_VERSION%\spark-tools*assembly*.jar") do (
+ set SPARK_TOOLS_JAR=%%d
+)
+
rem Compute classpath using external script
set DONT_PRINT_CLASSPATH=1
call "%FWDIR%bin\compute-classpath.cmd"
set DONT_PRINT_CLASSPATH=0
+set CLASSPATH=%SPARK_TOOLS_JAR%;%CLASSPATH%
rem Figure out where java is.
set RUNNER=java
diff --git a/spark-shell b/spark-shell
index 9608bd3f30..d20af0fb39 100755
--- a/spark-shell
+++ b/spark-shell
@@ -23,7 +23,11 @@
# if those two env vars are set in spark-env.sh but MASTER is not.
# Options:
# -c <cores> Set the number of cores for REPL to use
-#
+
+cygwin=false
+case "`uname`" in
+ CYGWIN*) cygwin=true;;
+esac
# Enter posix mode for bash
set -o posix
@@ -79,7 +83,18 @@ if [[ ! $? ]]; then
saved_stty=""
fi
-$FWDIR/spark-class $OPTIONS org.apache.spark.repl.Main "$@"
+if $cygwin; then
+ # Workaround for issue involving JLine and Cygwin
+ # (see http://sourceforge.net/p/jline/bugs/40/).
+ # If you're using the Mintty terminal emulator in Cygwin, may need to set the
+ # "Backspace sends ^H" setting in "Keys" section of the Mintty options
+ # (see https://github.com/sbt/sbt/issues/562).
+ stty -icanon min 1 -echo > /dev/null 2>&1
+ $FWDIR/spark-class -Djline.terminal=unix $OPTIONS org.apache.spark.repl.Main "$@"
+ stty icanon echo > /dev/null 2>&1
+else
+ $FWDIR/spark-class $OPTIONS org.apache.spark.repl.Main "$@"
+fi
# record the exit status lest it be overwritten:
# then reenable echo and propagate the code.
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar
deleted file mode 100644
index 65f79925a4..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar
+++ /dev/null
Binary files differ
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5
deleted file mode 100644
index 29f45f4adb..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.md5
+++ /dev/null
@@ -1 +0,0 @@
-18876b8bc2e4cef28b6d191aa49d963f \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1
deleted file mode 100644
index e3bd62bac0..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.jar.sha1
+++ /dev/null
@@ -1 +0,0 @@
-06b27270ffa52250a2c08703b397c99127b72060 \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom
deleted file mode 100644
index 082d35726a..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom
+++ /dev/null
@@ -1,9 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns="http://maven.apache.org/POM/4.0.0"
- xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
- <modelVersion>4.0.0</modelVersion>
- <groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <version>0.7.2-spark</version>
- <description>POM was created from install:install-file</description>
-</project>
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5
deleted file mode 100644
index 92c4132b5b..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.md5
+++ /dev/null
@@ -1 +0,0 @@
-7bc4322266e6032bdf9ef6eebdd8097d \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1 b/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1
deleted file mode 100644
index 8a1d8a097a..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/0.7.2-spark/kafka-0.7.2-spark.pom.sha1
+++ /dev/null
@@ -1 +0,0 @@
-d0f79e8eff0db43ca7bcf7dce2c8cd2972685c9d \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml
deleted file mode 100644
index 720cd51c2f..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml
+++ /dev/null
@@ -1,12 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<metadata>
- <groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <versioning>
- <release>0.7.2-spark</release>
- <versions>
- <version>0.7.2-spark</version>
- </versions>
- <lastUpdated>20130121015225</lastUpdated>
- </versioning>
-</metadata>
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5
deleted file mode 100644
index a4ce5dc9e8..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.md5
+++ /dev/null
@@ -1 +0,0 @@
-e2b9c7c5f6370dd1d21a0aae5e8dcd77 \ No newline at end of file
diff --git a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1 b/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1
deleted file mode 100644
index b869eaf2a6..0000000000
--- a/streaming/lib/org/apache/kafka/kafka/maven-metadata-local.xml.sha1
+++ /dev/null
@@ -1 +0,0 @@
-2a4341da936b6c07a09383d17ffb185ac558ee91 \ No newline at end of file
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 3b25fb49fb..e3b6fee9b2 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -26,23 +26,29 @@
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-streaming_2.9.3</artifactId>
+ <artifactId>spark-streaming_2.10</artifactId>
<packaging>jar</packaging>
<name>Spark Project Streaming</name>
<url>http://spark.incubator.apache.org/</url>
<repositories>
- <!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
<repository>
- <id>lib</id>
- <url>file://${project.basedir}/lib</url>
+ <id>apache-repo</id>
+ <name>Apache Repository</name>
+ <url>https://repository.apache.org/content/repositories/releases</url>
+ <releases>
+ <enabled>true</enabled>
+ </releases>
+ <snapshots>
+ <enabled>false</enabled>
+ </snapshots>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core_2.9.3</artifactId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@@ -55,10 +61,23 @@
<version>1.9.11</version>
</dependency>
<dependency>
- <groupId>org.apache.kafka</groupId>
- <artifactId>kafka</artifactId>
- <version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
- <scope>provided</scope>
+ <groupId>com.sksamuel.kafka</groupId>
+ <artifactId>kafka_${scala.binary.version}</artifactId>
+ <version>0.8.0-beta1</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.sun.jmx</groupId>
+ <artifactId>jmxri</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.sun.jdmk</groupId>
+ <artifactId>jmxtools</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>net.sf.jopt-simple</groupId>
+ <artifactId>jopt-simple</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.apache.flume</groupId>
@@ -69,35 +88,39 @@
<groupId>org.jboss.netty</groupId>
<artifactId>netty</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.xerial.snappy</groupId>
+ <artifactId>*</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
- <groupId>com.github.sgroschupf</groupId>
- <artifactId>zkclient</artifactId>
- <version>0.1</version>
- </dependency>
- <dependency>
<groupId>org.twitter4j</groupId>
<artifactId>twitter4j-stream</artifactId>
<version>3.0.3</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
</dependency>
<dependency>
- <groupId>com.typesafe.akka</groupId>
- <artifactId>akka-zeromq</artifactId>
- <version>2.0.3</version>
+ <groupId>${akka.group}</groupId>
+ <artifactId>akka-zeromq_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_2.9.3</artifactId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_2.9.3</artifactId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
@@ -110,10 +133,19 @@
<artifactId>slf4j-log4j12</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>commons-io</groupId>
+ <artifactId>commons-io</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.eclipse.paho</groupId>
+ <artifactId>mqtt-client</artifactId>
+ <version>0.4.0</version>
+ </dependency>
</dependencies>
<build>
- <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
- <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
<plugin>
<groupId>org.scalatest</groupId>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index bb9febad38..7b343d2376 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -40,7 +40,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
val graph = ssc.graph
val checkpointDir = ssc.checkpointDir
val checkpointDuration = ssc.checkpointDuration
- val pendingTimes = ssc.scheduler.jobManager.getPendingTimes()
+ val pendingTimes = ssc.scheduler.getPendingTimes()
val delaySeconds = MetadataCleaner.getDelaySeconds
def validate() {
@@ -94,7 +94,7 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
fs.delete(file, false)
fs.rename(writeFile, file)
- val finishTime = System.currentTimeMillis();
+ val finishTime = System.currentTimeMillis()
logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file +
"', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds")
return
@@ -124,7 +124,9 @@ class CheckpointWriter(checkpointDir: String) extends Logging {
def stop() {
synchronized {
- if (stopped) return ;
+ if (stopped) {
+ return
+ }
stopped = true
}
executor.shutdown()
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index 80da6bd30b..a78d3965ee 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -17,27 +17,23 @@
package org.apache.spark.streaming
-import org.apache.spark.streaming.dstream._
import StreamingContext._
-import org.apache.spark.util.MetadataCleaner
-
-//import Time._
-
+import org.apache.spark.streaming.dstream._
+import org.apache.spark.streaming.scheduler.Job
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.MetadataCleaner
-import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.reflect.ClassTag
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
-import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.conf.Configuration
/**
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
- * sequence of RDDs (of the same type) representing a continuous stream of data (see [[org.apache.spark.RDD]]
+ * sequence of RDDs (of the same type) representing a continuous stream of data (see [[org.apache.spark.rdd.RDD]]
* for more details on RDDs). DStreams can either be created from live data (such as, data from
* HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations
* such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each
@@ -56,7 +52,7 @@ import org.apache.hadoop.conf.Configuration
* - A function that is used to generate an RDD after each time interval
*/
-abstract class DStream[T: ClassManifest] (
+abstract class DStream[T: ClassTag] (
@transient protected[streaming] var ssc: StreamingContext
) extends Serializable with Logging {
@@ -82,7 +78,7 @@ abstract class DStream[T: ClassManifest] (
// RDDs generated, marked as protected[streaming] so that testsuites can access it
@transient
protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
-
+
// Time zero for the DStream
protected[streaming] var zeroTime: Time = null
@@ -274,16 +270,16 @@ abstract class DStream[T: ClassManifest] (
/**
* Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal
* method that should not be called directly.
- */
+ */
protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = {
// If this DStream was not initialized (i.e., zeroTime not set), then do it
// If RDD was already generated, then retrieve it from HashMap
generatedRDDs.get(time) match {
-
- // If an RDD was already generated and is being reused, then
+
+ // If an RDD was already generated and is being reused, then
// probably all RDDs in this DStream will be reused and hence should be cached
case Some(oldRDD) => Some(oldRDD)
-
+
// if RDD was not generated, and if the time is valid
// (based on sliding time of this DStream), then generate the RDD
case None => {
@@ -300,7 +296,7 @@ abstract class DStream[T: ClassManifest] (
}
generatedRDDs.put(time, newRDD)
Some(newRDD)
- case None =>
+ case None =>
None
}
} else {
@@ -344,7 +340,7 @@ abstract class DStream[T: ClassManifest] (
dependencies.foreach(_.clearOldMetadata(time))
}
- /* Adds metadata to the Stream while it is running.
+ /* Adds metadata to the Stream while it is running.
* This methd should be overwritten by sublcasses of InputDStream.
*/
protected[streaming] def addMetadata(metadata: Any) {
@@ -416,7 +412,7 @@ abstract class DStream[T: ClassManifest] (
// =======================================================================
/** Return a new DStream by applying a function to all elements of this DStream. */
- def map[U: ClassManifest](mapFunc: T => U): DStream[U] = {
+ def map[U: ClassTag](mapFunc: T => U): DStream[U] = {
new MappedDStream(this, context.sparkContext.clean(mapFunc))
}
@@ -424,7 +420,7 @@ abstract class DStream[T: ClassManifest] (
* Return a new DStream by applying a function to all elements of this DStream,
* and then flattening the results
*/
- def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = {
+ def flatMap[U: ClassTag](flatMapFunc: T => Traversable[U]): DStream[U] = {
new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc))
}
@@ -438,12 +434,19 @@ abstract class DStream[T: ClassManifest] (
*/
def glom(): DStream[Array[T]] = new GlommedDStream(this)
+
+ /**
+ * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+ * returned DStream has exactly numPartitions partitions.
+ */
+ def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions))
+
/**
* Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
* of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
* of the RDD.
*/
- def mapPartitions[U: ClassManifest](
+ def mapPartitions[U: ClassTag](
mapPartFunc: Iterator[T] => Iterator[U],
preservePartitioning: Boolean = false
): DStream[U] = {
@@ -479,7 +482,7 @@ abstract class DStream[T: ClassManifest] (
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
- * this DStream will be registered as an output stream and therefore materialized.
+ * 'this' DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: RDD[T] => Unit) {
this.foreach((r: RDD[T], t: Time) => foreachFunc(r))
@@ -487,28 +490,60 @@ abstract class DStream[T: ClassManifest] (
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
- * this DStream will be registered as an output stream and therefore materialized.
+ * 'this' DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: (RDD[T], Time) => Unit) {
- val newStream = new ForEachDStream(this, context.sparkContext.clean(foreachFunc))
- ssc.registerOutputStream(newStream)
- newStream
+ ssc.registerOutputStream(new ForEachDStream(this, context.sparkContext.clean(foreachFunc)))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream.
+ */
+ def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
+ transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r)))
}
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream.
*/
- def transform[U: ClassManifest](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
- transform((r: RDD[T], t: Time) => transformFunc(r))
+ def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
+ //new TransformedDStream(this, context.sparkContext.clean(transformFunc))
+ val cleanedF = context.sparkContext.clean(transformFunc)
+ val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
+ assert(rdds.length == 1)
+ cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
+ }
+ new TransformedDStream[U](Seq(this), realTransformFunc)
}
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream and 'other' DStream.
*/
- def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
- new TransformedDStream(this, context.sparkContext.clean(transformFunc))
+ def transformWith[U: ClassTag, V: ClassTag](
+ other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V]
+ ): DStream[V] = {
+ val cleanedF = ssc.sparkContext.clean(transformFunc)
+ transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream and 'other' DStream.
+ */
+ def transformWith[U: ClassTag, V: ClassTag](
+ other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V]
+ ): DStream[V] = {
+ val cleanedF = ssc.sparkContext.clean(transformFunc)
+ val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
+ assert(rdds.length == 2)
+ val rdd1 = rdds(0).asInstanceOf[RDD[T]]
+ val rdd2 = rdds(1).asInstanceOf[RDD[U]]
+ cleanedF(rdd1, rdd2, time)
+ }
+ new TransformedDStream[V](Seq(this, other), realTransformFunc)
}
/**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
index 58a0da2870..3fd5d52403 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
@@ -20,13 +20,16 @@ package org.apache.spark.streaming
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.conf.Configuration
+
import collection.mutable.HashMap
import org.apache.spark.Logging
+import scala.collection.mutable.HashMap
+import scala.reflect.ClassTag
private[streaming]
-class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T])
+class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
extends Serializable with Logging {
protected val data = new HashMap[Time, AnyRef]()
@@ -107,4 +110,3 @@ class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T])
"[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]"
}
}
-
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index b9a58fded6..daed7ff7c3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -21,6 +21,7 @@ import dstream.InputDStream
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import collection.mutable.ArrayBuffer
import org.apache.spark.Logging
+import org.apache.spark.streaming.scheduler.Job
final private[streaming] class DStreamGraph extends Serializable with Logging {
initLogging()
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala
deleted file mode 100644
index 5233129506..0000000000
--- a/streaming/src/main/scala/org/apache/spark/streaming/JobManager.scala
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming
-
-import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
-import java.util.concurrent.Executors
-import collection.mutable.HashMap
-import collection.mutable.ArrayBuffer
-
-
-private[streaming]
-class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging {
-
- class JobHandler(ssc: StreamingContext, job: Job) extends Runnable {
- def run() {
- SparkEnv.set(ssc.env)
- try {
- val timeTaken = job.run()
- logInfo("Total delay: %.5f s for job %s of time %s (execution: %.5f s)".format(
- (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, job.time.milliseconds, timeTaken / 1000.0))
- } catch {
- case e: Exception =>
- logError("Running " + job + " failed", e)
- }
- clearJob(job)
- }
- }
-
- initLogging()
-
- val jobExecutor = Executors.newFixedThreadPool(numThreads)
- val jobs = new HashMap[Time, ArrayBuffer[Job]]
-
- def runJob(job: Job) {
- jobs.synchronized {
- jobs.getOrElseUpdate(job.time, new ArrayBuffer[Job]) += job
- }
- jobExecutor.execute(new JobHandler(ssc, job))
- logInfo("Added " + job + " to queue")
- }
-
- def stop() {
- jobExecutor.shutdown()
- }
-
- private def clearJob(job: Job) {
- var timeCleared = false
- val time = job.time
- jobs.synchronized {
- val jobsOfTime = jobs.get(time)
- if (jobsOfTime.isDefined) {
- jobsOfTime.get -= job
- if (jobsOfTime.get.isEmpty) {
- jobs -= time
- timeCleared = true
- }
- } else {
- throw new Exception("Job finished for time " + job.time +
- " but time does not exist in jobs")
- }
- }
- if (timeCleared) {
- ssc.scheduler.clearOldMetadata(time)
- }
- }
-
- def getPendingTimes(): Array[Time] = {
- jobs.synchronized {
- jobs.keySet.toArray
- }
- }
-}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
index 757bc98981..80af96c060 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
@@ -18,16 +18,15 @@
package org.apache.spark.streaming
import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.streaming.dstream.{ReducedWindowedDStream, StateDStream}
-import org.apache.spark.streaming.dstream.{CoGroupedDStream, ShuffledDStream}
-import org.apache.spark.streaming.dstream.{MapValuedDStream, FlatMapValuedDStream}
+import org.apache.spark.streaming.dstream._
import org.apache.spark.{Partitioner, HashPartitioner}
import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.{Manifests, RDD, PairRDDFunctions}
+import org.apache.spark.rdd.{ClassTags, RDD, PairRDDFunctions}
import org.apache.spark.storage.StorageLevel
import scala.collection.mutable.ArrayBuffer
+import scala.reflect.{ClassTag, classTag}
import org.apache.hadoop.mapred.{JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
@@ -35,7 +34,7 @@ import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.conf.Configuration
-class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)])
+class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)])
extends Serializable {
private[streaming] def ssc = self.ssc
@@ -105,7 +104,7 @@ extends Serializable {
* combineByKey for RDDs. Please refer to combineByKey in
* [[org.apache.spark.rdd.PairRDDFunctions]] for more information.
*/
- def combineByKey[C: ClassManifest](
+ def combineByKey[C: ClassTag](
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiner: (C, C) => C,
@@ -205,7 +204,7 @@ extends Serializable {
* DStream's batching interval
*/
def reduceByKeyAndWindow(
- reduceFunc: (V, V) => V,
+ reduceFunc: (V, V) => V,
windowDuration: Duration,
slideDuration: Duration
): DStream[(K, V)] = {
@@ -336,7 +335,7 @@ extends Serializable {
* corresponding state key-value pair will be eliminated.
* @tparam S State type
*/
- def updateStateByKey[S: ClassManifest](
+ def updateStateByKey[S: ClassTag](
updateFunc: (Seq[V], Option[S]) => Option[S]
): DStream[(K, S)] = {
updateStateByKey(updateFunc, defaultPartitioner())
@@ -351,7 +350,7 @@ extends Serializable {
* @param numPartitions Number of partitions of each RDD in the new DStream.
* @tparam S State type
*/
- def updateStateByKey[S: ClassManifest](
+ def updateStateByKey[S: ClassTag](
updateFunc: (Seq[V], Option[S]) => Option[S],
numPartitions: Int
): DStream[(K, S)] = {
@@ -359,7 +358,7 @@ extends Serializable {
}
/**
- * Create a new "state" DStream where the state for each key is updated by applying
+ * Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of the key.
* [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
@@ -367,7 +366,7 @@ extends Serializable {
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
* @tparam S State type
*/
- def updateStateByKey[S: ClassManifest](
+ def updateStateByKey[S: ClassTag](
updateFunc: (Seq[V], Option[S]) => Option[S],
partitioner: Partitioner
): DStream[(K, S)] = {
@@ -390,7 +389,7 @@ extends Serializable {
* @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs.
* @tparam S State type
*/
- def updateStateByKey[S: ClassManifest](
+ def updateStateByKey[S: ClassTag](
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
rememberPartitioner: Boolean
@@ -398,89 +397,171 @@ extends Serializable {
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner)
}
-
- def mapValues[U: ClassManifest](mapValuesFunc: V => U): DStream[(K, U)] = {
+ /**
+ * Return a new DStream by applying a map function to the value of each key-value pairs in
+ * 'this' DStream without changing the key.
+ */
+ def mapValues[U: ClassTag](mapValuesFunc: V => U): DStream[(K, U)] = {
new MapValuedDStream[K, V, U](self, mapValuesFunc)
}
- def flatMapValues[U: ClassManifest](
+ /**
+ * Return a new DStream by applying a flatmap function to the value of each key-value pairs in
+ * 'this' DStream without changing the key.
+ */
+ def flatMapValues[U: ClassTag](
flatMapValuesFunc: V => TraversableOnce[U]
): DStream[(K, U)] = {
new FlatMapValuedDStream[K, V, U](self, flatMapValuesFunc)
}
/**
- * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this`
- * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
- * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with Spark's default number
* of partitions.
*/
- def cogroup[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = {
+ def cogroup[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = {
cogroup(other, defaultPartitioner())
}
/**
- * Cogroup `this` DStream with `other` DStream using a partitioner. For each key k in corresponding RDDs of `this`
- * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
- * key in both RDDs. Partitioner is used to partition each generated RDD.
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
*/
- def cogroup[W: ClassManifest](
+ def cogroup[W: ClassTag](other: DStream[(K, W)], numPartitions: Int): DStream[(K, (Seq[V], Seq[W]))] = {
+ cogroup(other, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * The supplied [[org.apache.spark.Partitioner]] is used to partition the generated RDDs.
+ */
+ def cogroup[W: ClassTag](
other: DStream[(K, W)],
partitioner: Partitioner
): DStream[(K, (Seq[V], Seq[W]))] = {
-
- val cgd = new CoGroupedDStream[K](
- Seq(self.asInstanceOf[DStream[(K, _)]], other.asInstanceOf[DStream[(K, _)]]),
- partitioner
- )
- val pdfs = new PairDStreamFunctions[K, Seq[Seq[_]]](cgd)(
- classManifest[K],
- Manifests.seqSeqManifest
+ self.transformWith(
+ other,
+ (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.cogroup(rdd2, partitioner)
)
- pdfs.mapValues {
- case Seq(vs, ws) =>
- (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
- }
}
/**
- * Join `this` DStream with `other` DStream. HashPartitioner is used
- * to partition each generated RDD into default number of partitions.
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
*/
- def join[W: ClassManifest](other: DStream[(K, W)]): DStream[(K, (V, W))] = {
+ def join[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, W))] = {
join[W](other, defaultPartitioner())
}
/**
- * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will
- * be generated by joining RDDs from `this` and other DStream. Uses the given
- * Partitioner to partition each generated RDD.
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ */
+ def join[W: ClassTag](other: DStream[(K, W)], numPartitions: Int): DStream[(K, (V, W))] = {
+ join[W](other, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * The supplied [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD.
*/
- def join[W: ClassManifest](
+ def join[W: ClassTag](
other: DStream[(K, W)],
partitioner: Partitioner
): DStream[(K, (V, W))] = {
- this.cogroup(other, partitioner)
- .flatMapValues{
- case (vs, ws) =>
- for (v <- vs.iterator; w <- ws.iterator) yield (v, w)
- }
+ self.transformWith(
+ other,
+ (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.join(rdd2, partitioner)
+ )
}
/**
- * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated
- * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix"
+ * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default
+ * number of partitions.
+ */
+ def leftOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (V, Option[W]))] = {
+ leftOuterJoin[W](other, defaultPartitioner())
+ }
+
+ /**
+ * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions`
+ * partitions.
+ */
+ def leftOuterJoin[W: ClassTag](
+ other: DStream[(K, W)],
+ numPartitions: Int
+ ): DStream[(K, (V, Option[W]))] = {
+ leftOuterJoin[W](other, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and
+ * `other` DStream. The supplied [[org.apache.spark.Partitioner]] is used to control
+ * the partitioning of each RDD.
+ */
+ def leftOuterJoin[W: ClassTag](
+ other: DStream[(K, W)],
+ partitioner: Partitioner
+ ): DStream[(K, (V, Option[W]))] = {
+ self.transformWith(
+ other,
+ (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.leftOuterJoin(rdd2, partitioner)
+ )
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default
+ * number of partitions.
+ */
+ def rightOuterJoin[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Option[V], W))] = {
+ rightOuterJoin[W](other, defaultPartitioner())
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions`
+ * partitions.
+ */
+ def rightOuterJoin[W: ClassTag](
+ other: DStream[(K, W)],
+ numPartitions: Int
+ ): DStream[(K, (Option[V], W))] = {
+ rightOuterJoin[W](other, defaultPartitioner(numPartitions))
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. The supplied [[org.apache.spark.Partitioner]] is used to control
+ * the partitioning of each RDD.
+ */
+ def rightOuterJoin[W: ClassTag](
+ other: DStream[(K, W)],
+ partitioner: Partitioner
+ ): DStream[(K, (Option[V], W))] = {
+ self.transformWith(
+ other,
+ (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.rightOuterJoin(rdd2, partitioner)
+ )
+ }
+
+ /**
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval
+ * is generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix"
*/
def saveAsHadoopFiles[F <: OutputFormat[K, V]](
prefix: String,
suffix: String
- )(implicit fm: ClassManifest[F]) {
- saveAsHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ )(implicit fm: ClassTag[F]) {
+ saveAsHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}
/**
- * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is generated
- * based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix"
+ * Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval
+ * is generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix"
*/
def saveAsHadoopFiles(
prefix: String,
@@ -504,8 +585,8 @@ extends Serializable {
def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](
prefix: String,
suffix: String
- )(implicit fm: ClassManifest[F]) {
- saveAsNewAPIHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ )(implicit fm: ClassTag[F]) {
+ saveAsNewAPIHadoopFiles(prefix, suffix, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}
/**
@@ -527,9 +608,7 @@ extends Serializable {
self.foreach(saveFunc)
}
- private def getKeyClass() = implicitly[ClassManifest[K]].erasure
+ private def getKeyClass() = implicitly[ClassTag[K]].runtimeClass
- private def getValueClass() = implicitly[ClassManifest[V]].erasure
+ private def getValueClass() = implicitly[ClassTag[V]].runtimeClass
}
-
-
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 098081d245..41da028a3c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -34,6 +34,7 @@ import org.apache.spark.streaming.receivers.ActorReceiver
import scala.collection.mutable.Queue
import scala.collection.Map
+import scala.reflect.ClassTag
import java.io.InputStream
import java.util.concurrent.atomic.AtomicInteger
@@ -46,7 +47,8 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
import org.apache.hadoop.fs.Path
import twitter4j.Status
import twitter4j.auth.Authorization
-
+import org.apache.spark.streaming.scheduler._
+import akka.util.ByteString
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -146,9 +148,10 @@ class StreamingContext private (
}
}
- protected[streaming] var checkpointDuration: Duration = if (isCheckpointPresent) cp_.checkpointDuration else null
- protected[streaming] var receiverJobThread: Thread = null
- protected[streaming] var scheduler: Scheduler = null
+ protected[streaming] val checkpointDuration: Duration = {
+ if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration
+ }
+ protected[streaming] val scheduler = new JobScheduler(this)
/**
* Return the associated Spark context
@@ -191,7 +194,7 @@ class StreamingContext private (
* Find more details at: http://spark-project.org/docs/latest/streaming-custom-receivers.html
* @param receiver Custom implementation of NetworkReceiver
*/
- def networkStream[T: ClassManifest](
+ def networkStream[T: ClassTag](
receiver: NetworkReceiver[T]): DStream[T] = {
val inputStream = new PluggableInputDStream[T](this,
receiver)
@@ -211,7 +214,7 @@ class StreamingContext private (
* to ensure the type safety, i.e parametrized type of data received and actorStream
* should be same.
*/
- def actorStream[T: ClassManifest](
+ def actorStream[T: ClassTag](
props: Props,
name: String,
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
@@ -231,14 +234,14 @@ class StreamingContext private (
* and sub sequence refer to its payload.
* @param storageLevel RDD storage level. Defaults to memory-only.
*/
- def zeroMQStream[T: ClassManifest](
+ def zeroMQStream[T: ClassTag](
publisherUrl:String,
subscribe: Subscribe,
- bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T],
+ bytesToObjects: Seq[ByteString] ⇒ Iterator[T],
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy
): DStream[T] = {
- actorStream(Props(new ZeroMQReceiver(publisherUrl,subscribe,bytesToObjects)),
+ actorStream(Props(new ZeroMQReceiver(publisherUrl, subscribe, bytesToObjects)),
"ZeroMQReceiver", storageLevel, supervisorStrategy)
}
@@ -256,10 +259,14 @@ class StreamingContext private (
groupId: String,
topics: Map[String, Int],
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
- ): DStream[String] = {
+ ): DStream[(String, String)] = {
val kafkaParams = Map[String, String](
- "zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000")
- kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel)
+ "zookeeper.connect" -> zkQuorum, "group.id" -> groupId,
+ "zookeeper.connection.timeout.ms" -> "10000")
+ kafkaStream[String, String, kafka.serializer.StringDecoder, kafka.serializer.StringDecoder](
+ kafkaParams,
+ topics,
+ storageLevel)
}
/**
@@ -270,12 +277,16 @@ class StreamingContext private (
* in its own thread.
* @param storageLevel Storage level to use for storing the received objects
*/
- def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest](
+ def kafkaStream[
+ K: ClassTag,
+ V: ClassTag,
+ U <: kafka.serializer.Decoder[_]: Manifest,
+ T <: kafka.serializer.Decoder[_]: Manifest](
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
- ): DStream[T] = {
- val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel)
+ ): DStream[(K, V)] = {
+ val inputStream = new KafkaInputDStream[K, V, U, T](this, kafkaParams, topics, storageLevel)
registerInputStream(inputStream)
inputStream
}
@@ -307,7 +318,7 @@ class StreamingContext private (
* @param storageLevel Storage level to use for storing the received objects
* @tparam T Type of the objects received (after converting bytes to objects)
*/
- def socketStream[T: ClassManifest](
+ def socketStream[T: ClassTag](
hostname: String,
port: Int,
converter: (InputStream) => Iterator[T],
@@ -329,7 +340,7 @@ class StreamingContext private (
port: Int,
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
): DStream[SparkFlumeEvent] = {
- val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel)
+ val inputStream = new FlumeInputDStream[SparkFlumeEvent](this, hostname, port, storageLevel)
registerInputStream(inputStream)
inputStream
}
@@ -344,7 +355,7 @@ class StreamingContext private (
* @param storageLevel Storage level to use for storing the received objects
* @tparam T Type of the objects in the received blocks
*/
- def rawSocketStream[T: ClassManifest](
+ def rawSocketStream[T: ClassTag](
hostname: String,
port: Int,
storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2
@@ -364,9 +375,9 @@ class StreamingContext private (
* @tparam F Input format for reading HDFS file
*/
def fileStream[
- K: ClassManifest,
- V: ClassManifest,
- F <: NewInputFormat[K, V]: ClassManifest
+ K: ClassTag,
+ V: ClassTag,
+ F <: NewInputFormat[K, V]: ClassTag
] (directory: String): DStream[(K, V)] = {
val inputStream = new FileInputDStream[K, V, F](this, directory)
registerInputStream(inputStream)
@@ -384,9 +395,9 @@ class StreamingContext private (
* @tparam F Input format for reading HDFS file
*/
def fileStream[
- K: ClassManifest,
- V: ClassManifest,
- F <: NewInputFormat[K, V]: ClassManifest
+ K: ClassTag,
+ V: ClassTag,
+ F <: NewInputFormat[K, V]: ClassTag
] (directory: String, filter: Path => Boolean, newFilesOnly: Boolean): DStream[(K, V)] = {
val inputStream = new FileInputDStream[K, V, F](this, directory, filter, newFilesOnly)
registerInputStream(inputStream)
@@ -428,7 +439,7 @@ class StreamingContext private (
* @param oneAtATime Whether only one RDD should be consumed from the queue in every interval
* @tparam T Type of objects in the RDD
*/
- def queueStream[T: ClassManifest](
+ def queueStream[T: ClassTag](
queue: Queue[RDD[T]],
oneAtATime: Boolean = true
): DStream[T] = {
@@ -444,7 +455,7 @@ class StreamingContext private (
* Set as null if no RDD should be returned when empty
* @tparam T Type of objects in the RDD
*/
- def queueStream[T: ClassManifest](
+ def queueStream[T: ClassTag](
queue: Queue[RDD[T]],
oneAtATime: Boolean,
defaultRDD: RDD[T]
@@ -454,14 +465,40 @@ class StreamingContext private (
inputStream
}
+/**
+ * Create an input stream that receives messages pushed by a mqtt publisher.
+ * @param brokerUrl Url of remote mqtt publisher
+ * @param topic topic name to subscribe to
+ * @param storageLevel RDD storage level. Defaults to memory-only.
+ */
+
+ def mqttStream(
+ brokerUrl: String,
+ topic: String,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2): DStream[String] = {
+ val inputStream = new MQTTInputDStream[String](this, brokerUrl, topic, storageLevel)
+ registerInputStream(inputStream)
+ inputStream
+ }
/**
- * Create a unified DStream from multiple DStreams of the same type and same interval
+ * Create a unified DStream from multiple DStreams of the same type and same slide duration.
*/
- def union[T: ClassManifest](streams: Seq[DStream[T]]): DStream[T] = {
+ def union[T: ClassTag](streams: Seq[DStream[T]]): DStream[T] = {
new UnionDStream[T](streams.toArray)
}
/**
+ * Create a new DStream in which each RDD is generated by applying a function on RDDs of
+ * the DStreams.
+ */
+ def transform[T: ClassTag](
+ dstreams: Seq[DStream[_]],
+ transformFunc: (Seq[RDD[_]], Time) => RDD[T]
+ ): DStream[T] = {
+ new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc))
+ }
+
+ /**
* Register an input stream that will be started (InputDStream.start() called) to get the
* input data.
*/
@@ -476,6 +513,13 @@ class StreamingContext private (
graph.addOutputStream(outputStream)
}
+ /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
+ * receiving system events related to streaming.
+ */
+ def addStreamingListener(streamingListener: StreamingListener) {
+ scheduler.listenerBus.addListener(streamingListener)
+ }
+
protected def validate() {
assert(graph != null, "Graph is null")
graph.validate()
@@ -491,27 +535,22 @@ class StreamingContext private (
* Start the execution of the streams.
*/
def start() {
- if (checkpointDir != null && checkpointDuration == null && graph != null) {
- checkpointDuration = graph.batchDuration
- }
-
validate()
+ // Get the network input streams
val networkInputStreams = graph.getInputStreams().filter(s => s match {
case n: NetworkInputDStream[_] => true
case _ => false
}).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray
+ // Start the network input tracker (must start before receivers)
if (networkInputStreams.length > 0) {
- // Start the network input tracker (must start before receivers)
networkInputTracker = new NetworkInputTracker(this, networkInputStreams)
networkInputTracker.start()
}
-
Thread.sleep(1000)
// Start the scheduler
- scheduler = new Scheduler(this)
scheduler.start()
}
@@ -522,7 +561,6 @@ class StreamingContext private (
try {
if (scheduler != null) scheduler.stop()
if (networkInputTracker != null) networkInputTracker.stop()
- if (receiverJobThread != null) receiverJobThread.interrupt()
sc.stop()
logInfo("StreamingContext stopped successfully")
} catch {
@@ -534,7 +572,7 @@ class StreamingContext private (
object StreamingContext {
- implicit def toPairDStreamFunctions[K: ClassManifest, V: ClassManifest](stream: DStream[(K,V)]) = {
+ implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = {
new PairDStreamFunctions[K, V](stream)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
index d1932b6b05..d29033df32 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala
@@ -23,9 +23,11 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.rdd.RDD
+import scala.reflect.ClassTag
+
/**
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous
- * sequence of RDDs (of the same type) representing a continuous stream of data (see [[org.apache.spark.RDD]]
+ * sequence of RDDs (of the same type) representing a continuous stream of data (see [[org.apache.spark.rdd.RDD]]
* for more details on RDDs). DStreams can either be created from live data (such as, data from
* HDFS, Kafka or Flume) or it can be generated by transformation existing DStreams using operations
* such as `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming program is running, each
@@ -41,7 +43,7 @@ import org.apache.spark.rdd.RDD
* - A time interval at which the DStream generates an RDD
* - A function that is used to generate an RDD after each time interval
*/
-class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T])
+class JavaDStream[T](val dstream: DStream[T])(implicit val classTag: ClassTag[T])
extends JavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] {
override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd)
@@ -94,9 +96,15 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM
*/
def union(that: JavaDStream[T]): JavaDStream[T] =
dstream.union(that.dstream)
+
+ /**
+ * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+ * returned DStream has exactly numPartitions partitions.
+ */
+ def repartition(numPartitions: Int): JavaDStream[T] = dstream.repartition(numPartitions)
}
object JavaDStream {
- implicit def fromDStream[T: ClassManifest](dstream: DStream[T]): JavaDStream[T] =
+ implicit def fromDStream[T: ClassTag](dstream: DStream[T]): JavaDStream[T] =
new JavaDStream[T](dstream)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index 459695b7ca..64f38ce1c0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -21,17 +21,19 @@ import java.util.{List => JList}
import java.lang.{Long => JLong}
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import org.apache.spark.streaming._
import org.apache.spark.api.java.{JavaPairRDD, JavaRDDLike, JavaRDD}
-import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
+import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
+import org.apache.spark.api.java.function.{Function3 => JFunction3, _}
import java.util
import org.apache.spark.rdd.RDD
import JavaDStream._
trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]]
extends Serializable {
- implicit val classManifest: ClassManifest[T]
+ implicit val classTag: ClassTag[T]
def dstream: DStream[T]
@@ -120,10 +122,12 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
* this DStream. Applying glom() to an RDD coalesces all elements within each partition into
* an array.
*/
- def glom(): JavaDStream[JList[T]] =
+ def glom(): JavaDStream[JList[T]] = {
new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
+ }
+
- /** Return the StreamingContext associated with this DStream */
+ /** Return the [[org.apache.spark.streaming.StreamingContext]] associated with this DStream */
def context(): StreamingContext = dstream.context()
/** Return a new DStream by applying a function to all elements of this DStream. */
@@ -133,7 +137,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/** Return a new DStream by applying a function to all elements of this DStream. */
def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairDStream[K2, V2] = {
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
+ def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]]
new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType())
}
@@ -154,7 +158,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
+ def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]]
new JavaPairDStream(dstream.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
@@ -238,7 +242,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
- * this DStream will be registered as an output stream and therefore materialized.
+ * 'this' DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: JFunction[R, Void]) {
dstream.foreach(rdd => foreachFunc.call(wrapRDD(rdd)))
@@ -246,7 +250,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Apply a function to each RDD in this DStream. This is an output operator, so
- * this DStream will be registered as an output stream and therefore materialized.
+ * 'this' DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: JFunction2[R, Time, Void]) {
dstream.foreach((rdd, time) => foreachFunc.call(wrapRDD(rdd), time))
@@ -254,11 +258,11 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream.
*/
def transform[U](transformFunc: JFunction[R, JavaRDD[U]]): JavaDStream[U] = {
- implicit val cm: ClassManifest[U] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ implicit val cm: ClassTag[U] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
def scalaTransform (in: RDD[T]): RDD[U] =
transformFunc.call(wrapRDD(in)).rdd
dstream.transform(scalaTransform(_))
@@ -266,11 +270,11 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream.
*/
def transform[U](transformFunc: JFunction2[R, Time, JavaRDD[U]]): JavaDStream[U] = {
- implicit val cm: ClassManifest[U] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ implicit val cm: ClassTag[U] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
def scalaTransform (in: RDD[T], time: Time): RDD[U] =
transformFunc.call(wrapRDD(in), time).rdd
dstream.transform(scalaTransform(_, _))
@@ -278,14 +282,14 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream.
*/
def transform[K2, V2](transformFunc: JFunction[R, JavaPairRDD[K2, V2]]):
JavaPairDStream[K2, V2] = {
- implicit val cmk: ClassManifest[K2] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K2]]
- implicit val cmv: ClassManifest[V2] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V2]]
+ implicit val cmk: ClassTag[K2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]]
+ implicit val cmv: ClassTag[V2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]]
def scalaTransform (in: RDD[T]): RDD[(K2, V2)] =
transformFunc.call(wrapRDD(in)).rdd
dstream.transform(scalaTransform(_))
@@ -293,20 +297,96 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
/**
* Return a new DStream in which each RDD is generated by applying a function
- * on each RDD of this DStream.
+ * on each RDD of 'this' DStream.
*/
def transform[K2, V2](transformFunc: JFunction2[R, Time, JavaPairRDD[K2, V2]]):
JavaPairDStream[K2, V2] = {
- implicit val cmk: ClassManifest[K2] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K2]]
- implicit val cmv: ClassManifest[V2] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V2]]
+ implicit val cmk: ClassTag[K2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]]
+ implicit val cmv: ClassTag[V2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]]
def scalaTransform (in: RDD[T], time: Time): RDD[(K2, V2)] =
transformFunc.call(wrapRDD(in), time).rdd
dstream.transform(scalaTransform(_, _))
}
/**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream and 'other' DStream.
+ */
+ def transformWith[U, W](
+ other: JavaDStream[U],
+ transformFunc: JFunction3[R, JavaRDD[U], Time, JavaRDD[W]]
+ ): JavaDStream[W] = {
+ implicit val cmu: ClassTag[U] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
+ implicit val cmv: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ def scalaTransform (inThis: RDD[T], inThat: RDD[U], time: Time): RDD[W] =
+ transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd
+ dstream.transformWith[U, W](other.dstream, scalaTransform(_, _, _))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream and 'other' DStream.
+ */
+ def transformWith[U, K2, V2](
+ other: JavaDStream[U],
+ transformFunc: JFunction3[R, JavaRDD[U], Time, JavaPairRDD[K2, V2]]
+ ): JavaPairDStream[K2, V2] = {
+ implicit val cmu: ClassTag[U] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
+ implicit val cmk2: ClassTag[K2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]]
+ implicit val cmv2: ClassTag[V2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]]
+ def scalaTransform (inThis: RDD[T], inThat: RDD[U], time: Time): RDD[(K2, V2)] =
+ transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd
+ dstream.transformWith[U, (K2, V2)](other.dstream, scalaTransform(_, _, _))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream and 'other' DStream.
+ */
+ def transformWith[K2, V2, W](
+ other: JavaPairDStream[K2, V2],
+ transformFunc: JFunction3[R, JavaPairRDD[K2, V2], Time, JavaRDD[W]]
+ ): JavaDStream[W] = {
+ implicit val cmk2: ClassTag[K2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]]
+ implicit val cmv2: ClassTag[V2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]]
+ implicit val cmw: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[W] =
+ transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd
+ dstream.transformWith[(K2, V2), W](other.dstream, scalaTransform(_, _, _))
+ }
+
+ /**
+ * Return a new DStream in which each RDD is generated by applying a function
+ * on each RDD of 'this' DStream and 'other' DStream.
+ */
+ def transformWith[K2, V2, K3, V3](
+ other: JavaPairDStream[K2, V2],
+ transformFunc: JFunction3[R, JavaPairRDD[K2, V2], Time, JavaPairRDD[K3, V3]]
+ ): JavaPairDStream[K3, V3] = {
+ implicit val cmk2: ClassTag[K2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K2]]
+ implicit val cmv2: ClassTag[V2] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V2]]
+ implicit val cmk3: ClassTag[K3] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K3]]
+ implicit val cmv3: ClassTag[V3] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V3]]
+ def scalaTransform (inThis: RDD[T], inThat: RDD[(K2, V2)], time: Time): RDD[(K3, V3)] =
+ transformFunc.call(wrapRDD(inThis), other.wrapRDD(inThat), time).rdd
+ dstream.transformWith[(K2, V2), (K3, V3)](other.dstream, scalaTransform(_, _, _))
+ }
+
+ /**
* Enable periodic checkpointing of RDDs of this DStream
* @param interval Time interval after which generated RDD will be checkpointed
*/
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index 978fca33ad..dfd6e27c3e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -21,10 +21,11 @@ import java.util.{List => JList}
import java.lang.{Long => JLong}
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
-import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
+import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, Function3 => JFunction3}
import org.apache.spark.Partitioner
import org.apache.hadoop.mapred.{JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
@@ -36,8 +37,8 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PairRDDFunctions
class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
- implicit val kManifiest: ClassManifest[K],
- implicit val vManifest: ClassManifest[V])
+ implicit val kManifest: ClassTag[K],
+ implicit val vManifest: ClassTag[V])
extends JavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] {
override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
@@ -59,6 +60,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
/** Persist the RDDs of this DStream with the given storage level */
def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel)
+ /**
+ * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+ * returned DStream has exactly numPartitions partitions.
+ */
+ def repartition(numPartitions: Int): JavaPairDStream[K, V] = dstream.repartition(numPartitions)
+
/** Method that generates a RDD for the given Duration */
def compute(validTime: Time): JavaPairRDD[K, V] = {
dstream.compute(validTime) match {
@@ -148,7 +155,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
/**
* Combine elements of each key in DStream's RDDs using custom function. This is similar to the
- * combineByKey for RDDs. Please refer to combineByKey in [[PairRDDFunctions]] for more
+ * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.PairRDDFunctions]] for more
* information.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
@@ -156,8 +163,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
mergeCombiners: JFunction2[C, C, C],
partitioner: Partitioner
): JavaPairDStream[K, C] = {
- implicit val cm: ClassManifest[C] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ implicit val cm: ClassTag[C] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]]
dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner)
}
@@ -413,7 +420,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
- * Create a new "state" DStream where the state for each key is updated by applying
+ * Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
* Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
* @param updateFunc State update function. If `this` function returns None, then
@@ -422,13 +429,13 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
*/
def updateStateByKey[S](updateFunc: JFunction2[JList[V], Optional[S], Optional[S]])
: JavaPairDStream[K, S] = {
- implicit val cm: ClassManifest[S] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[S]]
+ implicit val cm: ClassTag[S] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[S]]
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc))
}
/**
- * Create a new "state" DStream where the state for each key is updated by applying
+ * Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
* Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
* @param updateFunc State update function. If `this` function returns None, then
@@ -436,15 +443,17 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* @param numPartitions Number of partitions of each RDD in the new DStream.
* @tparam S State type
*/
- def updateStateByKey[S: ClassManifest](
+ def updateStateByKey[S](
updateFunc: JFunction2[JList[V], Optional[S], Optional[S]],
numPartitions: Int)
: JavaPairDStream[K, S] = {
+ implicit val cm: ClassTag[S] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[S]]
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), numPartitions)
}
/**
- * Create a new "state" DStream where the state for each key is updated by applying
+ * Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of the key.
* [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
@@ -452,75 +461,194 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream.
* @tparam S State type
*/
- def updateStateByKey[S: ClassManifest](
+ def updateStateByKey[S](
updateFunc: JFunction2[JList[V], Optional[S], Optional[S]],
partitioner: Partitioner
): JavaPairDStream[K, S] = {
+ implicit val cm: ClassTag[S] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[S]]
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner)
}
+
+ /**
+ * Return a new DStream by applying a map function to the value of each key-value pairs in
+ * 'this' DStream without changing the key.
+ */
def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = {
- implicit val cm: ClassManifest[U] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ implicit val cm: ClassTag[U] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
dstream.mapValues(f)
}
+ /**
+ * Return a new DStream by applying a flatmap function to the value of each key-value pairs in
+ * 'this' DStream without changing the key.
+ */
def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairDStream[K, U] = {
import scala.collection.JavaConverters._
def fn = (x: V) => f.apply(x).asScala
- implicit val cm: ClassManifest[U] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ implicit val cm: ClassTag[U] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
dstream.flatMapValues(fn)
}
/**
- * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this`
- * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
- * key in both RDDs. HashPartitioner is used to partition each generated RDD into default number
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with Spark's default number
* of partitions.
*/
def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JList[V], JList[W])] = {
- implicit val cm: ClassManifest[W] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]]
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
dstream.cogroup(other.dstream).mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2))))
}
/**
- * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this`
- * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that
- * key in both RDDs. Partitioner is used to partition each generated RDD.
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
*/
- def cogroup[W](other: JavaPairDStream[K, W], partitioner: Partitioner)
- : JavaPairDStream[K, (JList[V], JList[W])] = {
- implicit val cm: ClassManifest[W] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]]
+ def cogroup[W](
+ other: JavaPairDStream[K, W],
+ numPartitions: Int
+ ): JavaPairDStream[K, (JList[V], JList[W])] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ dstream.cogroup(other.dstream, numPartitions)
+ .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2))))
+ }
+
+ /**
+ * Return a new DStream by applying 'cogroup' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ */
+ def cogroup[W](
+ other: JavaPairDStream[K, W],
+ partitioner: Partitioner
+ ): JavaPairDStream[K, (JList[V], JList[W])] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
dstream.cogroup(other.dstream, partitioner)
- .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2))))
+ .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2))))
}
/**
- * Join `this` DStream with `other` DStream. HashPartitioner is used
- * to partition each generated RDD into default number of partitions.
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
*/
def join[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, W)] = {
- implicit val cm: ClassManifest[W] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]]
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
dstream.join(other.dstream)
}
/**
- * Join `this` DStream with `other` DStream, that is, each RDD of the new DStream will
- * be generated by joining RDDs from `this` and other DStream. Uses the given
- * Partitioner to partition each generated RDD.
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
*/
- def join[W](other: JavaPairDStream[K, W], partitioner: Partitioner)
- : JavaPairDStream[K, (V, W)] = {
- implicit val cm: ClassManifest[W] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[W]]
+ def join[W](other: JavaPairDStream[K, W], numPartitions: Int): JavaPairDStream[K, (V, W)] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ dstream.join(other.dstream, numPartitions)
+ }
+
+ /**
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * The supplied [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD.
+ */
+ def join[W](
+ other: JavaPairDStream[K, W],
+ partitioner: Partitioner
+ ): JavaPairDStream[K, (V, W)] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
dstream.join(other.dstream, partitioner)
}
/**
+ * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default
+ * number of partitions.
+ */
+ def leftOuterJoin[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (V, Optional[W])] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.leftOuterJoin(other.dstream)
+ joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}
+ }
+
+ /**
+ * Return a new DStream by applying 'left outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions`
+ * partitions.
+ */
+ def leftOuterJoin[W](
+ other: JavaPairDStream[K, W],
+ numPartitions: Int
+ ): JavaPairDStream[K, (V, Optional[W])] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.leftOuterJoin(other.dstream, numPartitions)
+ joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}
+ }
+
+ /**
+ * Return a new DStream by applying 'join' between RDDs of `this` DStream and `other` DStream.
+ * The supplied [[org.apache.spark.Partitioner]] is used to control the partitioning of each RDD.
+ */
+ def leftOuterJoin[W](
+ other: JavaPairDStream[K, W],
+ partitioner: Partitioner
+ ): JavaPairDStream[K, (V, Optional[W])] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.leftOuterJoin(other.dstream, partitioner)
+ joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))}
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with Spark's default
+ * number of partitions.
+ */
+ def rightOuterJoin[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (Optional[V], W)] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.rightOuterJoin(other.dstream)
+ joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. Hash partitioning is used to generate the RDDs with `numPartitions`
+ * partitions.
+ */
+ def rightOuterJoin[W](
+ other: JavaPairDStream[K, W],
+ numPartitions: Int
+ ): JavaPairDStream[K, (Optional[V], W)] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.rightOuterJoin(other.dstream, numPartitions)
+ joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}
+ }
+
+ /**
+ * Return a new DStream by applying 'right outer join' between RDDs of `this` DStream and
+ * `other` DStream. The supplied [[org.apache.spark.Partitioner]] is used to control
+ * the partitioning of each RDD.
+ */
+ def rightOuterJoin[W](
+ other: JavaPairDStream[K, W],
+ partitioner: Partitioner
+ ): JavaPairDStream[K, (Optional[V], W)] = {
+ implicit val cm: ClassTag[W] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+ val joinResult = dstream.rightOuterJoin(other.dstream, partitioner)
+ joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}
+ }
+
+ /**
* Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
* generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
*/
@@ -590,24 +718,29 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf)
}
- override val classManifest: ClassManifest[(K, V)] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+ /** Convert to a JavaDStream */
+ def toJavaDStream(): JavaDStream[(K, V)] = {
+ new JavaDStream[(K, V)](dstream)
+ }
+
+ override val classTag: ClassTag[(K, V)] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]]
}
object JavaPairDStream {
- implicit def fromPairDStream[K: ClassManifest, V: ClassManifest](dstream: DStream[(K, V)])
- :JavaPairDStream[K, V] =
+ implicit def fromPairDStream[K: ClassTag, V: ClassTag](dstream: DStream[(K, V)]) = {
new JavaPairDStream[K, V](dstream)
+ }
def fromJavaDStream[K, V](dstream: JavaDStream[(K, V)]): JavaPairDStream[K, V] = {
- implicit val cmk: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val cmv: ClassManifest[V] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ implicit val cmk: ClassTag[K] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val cmv: ClassTag[V] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
new JavaPairDStream[K, V](dstream.dstream)
}
- def scalaToJavaLong[K: ClassManifest](dstream: JavaPairDStream[K, Long])
+ def scalaToJavaLong[K: ClassTag](dstream: JavaPairDStream[K, Long])
: JavaPairDStream[K, JLong] = {
StreamingContext.toPairDStreamFunctions(dstream.dstream).mapValues(new JLong(_))
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index 54ba3e6025..78d318cf27 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -17,26 +17,29 @@
package org.apache.spark.streaming.api.java
-import java.lang.{Long => JLong, Integer => JInt}
+import java.lang.{Integer => JInt}
import java.io.InputStream
-import java.util.{Map => JMap}
+import java.util.{Map => JMap, List => JList}
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import twitter4j.Status
import akka.actor.Props
import akka.actor.SupervisorStrategy
import akka.zeromq.Subscribe
+import akka.util.ByteString
+
import twitter4j.auth.Authorization
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
-import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
+import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaRDD}
import org.apache.spark.streaming._
import org.apache.spark.streaming.dstream._
-import org.apache.spark.streaming.receivers.{ActorReceiver, ReceiverSupervisorStrategy}
+import org.apache.spark.streaming.scheduler.StreamingListener
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
@@ -141,11 +144,12 @@ class JavaStreamingContext(val ssc: StreamingContext) {
zkQuorum: String,
groupId: String,
topics: JMap[String, JInt])
- : JavaDStream[String] = {
- implicit val cmt: ClassManifest[String] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
+ : JavaPairDStream[String, String] = {
+ implicit val cmt: ClassTag[String] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]]
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
StorageLevel.MEMORY_ONLY_SER_2)
+
}
/**
@@ -162,34 +166,43 @@ class JavaStreamingContext(val ssc: StreamingContext) {
groupId: String,
topics: JMap[String, JInt],
storageLevel: StorageLevel)
- : JavaDStream[String] = {
- implicit val cmt: ClassManifest[String] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
+ : JavaPairDStream[String, String] = {
+ implicit val cmt: ClassTag[String] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[String]]
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
storageLevel)
}
/**
* Create an input stream that pulls messages form a Kafka Broker.
- * @param typeClass Type of RDD
- * @param decoderClass Type of kafka decoder
+ * @param keyTypeClass Key type of RDD
+ * @param valueTypeClass value type of RDD
+ * @param keyDecoderClass Type of kafka key decoder
+ * @param valueDecoderClass Type of kafka value decoder
* @param kafkaParams Map of kafka configuration paramaters.
* See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param storageLevel RDD storage level. Defaults to memory-only
*/
- def kafkaStream[T, D <: kafka.serializer.Decoder[_]](
- typeClass: Class[T],
- decoderClass: Class[D],
+ def kafkaStream[K, V, U <: kafka.serializer.Decoder[_], T <: kafka.serializer.Decoder[_]](
+ keyTypeClass: Class[K],
+ valueTypeClass: Class[V],
+ keyDecoderClass: Class[U],
+ valueDecoderClass: Class[T],
kafkaParams: JMap[String, String],
topics: JMap[String, JInt],
storageLevel: StorageLevel)
- : JavaDStream[T] = {
- implicit val cmt: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]]
- ssc.kafkaStream[T, D](
+ : JavaPairDStream[K, V] = {
+ implicit val keyCmt: ClassTag[K] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val valueCmt: ClassTag[V] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
+
+ implicit val keyCmd: Manifest[U] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[U]]
+ implicit val valueCmd: Manifest[T] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[T]]
+
+ ssc.kafkaStream[K, V, U, T](
kafkaParams.toMap,
Map(topics.mapValues(_.intValue()).toSeq: _*),
storageLevel)
@@ -237,8 +250,8 @@ class JavaStreamingContext(val ssc: StreamingContext) {
storageLevel: StorageLevel)
: JavaDStream[T] = {
def fn = (x: InputStream) => converter.apply(x).toIterator
- implicit val cmt: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cmt: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
ssc.socketStream(hostname, port, fn, storageLevel)
}
@@ -266,8 +279,8 @@ class JavaStreamingContext(val ssc: StreamingContext) {
hostname: String,
port: Int,
storageLevel: StorageLevel): JavaDStream[T] = {
- implicit val cmt: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cmt: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
JavaDStream.fromDStream(ssc.rawSocketStream(hostname, port, storageLevel))
}
@@ -281,8 +294,8 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @tparam T Type of the objects in the received blocks
*/
def rawSocketStream[T](hostname: String, port: Int): JavaDStream[T] = {
- implicit val cmt: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cmt: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
JavaDStream.fromDStream(ssc.rawSocketStream(hostname, port))
}
@@ -296,13 +309,13 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @tparam F Input format for reading HDFS file
*/
def fileStream[K, V, F <: NewInputFormat[K, V]](directory: String): JavaPairDStream[K, V] = {
- implicit val cmk: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val cmv: ClassManifest[V] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
- implicit val cmf: ClassManifest[F] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[F]]
- ssc.fileStream[K, V, F](directory);
+ implicit val cmk: ClassTag[K] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val cmv: ClassTag[V] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
+ implicit val cmf: ClassTag[F] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[F]]
+ ssc.fileStream[K, V, F](directory)
}
/**
@@ -396,7 +409,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
def twitterStream(): JavaDStream[Status] = {
ssc.twitterStream()
}
-
+
/**
* Create an input stream with any arbitrary user implemented actor receiver.
* @param props Props object defining creation of the actor
@@ -414,8 +427,8 @@ class JavaStreamingContext(val ssc: StreamingContext) {
storageLevel: StorageLevel,
supervisorStrategy: SupervisorStrategy
): JavaDStream[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
ssc.actorStream[T](props, name, storageLevel, supervisorStrategy)
}
@@ -435,8 +448,8 @@ class JavaStreamingContext(val ssc: StreamingContext) {
name: String,
storageLevel: StorageLevel
): JavaDStream[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
ssc.actorStream[T](props, name, storageLevel)
}
@@ -454,8 +467,8 @@ class JavaStreamingContext(val ssc: StreamingContext) {
props: Props,
name: String
): JavaDStream[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
ssc.actorStream[T](props, name)
}
@@ -472,12 +485,12 @@ class JavaStreamingContext(val ssc: StreamingContext) {
def zeroMQStream[T](
publisherUrl:String,
subscribe: Subscribe,
- bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T],
+ bytesToObjects: Seq[ByteString] ⇒ Iterator[T],
storageLevel: StorageLevel,
supervisorStrategy: SupervisorStrategy
): JavaDStream[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
ssc.zeroMQStream[T](publisherUrl, subscribe, bytesToObjects, storageLevel, supervisorStrategy)
}
@@ -497,9 +510,9 @@ class JavaStreamingContext(val ssc: StreamingContext) {
bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]],
storageLevel: StorageLevel
): JavaDStream[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- def fn(x: Seq[Seq[Byte]]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
+ def fn(x: Seq[ByteString]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator
ssc.zeroMQStream[T](publisherUrl, subscribe, fn, storageLevel)
}
@@ -517,9 +530,9 @@ class JavaStreamingContext(val ssc: StreamingContext) {
subscribe: Subscribe,
bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]]
): JavaDStream[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- def fn(x: Seq[Seq[Byte]]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
+ def fn(x: Seq[ByteString]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator
ssc.zeroMQStream[T](publisherUrl, subscribe, fn)
}
@@ -539,8 +552,8 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @tparam T Type of objects in the RDD
*/
def queueStream[T](queue: java.util.Queue[JavaRDD[T]]): JavaDStream[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
val sQueue = new scala.collection.mutable.Queue[RDD[T]]
sQueue.enqueue(queue.map(_.rdd).toSeq: _*)
ssc.queueStream(sQueue)
@@ -556,8 +569,8 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @tparam T Type of objects in the RDD
*/
def queueStream[T](queue: java.util.Queue[JavaRDD[T]], oneAtATime: Boolean): JavaDStream[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
val sQueue = new scala.collection.mutable.Queue[RDD[T]]
sQueue.enqueue(queue.map(_.rdd).toSeq: _*)
ssc.queueStream(sQueue, oneAtATime)
@@ -577,14 +590,85 @@ class JavaStreamingContext(val ssc: StreamingContext) {
queue: java.util.Queue[JavaRDD[T]],
oneAtATime: Boolean,
defaultRDD: JavaRDD[T]): JavaDStream[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
val sQueue = new scala.collection.mutable.Queue[RDD[T]]
sQueue.enqueue(queue.map(_.rdd).toSeq: _*)
ssc.queueStream(sQueue, oneAtATime, defaultRDD.rdd)
}
/**
+ * Create a unified DStream from multiple DStreams of the same type and same slide duration.
+ */
+ def union[T](first: JavaDStream[T], rest: JList[JavaDStream[T]]): JavaDStream[T] = {
+ val dstreams: Seq[DStream[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream)
+ implicit val cm: ClassTag[T] = first.classTag
+ ssc.union(dstreams)(cm)
+ }
+
+ /**
+ * Create a unified DStream from multiple DStreams of the same type and same slide duration.
+ */
+ def union[K, V](
+ first: JavaPairDStream[K, V],
+ rest: JList[JavaPairDStream[K, V]]
+ ): JavaPairDStream[K, V] = {
+ val dstreams: Seq[DStream[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.dstream)
+ implicit val cm: ClassTag[(K, V)] = first.classTag
+ implicit val kcm: ClassTag[K] = first.kManifest
+ implicit val vcm: ClassTag[V] = first.vManifest
+ new JavaPairDStream[K, V](ssc.union(dstreams)(cm))(kcm, vcm)
+ }
+
+ /**
+ * Create a new DStream in which each RDD is generated by applying a function on RDDs of
+ * the DStreams. The order of the JavaRDDs in the transform function parameter will be the
+ * same as the order of corresponding DStreams in the list. Note that for adding a
+ * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream().
+ * In the transform function, convert the JavaRDD corresponding to that JavaDStream to
+ * a JavaPairRDD using [[org.apache.spark.api.java.JavaPairRDD]].fromJavaRDD().
+ */
+ def transform[T](
+ dstreams: JList[JavaDStream[_]],
+ transformFunc: JFunction2[JList[JavaRDD[_]], Time, JavaRDD[T]]
+ ): JavaDStream[T] = {
+ implicit val cmt: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
+ val scalaDStreams = dstreams.map(_.dstream).toSeq
+ val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
+ val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList
+ transformFunc.call(jrdds, time).rdd
+ }
+ ssc.transform(scalaDStreams, scalaTransformFunc)
+ }
+
+ /**
+ * Create a new DStream in which each RDD is generated by applying a function on RDDs of
+ * the DStreams. The order of the JavaRDDs in the transform function parameter will be the
+ * same as the order of corresponding DStreams in the list. Note that for adding a
+ * JavaPairDStream in the list of JavaDStreams, convert it to a JavaDStream using
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream]].toJavaDStream().
+ * In the transform function, convert the JavaRDD corresponding to that JavaDStream to
+ * a JavaPairRDD using [[org.apache.spark.api.java.JavaPairRDD]].fromJavaRDD().
+ */
+ def transform[K, V](
+ dstreams: JList[JavaDStream[_]],
+ transformFunc: JFunction2[JList[JavaRDD[_]], Time, JavaPairRDD[K, V]]
+ ): JavaPairDStream[K, V] = {
+ implicit val cmk: ClassTag[K] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val cmv: ClassTag[V] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
+ val scalaDStreams = dstreams.map(_.dstream).toSeq
+ val scalaTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
+ val jrdds = rdds.map(rdd => JavaRDD.fromRDD[AnyRef](rdd.asInstanceOf[RDD[AnyRef]])).toList
+ transformFunc.call(jrdds, time).rdd
+ }
+ ssc.transform(scalaDStreams, scalaTransformFunc)
+ }
+
+ /**
* Sets the context to periodically checkpoint the DStream operations for master
* fault-tolerance. The graph will be checkpointed every batch interval.
* @param directory HDFS-compatible directory where the checkpoint data will be reliably stored
@@ -604,6 +688,13 @@ class JavaStreamingContext(val ssc: StreamingContext) {
ssc.remember(duration)
}
+ /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
+ * receiving system events related to streaming.
+ */
+ def addStreamingListener(streamingListener: StreamingListener) {
+ ssc.addStreamingListener(streamingListener)
+ }
+
/**
* Starts the execution of the streams.
*/
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala
deleted file mode 100644
index 4eddc755b9..0000000000
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/CoGroupedDStream.scala
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.dstream
-
-import org.apache.spark.Partitioner
-import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.CoGroupedRDD
-import org.apache.spark.streaming.{Time, DStream, Duration}
-
-private[streaming]
-class CoGroupedDStream[K : ClassManifest](
- parents: Seq[DStream[(K, _)]],
- partitioner: Partitioner
- ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) {
-
- if (parents.length == 0) {
- throw new IllegalArgumentException("Empty array of parents")
- }
-
- if (parents.map(_.ssc).distinct.size > 1) {
- throw new IllegalArgumentException("Array of parents have different StreamingContexts")
- }
-
- if (parents.map(_.slideDuration).distinct.size > 1) {
- throw new IllegalArgumentException("Array of parents have different slide times")
- }
-
- override def dependencies = parents.toList
-
- override def slideDuration: Duration = parents.head.slideDuration
-
- override def compute(validTime: Time): Option[RDD[(K, Seq[Seq[_]])]] = {
- val part = partitioner
- val rdds = parents.flatMap(_.getOrCompute(validTime))
- if (rdds.size > 0) {
- val q = new CoGroupedRDD[K](rdds, part)
- Some(q)
- } else {
- None
- }
- }
-
-}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala
index a9a05c9981..f396c34758 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala
@@ -19,11 +19,12 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Time, StreamingContext}
+import scala.reflect.ClassTag
/**
* An input stream that always returns the same RDD on each timestep. Useful for testing.
*/
-class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T])
+class ConstantInputDStream[T: ClassTag](ssc_ : StreamingContext, rdd: RDD[T])
extends InputDStream[T](ssc_) {
override def start() {}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index fea0573b77..39e25239bf 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -26,14 +26,16 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import scala.collection.mutable.{HashSet, HashMap}
+import scala.reflect.ClassTag
+
import java.io.{ObjectInputStream, IOException}
private[streaming]
-class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest](
+class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : ClassTag](
@transient ssc_ : StreamingContext,
directory: String,
filter: Path => Boolean = FileInputDStream.defaultFilter,
- newFilesOnly: Boolean = true)
+ newFilesOnly: Boolean = true)
extends InputDStream[(K, V)](ssc_) {
protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData
@@ -54,7 +56,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K
}
logDebug("LastModTime initialized to " + lastModTime + ", new files only = " + newFilesOnly)
}
-
+
override def stop() { }
/**
@@ -100,7 +102,7 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K
latestModTimeFiles += path.toString
logDebug("Accepted " + path)
return true
- }
+ }
}
}
logDebug("Finding new files at time " + validTime + " for last mod time = " + lastModTime)
@@ -195,5 +197,3 @@ private[streaming]
object FileInputDStream {
def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".")
}
-
-
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
index 91ee2c1a36..db2e0a4cee 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala
@@ -19,9 +19,10 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.streaming.{Duration, DStream, Time}
import org.apache.spark.rdd.RDD
+import scala.reflect.ClassTag
private[streaming]
-class FilteredDStream[T: ClassManifest](
+class FilteredDStream[T: ClassTag](
parent: DStream[T],
filterFunc: T => Boolean
) extends DStream[T](parent.ssc) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
index ca7d7ca49e..244dc3ee4f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala
@@ -20,9 +20,10 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.streaming.{Duration, DStream, Time}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
+import scala.reflect.ClassTag
private[streaming]
-class FlatMapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
+class FlatMapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag](
parent: DStream[(K, V)],
flatMapValueFunc: V => TraversableOnce[U]
) extends DStream[(K, U)](parent.ssc) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
index b37966f9a7..336c4b7a92 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala
@@ -19,9 +19,10 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.streaming.{Duration, DStream, Time}
import org.apache.spark.rdd.RDD
+import scala.reflect.ClassTag
private[streaming]
-class FlatMappedDStream[T: ClassManifest, U: ClassManifest](
+class FlatMappedDStream[T: ClassTag, U: ClassTag](
parent: DStream[T],
flatMapFunc: T => Traversable[U]
) extends DStream[U](parent.ssc) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala
index 18de772946..60d79175f1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala
@@ -22,6 +22,7 @@ import java.io.{ObjectInput, ObjectOutput, Externalizable}
import java.nio.ByteBuffer
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import org.apache.flume.source.avro.AvroSourceProtocol
import org.apache.flume.source.avro.AvroFlumeEvent
@@ -34,7 +35,7 @@ import org.apache.spark.util.Utils
import org.apache.spark.storage.StorageLevel
private[streaming]
-class FlumeInputDStream[T: ClassManifest](
+class FlumeInputDStream[T: ClassTag](
@transient ssc_ : StreamingContext,
host: String,
port: Int,
@@ -137,8 +138,8 @@ class FlumeReceiver(
protected override def onStart() {
val responder = new SpecificResponder(
- classOf[AvroSourceProtocol], new FlumeEventServer(this));
- val server = new NettyServer(responder, new InetSocketAddress(host, port));
+ classOf[AvroSourceProtocol], new FlumeEventServer(this))
+ val server = new NettyServer(responder, new InetSocketAddress(host, port))
blockGenerator.start()
server.start()
logInfo("Flume receiver started")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
index e21bac4602..364abcde68 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala
@@ -18,10 +18,12 @@
package org.apache.spark.streaming.dstream
import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.{Duration, DStream, Job, Time}
+import org.apache.spark.streaming.{Duration, DStream, Time}
+import org.apache.spark.streaming.scheduler.Job
+import scala.reflect.ClassTag
private[streaming]
-class ForEachDStream[T: ClassManifest] (
+class ForEachDStream[T: ClassTag] (
parent: DStream[T],
foreachFunc: (RDD[T], Time) => Unit
) extends DStream[Unit](parent.ssc) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
index 4294b07d91..23136f44fa 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala
@@ -19,9 +19,10 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.streaming.{Duration, DStream, Time}
import org.apache.spark.rdd.RDD
+import scala.reflect.ClassTag
private[streaming]
-class GlommedDStream[T: ClassManifest](parent: DStream[T])
+class GlommedDStream[T: ClassTag](parent: DStream[T])
extends DStream[Array[T]](parent.ssc) {
override def dependencies = List(parent)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
index 674b27118c..f01e67fe13 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala
@@ -19,6 +19,8 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.streaming.{Time, Duration, StreamingContext, DStream}
+import scala.reflect.ClassTag
+
/**
* This is the abstract base class for all input streams. This class provides to methods
* start() and stop() which called by the scheduler to start and stop receiving data/
@@ -30,7 +32,7 @@ import org.apache.spark.streaming.{Time, Duration, StreamingContext, DStream}
* that requires running a receiver on the worker nodes, use NetworkInputDStream
* as the parent class.
*/
-abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext)
+abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext)
extends DStream[T](ssc_) {
var lastValidTime: Time = null
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala
index 51e913675d..526f5564c7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/KafkaInputDStream.scala
@@ -19,52 +19,55 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{Time, DStreamCheckpointData, StreamingContext}
+import org.apache.spark.streaming.StreamingContext
import java.util.Properties
import java.util.concurrent.Executors
import kafka.consumer._
-import kafka.message.{Message, MessageSet, MessageAndMetadata}
import kafka.serializer.Decoder
-import kafka.utils.{Utils, ZKGroupTopicDirs}
-import kafka.utils.ZkUtils._
+import kafka.utils.VerifiableProperties
import kafka.utils.ZKStringSerializer
import org.I0Itec.zkclient._
import scala.collection.Map
-import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions._
-
+import scala.reflect.ClassTag
/**
* Input stream that pulls messages from a Kafka Broker.
- *
+ *
* @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param storageLevel RDD storage level.
*/
private[streaming]
-class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest](
+class KafkaInputDStream[
+ K: ClassTag,
+ V: ClassTag,
+ U <: Decoder[_]: Manifest,
+ T <: Decoder[_]: Manifest](
@transient ssc_ : StreamingContext,
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
- ) extends NetworkInputDStream[T](ssc_ ) with Logging {
+ ) extends NetworkInputDStream[(K, V)](ssc_) with Logging {
-
- def getReceiver(): NetworkReceiver[T] = {
- new KafkaReceiver[T, D](kafkaParams, topics, storageLevel)
- .asInstanceOf[NetworkReceiver[T]]
+ def getReceiver(): NetworkReceiver[(K, V)] = {
+ new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel)
+ .asInstanceOf[NetworkReceiver[(K, V)]]
}
}
private[streaming]
-class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
- kafkaParams: Map[String, String],
- topics: Map[String, Int],
- storageLevel: StorageLevel
+class KafkaReceiver[
+ K: ClassTag,
+ V: ClassTag,
+ U <: Decoder[_]: Manifest,
+ T <: Decoder[_]: Manifest](
+ kafkaParams: Map[String, String],
+ topics: Map[String, Int],
+ storageLevel: StorageLevel
) extends NetworkReceiver[Any] {
// Handles pushing data into the BlockManager
@@ -83,27 +86,35 @@ class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
// In case we are using multiple Threads to handle Kafka Messages
val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
- logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid"))
+ logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("group.id"))
// Kafka connection properties
val props = new Properties()
kafkaParams.foreach(param => props.put(param._1, param._2))
// Create the connection to the cluster
- logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect"))
+ logInfo("Connecting to Zookeper: " + kafkaParams("zookeeper.connect"))
val consumerConfig = new ConsumerConfig(props)
consumerConnector = Consumer.create(consumerConfig)
- logInfo("Connected to " + kafkaParams("zk.connect"))
+ logInfo("Connected to " + kafkaParams("zookeeper.connect"))
// When autooffset.reset is defined, it is our responsibility to try and whack the
// consumer group zk node.
- if (kafkaParams.contains("autooffset.reset")) {
- tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid"))
+ if (kafkaParams.contains("auto.offset.reset")) {
+ tryZookeeperConsumerGroupCleanup(kafkaParams("zookeeper.connect"), kafkaParams("group.id"))
}
+ val keyDecoder = manifest[U].runtimeClass.getConstructor(classOf[VerifiableProperties])
+ .newInstance(consumerConfig.props)
+ .asInstanceOf[Decoder[K]]
+ val valueDecoder = manifest[T].runtimeClass.getConstructor(classOf[VerifiableProperties])
+ .newInstance(consumerConfig.props)
+ .asInstanceOf[Decoder[V]]
+
// Create Threads for each Topic/Message Stream we are listening
- val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]]
- val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder)
+ val topicMessageStreams = consumerConnector.createMessageStreams(
+ topics, keyDecoder, valueDecoder)
+
// Start the messages handler for each partition
topicMessageStreams.values.foreach { streams =>
@@ -112,11 +123,12 @@ class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
}
// Handles Kafka Messages
- private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable {
+ private class MessageHandler[K: ClassTag, V: ClassTag](stream: KafkaStream[K, V])
+ extends Runnable {
def run() {
logInfo("Starting MessageHandler.")
for (msgAndMetadata <- stream) {
- blockGenerator += msgAndMetadata.message
+ blockGenerator += (msgAndMetadata.key, msgAndMetadata.message)
}
}
}
@@ -135,7 +147,7 @@ class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
zk.deleteRecursive(dir)
zk.close()
} catch {
- case _ => // swallow
+ case _ : Throwable => // swallow
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MQTTInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MQTTInputDStream.scala
new file mode 100644
index 0000000000..ef4a737568
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MQTTInputDStream.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.dstream
+
+import org.apache.spark.Logging
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{ Time, DStreamCheckpointData, StreamingContext }
+
+import java.util.Properties
+import java.util.concurrent.Executors
+import java.io.IOException
+
+import org.eclipse.paho.client.mqttv3.MqttCallback
+import org.eclipse.paho.client.mqttv3.MqttClient
+import org.eclipse.paho.client.mqttv3.MqttClientPersistence
+import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence
+import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken
+import org.eclipse.paho.client.mqttv3.MqttException
+import org.eclipse.paho.client.mqttv3.MqttMessage
+import org.eclipse.paho.client.mqttv3.MqttTopic
+
+import scala.collection.Map
+import scala.collection.mutable.HashMap
+import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
+
+/**
+ * Input stream that subscribe messages from a Mqtt Broker.
+ * Uses eclipse paho as MqttClient http://www.eclipse.org/paho/
+ * @param brokerUrl Url of remote mqtt publisher
+ * @param topic topic name to subscribe to
+ * @param storageLevel RDD storage level.
+ */
+
+private[streaming]
+class MQTTInputDStream[T: ClassTag](
+ @transient ssc_ : StreamingContext,
+ brokerUrl: String,
+ topic: String,
+ storageLevel: StorageLevel
+ ) extends NetworkInputDStream[T](ssc_) with Logging {
+
+ def getReceiver(): NetworkReceiver[T] = {
+ new MQTTReceiver(brokerUrl, topic, storageLevel)
+ .asInstanceOf[NetworkReceiver[T]]
+ }
+}
+
+private[streaming]
+class MQTTReceiver(brokerUrl: String,
+ topic: String,
+ storageLevel: StorageLevel
+ ) extends NetworkReceiver[Any] {
+ lazy protected val blockGenerator = new BlockGenerator(storageLevel)
+
+ def onStop() {
+ blockGenerator.stop()
+ }
+
+ def onStart() {
+
+ blockGenerator.start()
+
+ // Set up persistence for messages
+ var peristance: MqttClientPersistence = new MemoryPersistence()
+
+ // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance
+ var client: MqttClient = new MqttClient(brokerUrl, "MQTTSub", peristance)
+
+ // Connect to MqttBroker
+ client.connect()
+
+ // Subscribe to Mqtt topic
+ client.subscribe(topic)
+
+ // Callback automatically triggers as and when new message arrives on specified topic
+ var callback: MqttCallback = new MqttCallback() {
+
+ // Handles Mqtt message
+ override def messageArrived(arg0: String, arg1: MqttMessage) {
+ blockGenerator += new String(arg1.getPayload())
+ }
+
+ override def deliveryComplete(arg0: IMqttDeliveryToken) {
+ }
+
+ override def connectionLost(arg0: Throwable) {
+ logInfo("Connection lost " + arg0)
+ }
+ }
+
+ // Set up callback for MqttClient
+ client.setCallback(callback)
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
index 5329601a6f..8a04060e5b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala
@@ -19,9 +19,10 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.streaming.{Duration, DStream, Time}
import org.apache.spark.rdd.RDD
+import scala.reflect.ClassTag
private[streaming]
-class MapPartitionedDStream[T: ClassManifest, U: ClassManifest](
+class MapPartitionedDStream[T: ClassTag, U: ClassTag](
parent: DStream[T],
mapPartFunc: Iterator[T] => Iterator[U],
preservePartitioning: Boolean
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
index 8290df90a2..0ce364fd46 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala
@@ -20,9 +20,10 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.streaming.{Duration, DStream, Time}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
+import scala.reflect.ClassTag
private[streaming]
-class MapValuedDStream[K: ClassManifest, V: ClassManifest, U: ClassManifest](
+class MapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag](
parent: DStream[(K, V)],
mapValueFunc: V => U
) extends DStream[(K, U)](parent.ssc) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
index b1682afea3..c0b7491d09 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala
@@ -19,9 +19,10 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.streaming.{Duration, DStream, Time}
import org.apache.spark.rdd.RDD
+import scala.reflect.ClassTag
private[streaming]
-class MappedDStream[T: ClassManifest, U: ClassManifest] (
+class MappedDStream[T: ClassTag, U: ClassTag] (
parent: DStream[T],
mapFunc: T => U
) extends DStream[U](parent.ssc) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index 8d3ac0fc65..5add20871e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -21,17 +21,19 @@ import java.util.concurrent.ArrayBlockingQueue
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.reflect.ClassTag
import akka.actor.{Props, Actor}
import akka.pattern.ask
-import akka.dispatch.Await
-import akka.util.duration._
import org.apache.spark.streaming.util.{RecurringTimer, SystemClock}
import org.apache.spark.streaming._
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.rdd.{RDD, BlockRDD}
import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
+import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver}
/**
* Abstract class for defining any InputDStream that has to start a receiver on worker
@@ -42,7 +44,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
* @param ssc_ Streaming context that will execute this input stream
* @tparam T Class type of the object of this stream
*/
-abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : StreamingContext)
+abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingContext)
extends InputDStream[T](ssc_) {
// This is an unique identifier that is used to match the network receiver with the
@@ -84,7 +86,7 @@ private[streaming] case class ReportError(msg: String) extends NetworkReceiverMe
* Abstract class of a receiver that can be run on worker nodes to receive external data. See
* [[org.apache.spark.streaming.dstream.NetworkInputDStream]] for an explanation.
*/
-abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Logging {
+abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging {
initLogging()
@@ -176,8 +178,8 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
logInfo("Attempting to register with tracker")
val ip = System.getProperty("spark.driver.host", "localhost")
val port = System.getProperty("spark.driver.port", "7077").toInt
- val url = "akka://spark@%s:%s/user/NetworkInputTracker".format(ip, port)
- val tracker = env.actorSystem.actorFor(url)
+ val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port)
+ val tracker = env.actorSystem.actorSelection(url)
val timeout = 5.seconds
override def preStart() {
@@ -232,11 +234,11 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
logInfo("Data handler stopped")
}
- def += (obj: T) {
+ def += (obj: T): Unit = synchronized {
currentBuffer += obj
}
- private def updateCurrentBuffer(time: Long) {
+ private def updateCurrentBuffer(time: Long): Unit = synchronized {
try {
val newBlockBuffer = currentBuffer
currentBuffer = new ArrayBuffer[T]
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala
index 15782f5c11..6f9477020a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala
@@ -18,9 +18,10 @@
package org.apache.spark.streaming.dstream
import org.apache.spark.streaming.StreamingContext
+import scala.reflect.ClassTag
private[streaming]
-class PluggableInputDStream[T: ClassManifest](
+class PluggableInputDStream[T: ClassTag](
@transient ssc_ : StreamingContext,
receiver: NetworkReceiver[T]) extends NetworkInputDStream[T](ssc_) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala
index 7d9f3521b1..97325f8ea3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala
@@ -19,13 +19,13 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.UnionRDD
-
import scala.collection.mutable.Queue
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.streaming.{Time, StreamingContext}
+import scala.reflect.ClassTag
private[streaming]
-class QueueInputDStream[T: ClassManifest](
+class QueueInputDStream[T: ClassTag](
@transient ssc: StreamingContext,
val queue: Queue[RDD[T]],
oneAtATime: Boolean,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
index 10ed4ef78d..dea0f26f90 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
@@ -21,6 +21,8 @@ import org.apache.spark.Logging
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.StreamingContext
+import scala.reflect.ClassTag
+
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.{ReadableByteChannel, SocketChannel}
@@ -35,7 +37,7 @@ import java.util.concurrent.ArrayBlockingQueue
* in the format that the system is configured with.
*/
private[streaming]
-class RawInputDStream[T: ClassManifest](
+class RawInputDStream[T: ClassTag](
@transient ssc_ : StreamingContext,
host: String,
port: Int,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
index b88a4db959..db56345ca8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala
@@ -28,8 +28,11 @@ import org.apache.spark.storage.StorageLevel
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.streaming.{Duration, Interval, Time, DStream}
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
private[streaming]
-class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
+class ReducedWindowedDStream[K: ClassTag, V: ClassTag](
parent: DStream[(K, V)],
reduceFunc: (V, V) => V,
invReduceFunc: (V, V) => V,
@@ -49,7 +52,7 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
"must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")"
)
- // Reduce each batch of data using reduceByKey which will be further reduced by window
+ // Reduce each batch of data using reduceByKey which will be further reduced by window
// by ReducedWindowedDStream
val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
@@ -170,5 +173,3 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
}
}
}
-
-
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
index a95e66d761..e6e0022097 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
@@ -21,9 +21,10 @@ import org.apache.spark.Partitioner
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
import org.apache.spark.streaming.{Duration, DStream, Time}
+import scala.reflect.ClassTag
private[streaming]
-class ShuffledDStream[K: ClassManifest, V: ClassManifest, C: ClassManifest](
+class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag](
parent: DStream[(K,V)],
createCombiner: V => C,
mergeValue: (C, V) => C,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
index e2539c7396..2cdd13f205 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
@@ -21,11 +21,13 @@ import org.apache.spark.streaming.StreamingContext
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.NextIterator
+import scala.reflect.ClassTag
+
import java.io._
import java.net.Socket
private[streaming]
-class SocketInputDStream[T: ClassManifest](
+class SocketInputDStream[T: ClassTag](
@transient ssc_ : StreamingContext,
host: String,
port: Int,
@@ -39,7 +41,7 @@ class SocketInputDStream[T: ClassManifest](
}
private[streaming]
-class SocketReceiver[T: ClassManifest](
+class SocketReceiver[T: ClassTag](
host: String,
port: Int,
bytesToObjects: InputStream => Iterator[T],
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
index 362a6bf4cc..e0ff3ccba4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
@@ -23,8 +23,10 @@ import org.apache.spark.SparkContext._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Duration, Time, DStream}
+import scala.reflect.ClassTag
+
private[streaming]
-class StateDStream[K: ClassManifest, V: ClassManifest, S: ClassManifest](
+class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
parent: DStream[(K, V)],
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
index 60485adef9..aeea060df7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TransformedDStream.scala
@@ -19,18 +19,25 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Duration, DStream, Time}
+import scala.reflect.ClassTag
private[streaming]
-class TransformedDStream[T: ClassManifest, U: ClassManifest] (
- parent: DStream[T],
- transformFunc: (RDD[T], Time) => RDD[U]
- ) extends DStream[U](parent.ssc) {
+class TransformedDStream[U: ClassTag] (
+ parents: Seq[DStream[_]],
+ transformFunc: (Seq[RDD[_]], Time) => RDD[U]
+ ) extends DStream[U](parents.head.ssc) {
- override def dependencies = List(parent)
+ require(parents.length > 0, "List of DStreams to transform is empty")
+ require(parents.map(_.ssc).distinct.size == 1, "Some of the DStreams have different contexts")
+ require(parents.map(_.slideDuration).distinct.size == 1,
+ "Some of the DStreams have different slide durations")
- override def slideDuration: Duration = parent.slideDuration
+ override def dependencies = parents.toList
+
+ override def slideDuration: Duration = parents.head.slideDuration
override def compute(validTime: Time): Option[RDD[U]] = {
- parent.getOrCompute(validTime).map(transformFunc(_, validTime))
+ val parentRDDs = parents.map(_.getOrCompute(validTime).orNull).toSeq
+ Some(transformFunc(parentRDDs, validTime))
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
index c696bb70a8..0d84ec84f2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala
@@ -22,8 +22,11 @@ import org.apache.spark.rdd.RDD
import collection.mutable.ArrayBuffer
import org.apache.spark.rdd.UnionRDD
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
private[streaming]
-class UnionDStream[T: ClassManifest](parents: Array[DStream[T]])
+class UnionDStream[T: ClassTag](parents: Array[DStream[T]])
extends DStream[T](parents.head.ssc) {
if (parents.length == 0) {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
index 3c57294269..73d959331a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
@@ -22,8 +22,10 @@ import org.apache.spark.rdd.UnionRDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Duration, Interval, Time, DStream}
+import scala.reflect.ClassTag
+
private[streaming]
-class WindowedDStream[T: ClassManifest](
+class WindowedDStream[T: ClassTag](
parent: DStream[T],
_windowDuration: Duration,
_slideDuration: Duration)
@@ -52,6 +54,3 @@ class WindowedDStream[T: ClassManifest](
Some(new UnionRDD(ssc.sc, parent.slice(currentWindow)))
}
}
-
-
-
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
index ef0f85a717..fdf5371a89 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
@@ -20,6 +20,10 @@ package org.apache.spark.streaming.receivers
import akka.actor.{ Actor, PoisonPill, Props, SupervisorStrategy }
import akka.actor.{ actorRef2Scala, ActorRef }
import akka.actor.{ PossiblyHarmful, OneForOneStrategy }
+import akka.actor.SupervisorStrategy._
+
+import scala.concurrent.duration._
+import scala.reflect.ClassTag
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.dstream.NetworkReceiver
@@ -28,12 +32,9 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.ArrayBuffer
-/** A helper with set of defaults for supervisor strategy **/
+/** A helper with set of defaults for supervisor strategy */
object ReceiverSupervisorStrategy {
- import akka.util.duration._
- import akka.actor.SupervisorStrategy._
-
val defaultStrategy = OneForOneStrategy(maxNrOfRetries = 10, withinTimeRange =
15 millis) {
case _: RuntimeException ⇒ Restart
@@ -48,10 +49,10 @@ object ReceiverSupervisorStrategy {
* Find more details at: http://spark-project.org/docs/latest/streaming-custom-receivers.html
*
* @example {{{
- * class MyActor extends Actor with Receiver{
- * def receive {
- * case anything :String ⇒ pushBlock(anything)
- * }
+ * class MyActor extends Actor with Receiver{
+ * def receive {
+ * case anything :String => pushBlock(anything)
+ * }
* }
* //Can be plugged in actorStream as follows
* ssc.actorStream[String](Props(new MyActor),"MyActorReceiver")
@@ -65,11 +66,11 @@ object ReceiverSupervisorStrategy {
*
*/
trait Receiver { self: Actor ⇒
- def pushBlock[T: ClassManifest](iter: Iterator[T]) {
+ def pushBlock[T: ClassTag](iter: Iterator[T]) {
context.parent ! Data(iter)
}
- def pushBlock[T: ClassManifest](data: T) {
+ def pushBlock[T: ClassTag](data: T) {
context.parent ! Data(data)
}
@@ -83,8 +84,8 @@ case class Statistics(numberOfMsgs: Int,
numberOfHiccups: Int,
otherInfo: String)
-/** Case class to receive data sent by child actors **/
-private[streaming] case class Data[T: ClassManifest](data: T)
+/** Case class to receive data sent by child actors */
+private[streaming] case class Data[T: ClassTag](data: T)
/**
* Provides Actors as receivers for receiving stream.
@@ -95,19 +96,19 @@ private[streaming] case class Data[T: ClassManifest](data: T)
* his own Actor to run as receiver for Spark Streaming input source.
*
* This starts a supervisor actor which starts workers and also provides
- * [http://doc.akka.io/docs/akka/2.0.5/scala/fault-tolerance.html fault-tolerance].
- *
+ * [http://doc.akka.io/docs/akka/2.0.5/scala/fault-tolerance.html fault-tolerance].
+ *
* Here's a way to start more supervisor/workers as its children.
*
* @example {{{
- * context.parent ! Props(new Supervisor)
+ * context.parent ! Props(new Supervisor)
* }}} OR {{{
* context.parent ! Props(new Worker,"Worker")
* }}}
*
*
*/
-private[streaming] class ActorReceiver[T: ClassManifest](
+private[streaming] class ActorReceiver[T: ClassTag](
props: Props,
name: String,
storageLevel: StorageLevel,
@@ -120,7 +121,7 @@ private[streaming] class ActorReceiver[T: ClassManifest](
protected lazy val supervisor = env.actorSystem.actorOf(Props(new Supervisor),
"Supervisor" + streamId)
- private class Supervisor extends Actor {
+ class Supervisor extends Actor {
override val supervisorStrategy = receiverSupervisorStrategy
val worker = context.actorOf(props, name)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ZeroMQReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ZeroMQReceiver.scala
index 043bb8c8bf..f164d516b0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ZeroMQReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ZeroMQReceiver.scala
@@ -17,7 +17,10 @@
package org.apache.spark.streaming.receivers
+import scala.reflect.ClassTag
+
import akka.actor.Actor
+import akka.util.ByteString
import akka.zeromq._
import org.apache.spark.Logging
@@ -25,12 +28,12 @@ import org.apache.spark.Logging
/**
* A receiver to subscribe to ZeroMQ stream.
*/
-private[streaming] class ZeroMQReceiver[T: ClassManifest](publisherUrl: String,
+private[streaming] class ZeroMQReceiver[T: ClassTag](publisherUrl: String,
subscribe: Subscribe,
- bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T])
+ bytesToObjects: Seq[ByteString] ⇒ Iterator[T])
extends Actor with Receiver with Logging {
- override def preStart() = context.system.newSocket(SocketType.Sub, Listener(self),
+ override def preStart() = ZeroMQExtension(context.system).newSocket(SocketType.Sub, Listener(self),
Connect(publisherUrl), subscribe)
def receive: Receive = {
@@ -38,10 +41,10 @@ private[streaming] class ZeroMQReceiver[T: ClassManifest](publisherUrl: String,
case Connecting ⇒ logInfo("connecting ...")
case m: ZMQMessage ⇒
- logDebug("Received message for:" + m.firstFrameAsString)
+ logDebug("Received message for:" + m.frame(0))
//We ignore first frame for processing as it is the topic
- val bytes = m.frames.tail.map(_.payload)
+ val bytes = m.frames.tail
pushBlock(bytesToObjects(bytes))
case Closed ⇒ logInfo("received closed ")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
new file mode 100644
index 0000000000..4e8d07fe92
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.apache.spark.streaming.Time
+
+/**
+ * Class having information on completed batches.
+ * @param batchTime Time of the batch
+ * @param submissionTime Clock time of when jobs of this batch was submitted to
+ * the streaming scheduler queue
+ * @param processingStartTime Clock time of when the first job of this batch started processing
+ * @param processingEndTime Clock time of when the last job of this batch finished processing
+ */
+case class BatchInfo(
+ batchTime: Time,
+ submissionTime: Long,
+ processingStartTime: Option[Long],
+ processingEndTime: Option[Long]
+ ) {
+
+ /**
+ * Time taken for the first job of this batch to start processing from the time this batch
+ * was submitted to the streaming scheduler. Essentially, it is
+ * `processingStartTime` - `submissionTime`.
+ */
+ def schedulingDelay = processingStartTime.map(_ - submissionTime)
+
+ /**
+ * Time taken for the all jobs of this batch to finish processing from the time they started
+ * processing. Essentially, it is `processingEndTime` - `processingStartTime`.
+ */
+ def processingDelay = processingEndTime.zip(processingStartTime).map(x => x._1 - x._2).headOption
+
+ /**
+ * Time taken for all the jobs of this batch to finish processing from the time they
+ * were submitted. Essentially, it is `processingDelay` + `schedulingDelay`.
+ */
+ def totalDelay = schedulingDelay.zip(processingDelay).map(x => x._1 + x._2).headOption
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
index 2128b7c7a6..7341bfbc99 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Job.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala
@@ -15,13 +15,17 @@
* limitations under the License.
*/
-package org.apache.spark.streaming
+package org.apache.spark.streaming.scheduler
-import java.util.concurrent.atomic.AtomicLong
+import org.apache.spark.streaming.Time
+/**
+ * Class representing a Spark computation. It may contain multiple Spark jobs.
+ */
private[streaming]
class Job(val time: Time, func: () => _) {
- val id = Job.getNewId()
+ var id: String = _
+
def run(): Long = {
val startTime = System.currentTimeMillis
func()
@@ -29,13 +33,9 @@ class Job(val time: Time, func: () => _) {
(stopTime - startTime)
}
- override def toString = "streaming job " + id + " @ " + time
-}
-
-private[streaming]
-object Job {
- val id = new AtomicLong(0)
-
- def getNewId() = id.getAndIncrement()
-}
+ def setId(number: Int) {
+ id = "streaming job " + time + "." + number
+ }
+ override def toString = id
+} \ No newline at end of file
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index ed892e33e6..1cd0b9b0a4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -15,31 +15,34 @@
* limitations under the License.
*/
-package org.apache.spark.streaming
+package org.apache.spark.streaming.scheduler
-import util.{ManualClock, RecurringTimer, Clock}
import org.apache.spark.SparkEnv
import org.apache.spark.Logging
+import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter}
+import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
+/**
+ * This class generates jobs from DStreams as well as drives checkpointing and cleaning
+ * up DStream metadata.
+ */
private[streaming]
-class Scheduler(ssc: StreamingContext) extends Logging {
+class JobGenerator(jobScheduler: JobScheduler) extends Logging {
initLogging()
-
- val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt
- val jobManager = new JobManager(ssc, concurrentJobs)
- val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
- new CheckpointWriter(ssc.checkpointDir)
- } else {
- null
- }
-
+ val ssc = jobScheduler.ssc
val clockClass = System.getProperty(
"spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock]
val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
longTime => generateJobs(new Time(longTime)))
val graph = ssc.graph
+ lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
+ new CheckpointWriter(ssc.checkpointDir)
+ } else {
+ null
+ }
+
var latestTime: Time = null
def start() = synchronized {
@@ -48,26 +51,24 @@ class Scheduler(ssc: StreamingContext) extends Logging {
} else {
startFirstTime()
}
- logInfo("Scheduler started")
+ logInfo("JobGenerator started")
}
def stop() = synchronized {
timer.stop()
- jobManager.stop()
if (checkpointWriter != null) checkpointWriter.stop()
ssc.graph.stop()
- logInfo("Scheduler stopped")
+ logInfo("JobGenerator stopped")
}
private def startFirstTime() {
val startTime = new Time(timer.getStartTime())
graph.start(startTime - graph.batchDuration)
timer.start(startTime.milliseconds)
- logInfo("Scheduler's timer started at " + startTime)
+ logInfo("JobGenerator's timer started at " + startTime)
}
private def restart() {
-
// If manual clock is being used for testing, then
// either set the manual clock to the last checkpointed time,
// or if the property is defined set it to that time
@@ -93,35 +94,34 @@ class Scheduler(ssc: StreamingContext) extends Logging {
val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
logInfo("Batches to reschedule: " + timesToReschedule.mkString(", "))
timesToReschedule.foreach(time =>
- graph.generateJobs(time).foreach(jobManager.runJob)
+ jobScheduler.runJobs(time, graph.generateJobs(time))
)
// Restart the timer
timer.start(restartTime.milliseconds)
- logInfo("Scheduler's timer restarted at " + restartTime)
+ logInfo("JobGenerator's timer restarted at " + restartTime)
}
/** Generate jobs and perform checkpoint for the given `time`. */
- def generateJobs(time: Time) {
+ private def generateJobs(time: Time) {
SparkEnv.set(ssc.env)
logInfo("\n-----------------------------------------------------\n")
- graph.generateJobs(time).foreach(jobManager.runJob)
+ jobScheduler.runJobs(time, graph.generateJobs(time))
latestTime = time
doCheckpoint(time)
}
/**
- * Clear old metadata assuming jobs of `time` have finished processing.
- * And also perform checkpoint.
+ * On batch completion, clear old metadata and checkpoint computation.
*/
- def clearOldMetadata(time: Time) {
+ private[streaming] def onBatchCompletion(time: Time) {
ssc.graph.clearOldMetadata(time)
doCheckpoint(time)
}
/** Perform checkpoint for the give `time`. */
- def doCheckpoint(time: Time) = synchronized {
- if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
+ private def doCheckpoint(time: Time) = synchronized {
+ if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
logInfo("Checkpointing graph for time " + time)
ssc.graph.updateCheckpointData(time)
checkpointWriter.write(new Checkpoint(ssc, time))
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
new file mode 100644
index 0000000000..9511ccfbed
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkEnv
+import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors}
+import scala.collection.mutable.HashSet
+import org.apache.spark.streaming._
+
+/**
+ * This class schedules jobs to be run on Spark. It uses the JobGenerator to generate
+ * the jobs and runs them using a thread pool. Number of threads
+ */
+private[streaming]
+class JobScheduler(val ssc: StreamingContext) extends Logging {
+
+ initLogging()
+
+ val jobSets = new ConcurrentHashMap[Time, JobSet]
+ val numConcurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt
+ val executor = Executors.newFixedThreadPool(numConcurrentJobs)
+ val generator = new JobGenerator(this)
+ val listenerBus = new StreamingListenerBus()
+
+ def clock = generator.clock
+
+ def start() {
+ generator.start()
+ }
+
+ def stop() {
+ generator.stop()
+ executor.shutdown()
+ if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
+ executor.shutdownNow()
+ }
+ }
+
+ def runJobs(time: Time, jobs: Seq[Job]) {
+ if (jobs.isEmpty) {
+ logInfo("No jobs added for time " + time)
+ } else {
+ val jobSet = new JobSet(time, jobs)
+ jobSets.put(time, jobSet)
+ jobSet.jobs.foreach(job => executor.execute(new JobHandler(job)))
+ logInfo("Added jobs for time " + time)
+ }
+ }
+
+ def getPendingTimes(): Array[Time] = {
+ jobSets.keySet.toArray(new Array[Time](0))
+ }
+
+ private def beforeJobStart(job: Job) {
+ val jobSet = jobSets.get(job.time)
+ if (!jobSet.hasStarted) {
+ listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo()))
+ }
+ jobSet.beforeJobStart(job)
+ logInfo("Starting job " + job.id + " from job set of time " + jobSet.time)
+ SparkEnv.set(generator.ssc.env)
+ }
+
+ private def afterJobEnd(job: Job) {
+ val jobSet = jobSets.get(job.time)
+ jobSet.afterJobStop(job)
+ logInfo("Finished job " + job.id + " from job set of time " + jobSet.time)
+ if (jobSet.hasCompleted) {
+ jobSets.remove(jobSet.time)
+ generator.onBatchCompletion(jobSet.time)
+ logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format(
+ jobSet.totalDelay / 1000.0, jobSet.time.toString,
+ jobSet.processingDelay / 1000.0
+ ))
+ listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo()))
+ }
+ }
+
+ private[streaming]
+ class JobHandler(job: Job) extends Runnable {
+ def run() {
+ beforeJobStart(job)
+ try {
+ job.run()
+ } catch {
+ case e: Exception =>
+ logError("Running " + job + " failed", e)
+ }
+ afterJobEnd(job)
+ }
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
new file mode 100644
index 0000000000..57268674ea
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable.HashSet
+import org.apache.spark.streaming.Time
+
+/** Class representing a set of Jobs
+ * belong to the same batch.
+ */
+private[streaming]
+case class JobSet(time: Time, jobs: Seq[Job]) {
+
+ private val incompleteJobs = new HashSet[Job]()
+ var submissionTime = System.currentTimeMillis() // when this jobset was submitted
+ var processingStartTime = -1L // when the first job of this jobset started processing
+ var processingEndTime = -1L // when the last job of this jobset finished processing
+
+ jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) }
+ incompleteJobs ++= jobs
+
+ def beforeJobStart(job: Job) {
+ if (processingStartTime < 0) processingStartTime = System.currentTimeMillis()
+ }
+
+ def afterJobStop(job: Job) {
+ incompleteJobs -= job
+ if (hasCompleted) processingEndTime = System.currentTimeMillis()
+ }
+
+ def hasStarted() = (processingStartTime > 0)
+
+ def hasCompleted() = incompleteJobs.isEmpty
+
+ // Time taken to process all the jobs from the time they started processing
+ // (i.e. not including the time they wait in the streaming scheduler queue)
+ def processingDelay = processingEndTime - processingStartTime
+
+ // Time taken to process all the jobs from the time they were submitted
+ // (i.e. including the time they wait in the streaming scheduler queue)
+ def totalDelay = {
+ processingEndTime - time.milliseconds
+ }
+
+ def toBatchInfo(): BatchInfo = {
+ new BatchInfo(
+ time,
+ submissionTime,
+ if (processingStartTime >= 0 ) Some(processingStartTime) else None,
+ if (processingEndTime >= 0 ) Some(processingEndTime) else None
+ )
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
index b97fb7e6e3..abff55d77c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.streaming
+package org.apache.spark.streaming.scheduler
import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
@@ -25,12 +25,13 @@ import org.apache.spark.SparkContext._
import scala.collection.mutable.HashMap
import scala.collection.mutable.Queue
+import scala.concurrent.duration._
import akka.actor._
import akka.pattern.ask
-import akka.util.duration._
import akka.dispatch._
import org.apache.spark.storage.BlockId
+import org.apache.spark.streaming.{Time, StreamingContext}
private[streaming] sealed trait NetworkInputTrackerMessage
private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
new file mode 100644
index 0000000000..36225e190c
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import scala.collection.mutable.Queue
+import org.apache.spark.util.Distribution
+
+/** Base trait for events related to StreamingListener */
+sealed trait StreamingListenerEvent
+
+case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent
+
+case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent
+
+
+/**
+ * A listener interface for receiving information about an ongoing streaming
+ * computation.
+ */
+trait StreamingListener {
+ /**
+ * Called when processing of a batch has completed
+ */
+ def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { }
+
+ /**
+ * Called when processing of a batch has started
+ */
+ def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { }
+}
+
+
+/**
+ * A simple StreamingListener that logs summary statistics across Spark Streaming batches
+ * @param numBatchInfos Number of last batches to consider for generating statistics (default: 10)
+ */
+class StatsReportListener(numBatchInfos: Int = 10) extends StreamingListener {
+ // Queue containing latest completed batches
+ val batchInfos = new Queue[BatchInfo]()
+
+ override def onBatchCompleted(batchStarted: StreamingListenerBatchCompleted) {
+ batchInfos.enqueue(batchStarted.batchInfo)
+ if (batchInfos.size > numBatchInfos) batchInfos.dequeue()
+ printStats()
+ }
+
+ def printStats() {
+ showMillisDistribution("Total delay: ", _.totalDelay)
+ showMillisDistribution("Processing time: ", _.processingDelay)
+ }
+
+ def showMillisDistribution(heading: String, getMetric: BatchInfo => Option[Long]) {
+ org.apache.spark.scheduler.StatsReportListener.showMillisDistribution(
+ heading, extractDistribution(getMetric))
+ }
+
+ def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = {
+ Distribution(batchInfos.flatMap(getMetric(_)).map(_.toDouble))
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
new file mode 100644
index 0000000000..110a20f282
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import org.apache.spark.Logging
+import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import java.util.concurrent.LinkedBlockingQueue
+
+/** Asynchronously passes StreamingListenerEvents to registered StreamingListeners. */
+private[spark] class StreamingListenerBus() extends Logging {
+ private val listeners = new ArrayBuffer[StreamingListener]() with SynchronizedBuffer[StreamingListener]
+
+ /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
+ * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
+ private val EVENT_QUEUE_CAPACITY = 10000
+ private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY)
+ private var queueFullErrorMessageLogged = false
+
+ new Thread("StreamingListenerBus") {
+ setDaemon(true)
+ override def run() {
+ while (true) {
+ val event = eventQueue.take
+ event match {
+ case batchStarted: StreamingListenerBatchStarted =>
+ listeners.foreach(_.onBatchStarted(batchStarted))
+ case batchCompleted: StreamingListenerBatchCompleted =>
+ listeners.foreach(_.onBatchCompleted(batchCompleted))
+ case _ =>
+ }
+ }
+ }
+ }.start()
+
+ def addListener(listener: StreamingListener) {
+ listeners += listener
+ }
+
+ def post(event: StreamingListenerEvent) {
+ val eventAdded = eventQueue.offer(event)
+ if (!eventAdded && !queueFullErrorMessageLogged) {
+ logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
+ "This likely means one of the SparkListeners is too slow and cannot keep up with the " +
+ "rate at which tasks are being started by the scheduler.")
+ queueFullErrorMessageLogged = true
+ }
+ }
+
+ /**
+ * Waits until there are no more events in the queue, or until the specified time has elapsed.
+ * Used for testing only. Returns true if the queue has emptied and false is the specified time
+ * elapsed before the queue emptied.
+ */
+ def waitUntilEmpty(timeoutMillis: Int): Boolean = {
+ val finishTime = System.currentTimeMillis + timeoutMillis
+ while (!eventQueue.isEmpty()) {
+ if (System.currentTimeMillis > finishTime) {
+ return false
+ }
+ /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify
+ * add overhead in the general case. */
+ Thread.sleep(10)
+ }
+ return true
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
index 6977957126..4a3993e3e3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala
@@ -25,6 +25,7 @@ import StreamingContext._
import scala.util.Random
import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import scala.reflect.ClassTag
import java.io.{File, ObjectInputStream, IOException}
import java.util.UUID
@@ -120,7 +121,7 @@ object MasterFailureTest extends Logging {
* Tests stream operation with multiple master failures, and verifies whether the
* final set of output values is as expected or not.
*/
- def testOperation[T: ClassManifest](
+ def testOperation[T: ClassTag](
directory: String,
batchDuration: Duration,
input: Seq[String],
@@ -158,7 +159,7 @@ object MasterFailureTest extends Logging {
* and batch duration. Returns the streaming context and the directory to which
* files should be written for testing.
*/
- private def setupStreams[T: ClassManifest](
+ private def setupStreams[T: ClassTag](
directory: String,
batchDuration: Duration,
operation: DStream[String] => DStream[T]
@@ -192,7 +193,7 @@ object MasterFailureTest extends Logging {
* Repeatedly starts and kills the streaming context until timed out or
* the last expected output is generated. Finally, return
*/
- private def runStreams[T: ClassManifest](
+ private def runStreams[T: ClassTag](
ssc_ : StreamingContext,
lastExpectedOutput: T,
maxTimeToRun: Long
@@ -274,7 +275,7 @@ object MasterFailureTest extends Logging {
* duplicate batch outputs of values from the `output`. As a result, the
* expected output should not have consecutive batches with the same values as output.
*/
- private def verifyOutput[T: ClassManifest](output: Seq[T], expectedOutput: Seq[T]) {
+ private def verifyOutput[T: ClassTag](output: Seq[T], expectedOutput: Seq[T]) {
// Verify whether expected outputs do not consecutive batches with same output
for (i <- 0 until expectedOutput.size - 1) {
assert(expectedOutput(i) != expectedOutput(i+1),
@@ -305,7 +306,7 @@ object MasterFailureTest extends Logging {
* ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
*/
private[streaming]
-class TestOutputStream[T: ClassManifest](
+class TestOutputStream[T: ClassTag](
parent: DStream[T],
val output: ArrayBuffer[Seq[T]] = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
) extends ForEachDStream[T](
@@ -380,24 +381,24 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long)
val tempHadoopFile = new Path(testDir, ".tmp_" + (i+1).toString)
FileUtils.writeStringToFile(localFile, input(i).toString + "\n")
var tries = 0
- var done = false
- while (!done && tries < maxTries) {
- tries += 1
- try {
- // fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile)
- fs.copyFromLocalFile(new Path(localFile.toString), tempHadoopFile)
- fs.rename(tempHadoopFile, hadoopFile)
- done = true
- } catch {
- case ioe: IOException => {
- fs = testDir.getFileSystem(new Configuration())
- logWarning("Attempt " + tries + " at generating file " + hadoopFile + " failed.", ioe)
- }
- }
+ var done = false
+ while (!done && tries < maxTries) {
+ tries += 1
+ try {
+ // fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile)
+ fs.copyFromLocalFile(new Path(localFile.toString), tempHadoopFile)
+ fs.rename(tempHadoopFile, hadoopFile)
+ done = true
+ } catch {
+ case ioe: IOException => {
+ fs = testDir.getFileSystem(new Configuration())
+ logWarning("Attempt " + tries + " at generating file " + hadoopFile + " failed.", ioe)
+ }
+ }
}
- if (!done)
+ if (!done)
logError("Could not generate file " + hadoopFile)
- else
+ else
logInfo("Generated file " + hadoopFile + " at " + System.currentTimeMillis)
Thread.sleep(interval)
localFile.delete()
@@ -411,5 +412,3 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long)
}
}
}
-
-
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index c0d729ff87..daeb99f5b7 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -21,27 +21,31 @@ import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.Files;
+
import kafka.serializer.StringDecoder;
+
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.spark.streaming.api.java.JavaDStreamLike;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+
import scala.Tuple2;
+import twitter4j.Status;
+
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaRDDLike;
-import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.*;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.apache.spark.streaming.dstream.SparkFlumeEvent;
import org.apache.spark.streaming.JavaTestUtils;
import org.apache.spark.streaming.JavaCheckpointTestUtils;
-import org.apache.spark.streaming.InputStreamsSuite;
import java.io.*;
import java.util.*;
@@ -50,7 +54,6 @@ import akka.actor.Props;
import akka.zeromq.Subscribe;
-
// The test suite itself is Serializable so that anonymous Function implementations can be
// serialized, as an alternative to converting these anonymous classes to static inner classes;
// see http://stackoverflow.com/questions/758570/.
@@ -85,8 +88,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(3L),
Arrays.asList(1L));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream count = stream.count();
+ JavaDStream<Integer> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<Long> count = stream.count();
JavaTestUtils.attachTestOutputStream(count);
List<List<Long>> result = JavaTestUtils.runStreams(ssc, 3, 3);
assertOrderInvariantEquals(expected, result);
@@ -102,8 +105,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(5,5),
Arrays.asList(9,4));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream letterCount = stream.map(new Function<String, Integer>() {
+ JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<Integer> letterCount = stream.map(new Function<String, Integer>() {
@Override
public Integer call(String s) throws Exception {
return s.length();
@@ -128,8 +131,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(7,8,9,4,5,6),
Arrays.asList(7,8,9));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream windowed = stream.window(new Duration(2000));
+ JavaDStream<Integer> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<Integer> windowed = stream.window(new Duration(2000));
JavaTestUtils.attachTestOutputStream(windowed);
List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 4, 4);
@@ -152,8 +155,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(7,8,9,10,11,12,13,14,15,16,17,18),
Arrays.asList(13,14,15,16,17,18));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream windowed = stream.window(new Duration(4000), new Duration(2000));
+ JavaDStream<Integer> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<Integer> windowed = stream.window(new Duration(4000), new Duration(2000));
JavaTestUtils.attachTestOutputStream(windowed);
List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 8, 4);
@@ -170,8 +173,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList("giants"),
Arrays.asList("yankees"));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream filtered = stream.filter(new Function<String, Boolean>() {
+ JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<String> filtered = stream.filter(new Function<String, Boolean>() {
@Override
public Boolean call(String s) throws Exception {
return s.contains("a");
@@ -184,6 +187,39 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void testRepartitionMorePartitions() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 2);
+ JavaDStream repartitioned = stream.repartition(4);
+ JavaTestUtils.attachTestOutputStream(repartitioned);
+ List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2);
+ Assert.assertEquals(2, result.size());
+ for (List<List<Integer>> rdd : result) {
+ Assert.assertEquals(4, rdd.size());
+ Assert.assertEquals(
+ 10, rdd.get(0).size() + rdd.get(1).size() + rdd.get(2).size() + rdd.get(3).size());
+ }
+ }
+
+ @Test
+ public void testRepartitionFewerPartitions() {
+ List<List<Integer>> inputData = Arrays.asList(
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
+ Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
+ JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 4);
+ JavaDStream repartitioned = stream.repartition(2);
+ JavaTestUtils.attachTestOutputStream(repartitioned);
+ List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2);
+ Assert.assertEquals(2, result.size());
+ for (List<List<Integer>> rdd : result) {
+ Assert.assertEquals(2, rdd.size());
+ Assert.assertEquals(10, rdd.get(0).size() + rdd.get(1).size());
+ }
+ }
+
+ @Test
public void testGlom() {
List<List<String>> inputData = Arrays.asList(
Arrays.asList("giants", "dodgers"),
@@ -193,8 +229,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(Arrays.asList("giants", "dodgers")),
Arrays.asList(Arrays.asList("yankees", "red socks")));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream glommed = stream.glom();
+ JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<List<String>> glommed = stream.glom();
JavaTestUtils.attachTestOutputStream(glommed);
List<List<List<String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
@@ -211,8 +247,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList("GIANTSDODGERS"),
Arrays.asList("YANKEESRED SOCKS"));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream mapped = stream.mapPartitions(new FlatMapFunction<Iterator<String>, String>() {
+ JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<String> mapped = stream.mapPartitions(new FlatMapFunction<Iterator<String>, String>() {
@Override
public Iterable<String> call(Iterator<String> in) {
String out = "";
@@ -223,7 +259,7 @@ public class JavaAPISuite implements Serializable {
}
});
JavaTestUtils.attachTestOutputStream(mapped);
- List<List<List<String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+ List<List<String>> result = JavaTestUtils.runStreams(ssc, 2, 2);
Assert.assertEquals(expected, result);
}
@@ -254,8 +290,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(15),
Arrays.asList(24));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream reduced = stream.reduce(new IntegerSum());
+ JavaDStream<Integer> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<Integer> reduced = stream.reduce(new IntegerSum());
JavaTestUtils.attachTestOutputStream(reduced);
List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -275,8 +311,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(39),
Arrays.asList(24));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(),
+ JavaDStream<Integer> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<Integer> reducedWindowed = stream.reduceByWindow(new IntegerSum(),
new IntegerDifference(), new Duration(2000), new Duration(1000));
JavaTestUtils.attachTestOutputStream(reducedWindowed);
List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 4, 4);
@@ -292,8 +328,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(7,8,9));
JavaSparkContext jsc = new JavaSparkContext(ssc.ssc().sc());
- JavaRDD<Integer> rdd1 = ssc.sc().parallelize(Arrays.asList(1,2,3));
- JavaRDD<Integer> rdd2 = ssc.sc().parallelize(Arrays.asList(4,5,6));
+ JavaRDD<Integer> rdd1 = ssc.sc().parallelize(Arrays.asList(1, 2, 3));
+ JavaRDD<Integer> rdd2 = ssc.sc().parallelize(Arrays.asList(4, 5, 6));
JavaRDD<Integer> rdd3 = ssc.sc().parallelize(Arrays.asList(7,8,9));
LinkedList<JavaRDD<Integer>> rdds = Lists.newLinkedList();
@@ -320,17 +356,19 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(9,10,11));
JavaDStream<Integer> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream<Integer> transformed =
- stream.transform(new Function<JavaRDD<Integer>, JavaRDD<Integer>>() {
- @Override
- public JavaRDD<Integer> call(JavaRDD<Integer> in) throws Exception {
- return in.map(new Function<Integer, Integer>() {
- @Override
- public Integer call(Integer i) throws Exception {
- return i + 2;
- }
- });
- }});
+ JavaDStream<Integer> transformed = stream.transform(
+ new Function<JavaRDD<Integer>, JavaRDD<Integer>>() {
+ @Override
+ public JavaRDD<Integer> call(JavaRDD<Integer> in) throws Exception {
+ return in.map(new Function<Integer, Integer>() {
+ @Override
+ public Integer call(Integer i) throws Exception {
+ return i + 2;
+ }
+ });
+ }
+ });
+
JavaTestUtils.attachTestOutputStream(transformed);
List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -338,6 +376,316 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void testVariousTransform() {
+ // tests whether all variations of transform can be called from Java
+
+ List<List<Integer>> inputData = Arrays.asList(Arrays.asList(1));
+ JavaDStream<Integer> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+
+ List<List<Tuple2<String, Integer>>> pairInputData =
+ Arrays.asList(Arrays.asList(new Tuple2<String, Integer>("x", 1)));
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(
+ JavaTestUtils.attachTestInputStream(ssc, pairInputData, 1));
+
+ JavaDStream<Integer> transformed1 = stream.transform(
+ new Function<JavaRDD<Integer>, JavaRDD<Integer>>() {
+ @Override
+ public JavaRDD<Integer> call(JavaRDD<Integer> in) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Integer> transformed2 = stream.transform(
+ new Function2<JavaRDD<Integer>, Time, JavaRDD<Integer>>() {
+ @Override public JavaRDD<Integer> call(JavaRDD<Integer> in, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<String, Integer> transformed3 = stream.transform(
+ new Function<JavaRDD<Integer>, JavaPairRDD<String, Integer>>() {
+ @Override public JavaPairRDD<String, Integer> call(JavaRDD<Integer> in) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<String, Integer> transformed4 = stream.transform(
+ new Function2<JavaRDD<Integer>, Time, JavaPairRDD<String, Integer>>() {
+ @Override public JavaPairRDD<String, Integer> call(JavaRDD<Integer> in, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Integer> pairTransformed1 = pairStream.transform(
+ new Function<JavaPairRDD<String, Integer>, JavaRDD<Integer>>() {
+ @Override public JavaRDD<Integer> call(JavaPairRDD<String, Integer> in) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Integer> pairTransformed2 = pairStream.transform(
+ new Function2<JavaPairRDD<String, Integer>, Time, JavaRDD<Integer>>() {
+ @Override public JavaRDD<Integer> call(JavaPairRDD<String, Integer> in, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<String, String> pairTransformed3 = pairStream.transform(
+ new Function<JavaPairRDD<String, Integer>, JavaPairRDD<String, String>>() {
+ @Override public JavaPairRDD<String, String> call(JavaPairRDD<String, Integer> in) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<String, String> pairTransformed4 = pairStream.transform(
+ new Function2<JavaPairRDD<String, Integer>, Time, JavaPairRDD<String, String>>() {
+ @Override public JavaPairRDD<String, String> call(JavaPairRDD<String, Integer> in, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ }
+
+ @Test
+ public void testTransformWith() {
+ List<List<Tuple2<String, String>>> stringStringKVStream1 = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, String>("california", "dodgers"),
+ new Tuple2<String, String>("new york", "yankees")),
+ Arrays.asList(
+ new Tuple2<String, String>("california", "sharks"),
+ new Tuple2<String, String>("new york", "rangers")));
+
+ List<List<Tuple2<String, String>>> stringStringKVStream2 = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, String>("california", "giants"),
+ new Tuple2<String, String>("new york", "mets")),
+ Arrays.asList(
+ new Tuple2<String, String>("california", "ducks"),
+ new Tuple2<String, String>("new york", "islanders")));
+
+
+ List<List<Tuple2<String, Tuple2<String, String>>>> expected = Arrays.asList(
+ Arrays.asList(
+ new Tuple2<String, Tuple2<String, String>>("california",
+ new Tuple2<String, String>("dodgers", "giants")),
+ new Tuple2<String, Tuple2<String, String>>("new york",
+ new Tuple2<String, String>("yankees", "mets"))),
+ Arrays.asList(
+ new Tuple2<String, Tuple2<String, String>>("california",
+ new Tuple2<String, String>("sharks", "ducks")),
+ new Tuple2<String, Tuple2<String, String>>("new york",
+ new Tuple2<String, String>("rangers", "islanders"))));
+
+ JavaDStream<Tuple2<String, String>> stream1 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream1, 1);
+ JavaPairDStream<String, String> pairStream1 = JavaPairDStream.fromJavaDStream(stream1);
+
+ JavaDStream<Tuple2<String, String>> stream2 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream2, 1);
+ JavaPairDStream<String, String> pairStream2 = JavaPairDStream.fromJavaDStream(stream2);
+
+ JavaPairDStream<String, Tuple2<String, String>> joined = pairStream1.transformWith(
+ pairStream2,
+ new Function3<
+ JavaPairRDD<String, String>,
+ JavaPairRDD<String, String>,
+ Time,
+ JavaPairRDD<String, Tuple2<String, String>>
+ >() {
+ @Override
+ public JavaPairRDD<String, Tuple2<String, String>> call(
+ JavaPairRDD<String, String> rdd1,
+ JavaPairRDD<String, String> rdd2,
+ Time time
+ ) throws Exception {
+ return rdd1.join(rdd2);
+ }
+ }
+ );
+
+ JavaTestUtils.attachTestOutputStream(joined);
+ List<List<Tuple2<String, Tuple2<String, String>>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+
+ @Test
+ public void testVariousTransformWith() {
+ // tests whether all variations of transformWith can be called from Java
+
+ List<List<Integer>> inputData1 = Arrays.asList(Arrays.asList(1));
+ List<List<String>> inputData2 = Arrays.asList(Arrays.asList("x"));
+ JavaDStream<Integer> stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 1);
+ JavaDStream<String> stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 1);
+
+ List<List<Tuple2<String, Integer>>> pairInputData1 =
+ Arrays.asList(Arrays.asList(new Tuple2<String, Integer>("x", 1)));
+ List<List<Tuple2<Double, Character>>> pairInputData2 =
+ Arrays.asList(Arrays.asList(new Tuple2<Double, Character>(1.0, 'x')));
+ JavaPairDStream<String, Integer> pairStream1 = JavaPairDStream.fromJavaDStream(
+ JavaTestUtils.attachTestInputStream(ssc, pairInputData1, 1));
+ JavaPairDStream<Double, Character> pairStream2 = JavaPairDStream.fromJavaDStream(
+ JavaTestUtils.attachTestInputStream(ssc, pairInputData2, 1));
+
+ JavaDStream<Double> transformed1 = stream1.transformWith(
+ stream2,
+ new Function3<JavaRDD<Integer>, JavaRDD<String>, Time, JavaRDD<Double>>() {
+ @Override
+ public JavaRDD<Double> call(JavaRDD<Integer> rdd1, JavaRDD<String> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Double> transformed2 = stream1.transformWith(
+ pairStream1,
+ new Function3<JavaRDD<Integer>, JavaPairRDD<String, Integer>, Time, JavaRDD<Double>>() {
+ @Override
+ public JavaRDD<Double> call(JavaRDD<Integer> rdd1, JavaPairRDD<String, Integer> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<Double, Double> transformed3 = stream1.transformWith(
+ stream2,
+ new Function3<JavaRDD<Integer>, JavaRDD<String>, Time, JavaPairRDD<Double, Double>>() {
+ @Override
+ public JavaPairRDD<Double, Double> call(JavaRDD<Integer> rdd1, JavaRDD<String> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<Double, Double> transformed4 = stream1.transformWith(
+ pairStream1,
+ new Function3<JavaRDD<Integer>, JavaPairRDD<String, Integer>, Time, JavaPairRDD<Double, Double>>() {
+ @Override
+ public JavaPairRDD<Double, Double> call(JavaRDD<Integer> rdd1, JavaPairRDD<String, Integer> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Double> pairTransformed1 = pairStream1.transformWith(
+ stream2,
+ new Function3<JavaPairRDD<String, Integer>, JavaRDD<String>, Time, JavaRDD<Double>>() {
+ @Override
+ public JavaRDD<Double> call(JavaPairRDD<String, Integer> rdd1, JavaRDD<String> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaDStream<Double> pairTransformed2_ = pairStream1.transformWith(
+ pairStream1,
+ new Function3<JavaPairRDD<String, Integer>, JavaPairRDD<String, Integer>, Time, JavaRDD<Double>>() {
+ @Override
+ public JavaRDD<Double> call(JavaPairRDD<String, Integer> rdd1, JavaPairRDD<String, Integer> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<Double, Double> pairTransformed3 = pairStream1.transformWith(
+ stream2,
+ new Function3<JavaPairRDD<String, Integer>, JavaRDD<String>, Time, JavaPairRDD<Double, Double>>() {
+ @Override
+ public JavaPairRDD<Double, Double> call(JavaPairRDD<String, Integer> rdd1, JavaRDD<String> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+
+ JavaPairDStream<Double, Double> pairTransformed4 = pairStream1.transformWith(
+ pairStream2,
+ new Function3<JavaPairRDD<String, Integer>, JavaPairRDD<Double, Character>, Time, JavaPairRDD<Double, Double>>() {
+ @Override
+ public JavaPairRDD<Double, Double> call(JavaPairRDD<String, Integer> rdd1, JavaPairRDD<Double, Character> rdd2, Time time) throws Exception {
+ return null;
+ }
+ }
+ );
+ }
+
+ @Test
+ public void testStreamingContextTransform(){
+ List<List<Integer>> stream1input = Arrays.asList(
+ Arrays.asList(1),
+ Arrays.asList(2)
+ );
+
+ List<List<Integer>> stream2input = Arrays.asList(
+ Arrays.asList(3),
+ Arrays.asList(4)
+ );
+
+ List<List<Tuple2<Integer, String>>> pairStream1input = Arrays.asList(
+ Arrays.asList(new Tuple2<Integer, String>(1, "x")),
+ Arrays.asList(new Tuple2<Integer, String>(2, "y"))
+ );
+
+ List<List<Tuple2<Integer, Tuple2<Integer, String>>>> expected = Arrays.asList(
+ Arrays.asList(new Tuple2<Integer, Tuple2<Integer, String>>(1, new Tuple2<Integer, String>(1, "x"))),
+ Arrays.asList(new Tuple2<Integer, Tuple2<Integer, String>>(2, new Tuple2<Integer, String>(2, "y")))
+ );
+
+ JavaDStream<Integer> stream1 = JavaTestUtils.attachTestInputStream(ssc, stream1input, 1);
+ JavaDStream<Integer> stream2 = JavaTestUtils.attachTestInputStream(ssc, stream2input, 1);
+ JavaPairDStream<Integer, String> pairStream1 = JavaPairDStream.fromJavaDStream(
+ JavaTestUtils.attachTestInputStream(ssc, pairStream1input, 1));
+
+ List<JavaDStream<?>> listOfDStreams1 = Arrays.<JavaDStream<?>>asList(stream1, stream2);
+
+ // This is just to test whether this transform to JavaStream compiles
+ JavaDStream<Long> transformed1 = ssc.transform(
+ listOfDStreams1,
+ new Function2<List<JavaRDD<?>>, Time, JavaRDD<Long>>() {
+ public JavaRDD<Long> call(List<JavaRDD<?>> listOfRDDs, Time time) {
+ assert(listOfRDDs.size() == 2);
+ return null;
+ }
+ }
+ );
+
+ List<JavaDStream<?>> listOfDStreams2 =
+ Arrays.<JavaDStream<?>>asList(stream1, stream2, pairStream1.toJavaDStream());
+
+ JavaPairDStream<Integer, Tuple2<Integer, String>> transformed2 = ssc.transform(
+ listOfDStreams2,
+ new Function2<List<JavaRDD<?>>, Time, JavaPairRDD<Integer, Tuple2<Integer, String>>>() {
+ public JavaPairRDD<Integer, Tuple2<Integer, String>> call(List<JavaRDD<?>> listOfRDDs, Time time) {
+ assert(listOfRDDs.size() == 3);
+ JavaRDD<Integer> rdd1 = (JavaRDD<Integer>)listOfRDDs.get(0);
+ JavaRDD<Integer> rdd2 = (JavaRDD<Integer>)listOfRDDs.get(1);
+ JavaRDD<Tuple2<Integer, String>> rdd3 = (JavaRDD<Tuple2<Integer, String>>)listOfRDDs.get(2);
+ JavaPairRDD<Integer, String> prdd3 = JavaPairRDD.fromJavaRDD(rdd3);
+ PairFunction<Integer, Integer, Integer> mapToTuple = new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i, i);
+ }
+ };
+ return rdd1.union(rdd2).map(mapToTuple).join(prdd3);
+ }
+ }
+ );
+ JavaTestUtils.attachTestOutputStream(transformed2);
+ List<List<Tuple2<Integer, Tuple2<Integer, String>>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
public void testFlatMap() {
List<List<String>> inputData = Arrays.asList(
Arrays.asList("go", "giants"),
@@ -349,8 +697,8 @@ public class JavaAPISuite implements Serializable {
Arrays.asList("b", "o", "o", "d","o","d","g","e","r","s"),
Arrays.asList("a","t","h","l","e","t","i","c","s"));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream flatMapped = stream.flatMap(new FlatMapFunction<String, String>() {
+ JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<String> flatMapped = stream.flatMap(new FlatMapFunction<String, String>() {
@Override
public Iterable<String> call(String x) {
return Lists.newArrayList(x.split("(?!^)"));
@@ -396,8 +744,8 @@ public class JavaAPISuite implements Serializable {
new Tuple2<Integer, String>(9, "c"),
new Tuple2<Integer, String>(9, "s")));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaPairDStream flatMapped = stream.flatMap(new PairFlatMapFunction<String, Integer, String>() {
+ JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<Integer,String> flatMapped = stream.flatMap(new PairFlatMapFunction<String, Integer, String>() {
@Override
public Iterable<Tuple2<Integer, String>> call(String in) throws Exception {
List<Tuple2<Integer, String>> out = Lists.newArrayList();
@@ -430,10 +778,10 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(2,2,5,5),
Arrays.asList(3,3,6,6));
- JavaDStream stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 2);
- JavaDStream stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 2);
+ JavaDStream<Integer> stream1 = JavaTestUtils.attachTestInputStream(ssc, inputData1, 2);
+ JavaDStream<Integer> stream2 = JavaTestUtils.attachTestInputStream(ssc, inputData2, 2);
- JavaDStream unioned = stream1.union(stream2);
+ JavaDStream<Integer> unioned = stream1.union(stream2);
JavaTestUtils.attachTestOutputStream(unioned);
List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 3, 3);
@@ -444,7 +792,7 @@ public class JavaAPISuite implements Serializable {
* Performs an order-invariant comparison of lists representing two RDD streams. This allows
* us to account for ordering variation within individual RDD's which occurs during windowing.
*/
- public static <T extends Comparable> void assertOrderInvariantEquals(
+ public static <T extends Comparable<T>> void assertOrderInvariantEquals(
List<List<T>> expected, List<List<T>> actual) {
for (List<T> list: expected) {
Collections.sort(list);
@@ -467,11 +815,11 @@ public class JavaAPISuite implements Serializable {
Arrays.asList(new Tuple2<String, Integer>("giants", 6)),
Arrays.asList(new Tuple2<String, Integer>("yankees", 7)));
- JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
JavaPairDStream<String, Integer> pairStream = stream.map(
new PairFunction<String, String, Integer>() {
@Override
- public Tuple2 call(String in) throws Exception {
+ public Tuple2<String, Integer> call(String in) throws Exception {
return new Tuple2<String, Integer>(in, in.length());
}
});
@@ -1099,7 +1447,7 @@ public class JavaAPISuite implements Serializable {
JavaPairDStream<String, Tuple2<List<String>, List<String>>> grouped = pairStream1.cogroup(pairStream2);
JavaTestUtils.attachTestOutputStream(grouped);
- List<List<Tuple2<String, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+ List<List<Tuple2<String, Tuple2<List<String>, List<String>>>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
Assert.assertEquals(expected, result);
}
@@ -1142,7 +1490,38 @@ public class JavaAPISuite implements Serializable {
JavaPairDStream<String, Tuple2<String, String>> joined = pairStream1.join(pairStream2);
JavaTestUtils.attachTestOutputStream(joined);
- List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+ List<List<Tuple2<String, Tuple2<String, String>>>> result = JavaTestUtils.runStreams(ssc, 2, 2);
+
+ Assert.assertEquals(expected, result);
+ }
+
+ @Test
+ public void testLeftOuterJoin() {
+ List<List<Tuple2<String, String>>> stringStringKVStream1 = Arrays.asList(
+ Arrays.asList(new Tuple2<String, String>("california", "dodgers"),
+ new Tuple2<String, String>("new york", "yankees")),
+ Arrays.asList(new Tuple2<String, String>("california", "sharks") ));
+
+ List<List<Tuple2<String, String>>> stringStringKVStream2 = Arrays.asList(
+ Arrays.asList(new Tuple2<String, String>("california", "giants") ),
+ Arrays.asList(new Tuple2<String, String>("new york", "islanders") )
+
+ );
+
+ List<List<Long>> expected = Arrays.asList(Arrays.asList(2L), Arrays.asList(1L));
+
+ JavaDStream<Tuple2<String, String>> stream1 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream1, 1);
+ JavaPairDStream<String, String> pairStream1 = JavaPairDStream.fromJavaDStream(stream1);
+
+ JavaDStream<Tuple2<String, String>> stream2 = JavaTestUtils.attachTestInputStream(
+ ssc, stringStringKVStream2, 1);
+ JavaPairDStream<String, String> pairStream2 = JavaPairDStream.fromJavaDStream(stream2);
+
+ JavaPairDStream<String, Tuple2<String, Optional<String>>> joined = pairStream1.leftOuterJoin(pairStream2);
+ JavaDStream<Long> counted = joined.count();
+ JavaTestUtils.attachTestOutputStream(counted);
+ List<List<Long>> result = JavaTestUtils.runStreams(ssc, 2, 2);
Assert.assertEquals(expected, result);
}
@@ -1163,8 +1542,8 @@ public class JavaAPISuite implements Serializable {
File tempDir = Files.createTempDir();
ssc.checkpoint(tempDir.getAbsolutePath());
- JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream letterCount = stream.map(new Function<String, Integer>() {
+ JavaDStream<String> stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaDStream<Integer> letterCount = stream.map(new Function<String, Integer>() {
@Override
public Integer call(String s) throws Exception {
return s.length();
@@ -1220,20 +1599,26 @@ public class JavaAPISuite implements Serializable {
@Test
public void testKafkaStream() {
HashMap<String, Integer> topics = Maps.newHashMap();
- JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics);
- JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics,
+ JavaPairDStream<String, String> test1 = ssc.kafkaStream("localhost:12345", "group", topics);
+ JavaPairDStream<String, String> test2 = ssc.kafkaStream("localhost:12345", "group", topics,
StorageLevel.MEMORY_AND_DISK());
HashMap<String, String> kafkaParams = Maps.newHashMap();
- kafkaParams.put("zk.connect","localhost:12345");
- kafkaParams.put("groupid","consumer-group");
- JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics,
+ kafkaParams.put("zookeeper.connect","localhost:12345");
+ kafkaParams.put("group.id","consumer-group");
+ JavaPairDStream<String, String> test3 = ssc.kafkaStream(
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ kafkaParams,
+ topics,
StorageLevel.MEMORY_AND_DISK());
}
@Test
public void testSocketTextStream() {
- JavaDStream test = ssc.socketTextStream("localhost", 12345);
+ JavaDStream<String> test = ssc.socketTextStream("localhost", 12345);
}
@Test
@@ -1253,7 +1638,7 @@ public class JavaAPISuite implements Serializable {
}
}
- JavaDStream test = ssc.socketStream(
+ JavaDStream<String> test = ssc.socketStream(
"localhost",
12345,
new Converter(),
@@ -1262,39 +1647,39 @@ public class JavaAPISuite implements Serializable {
@Test
public void testTextFileStream() {
- JavaDStream test = ssc.textFileStream("/tmp/foo");
+ JavaDStream<String> test = ssc.textFileStream("/tmp/foo");
}
@Test
public void testRawSocketStream() {
- JavaDStream test = ssc.rawSocketStream("localhost", 12345);
+ JavaDStream<String> test = ssc.rawSocketStream("localhost", 12345);
}
@Test
public void testFlumeStream() {
- JavaDStream test = ssc.flumeStream("localhost", 12345, StorageLevel.MEMORY_ONLY());
+ JavaDStream<SparkFlumeEvent> test = ssc.flumeStream("localhost", 12345, StorageLevel.MEMORY_ONLY());
}
@Test
public void testFileStream() {
JavaPairDStream<String, String> foo =
- ssc.<String, String, SequenceFileInputFormat>fileStream("/tmp/foo");
+ ssc.<String, String, SequenceFileInputFormat<String,String>>fileStream("/tmp/foo");
}
@Test
public void testTwitterStream() {
String[] filters = new String[] { "good", "bad", "ugly" };
- JavaDStream test = ssc.twitterStream(filters, StorageLevel.MEMORY_ONLY());
+ JavaDStream<Status> test = ssc.twitterStream(filters, StorageLevel.MEMORY_ONLY());
}
@Test
public void testActorStream() {
- JavaDStream test = ssc.actorStream((Props)null, "TestActor", StorageLevel.MEMORY_ONLY());
+ JavaDStream<String> test = ssc.actorStream((Props)null, "TestActor", StorageLevel.MEMORY_ONLY());
}
@Test
public void testZeroMQStream() {
- JavaDStream test = ssc.zeroMQStream("url", (Subscribe) null, new Function<byte[][], Iterable<String>>() {
+ JavaDStream<String> test = ssc.zeroMQStream("url", (Subscribe) null, new Function<byte[][], Iterable<String>>() {
@Override
public Iterable<String> call(byte[][] b) throws Exception {
return null;
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
index 8a6604904d..42ab9590d6 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
@@ -17,7 +17,9 @@
package org.apache.spark.streaming
-import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import scala.reflect.ClassTag
+
import java.util.{List => JList}
import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStreamLike, JavaDStream, JavaStreamingContext}
import org.apache.spark.streaming._
@@ -31,15 +33,15 @@ trait JavaTestBase extends TestSuiteBase {
/**
* Create a [[org.apache.spark.streaming.TestInputStream]] and attach it to the supplied context.
* The stream will be derived from the supplied lists of Java objects.
- **/
+ */
def attachTestInputStream[T](
- ssc: JavaStreamingContext,
- data: JList[JList[T]],
- numPartitions: Int) = {
+ ssc: JavaStreamingContext,
+ data: JList[JList[T]],
+ numPartitions: Int) = {
val seqData = data.map(Seq(_:_*))
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
val dstream = new TestInputStream[T](ssc.ssc, seqData, numPartitions)
ssc.ssc.registerInputStream(dstream)
new JavaDStream[T](dstream)
@@ -50,12 +52,11 @@ trait JavaTestBase extends TestSuiteBase {
* [[org.apache.spark.streaming.TestOutputStream]].
**/
def attachTestOutputStream[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]](
- dstream: JavaDStreamLike[T, This, R]) =
+ dstream: JavaDStreamLike[T, This, R]) =
{
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
- val ostream = new TestOutputStream(dstream.dstream,
- new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]])
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
+ val ostream = new TestOutputStreamWithPartitions(dstream.dstream)
dstream.dstream.ssc.registerOutputStream(ostream)
}
@@ -63,16 +64,39 @@ trait JavaTestBase extends TestSuiteBase {
* Process all registered streams for a numBatches batches, failing if
* numExpectedOutput RDD's are not generated. Generated RDD's are collected
* and returned, represented as a list for each batch interval.
+ *
+ * Returns a list of items for each RDD.
*/
def runStreams[V](
- ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = {
- implicit val cm: ClassManifest[V] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = {
+ implicit val cm: ClassTag[V] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput)
val out = new ArrayList[JList[V]]()
res.map(entry => out.append(new ArrayList[V](entry)))
out
}
+
+ /**
+ * Process all registered streams for a numBatches batches, failing if
+ * numExpectedOutput RDD's are not generated. Generated RDD's are collected
+ * and returned, represented as a list for each batch interval.
+ *
+ * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each
+ * representing one partition.
+ */
+ def runStreamsWithPartitions[V](ssc: JavaStreamingContext, numBatches: Int,
+ numExpectedOutput: Int): JList[JList[JList[V]]] = {
+ implicit val cm: ClassTag[V] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
+ val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput)
+ val out = new ArrayList[JList[JList[V]]]()
+ res.map{entry =>
+ val lists = entry.map(new ArrayList[V](_))
+ out.append(new ArrayList[JList[V]](lists))
+ }
+ out
+ }
}
object JavaTestUtils extends JavaTestBase {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 11586f72b6..b35ca00b53 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -18,22 +18,13 @@
package org.apache.spark.streaming
import org.apache.spark.streaming.StreamingContext._
-import scala.runtime.RichInt
-import util.ManualClock
-
-class BasicOperationsSuite extends TestSuiteBase {
- override def framework() = "BasicOperationsSuite"
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext._
- before {
- System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
- }
+import util.ManualClock
- after {
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
- }
+class BasicOperationsSuite extends TestSuiteBase {
test("map") {
val input = Seq(1 to 4, 5 to 8, 9 to 12)
@@ -82,6 +73,44 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(input, operation, output, true)
}
+ test("repartition (more partitions)") {
+ val input = Seq(1 to 100, 101 to 200, 201 to 300)
+ val operation = (r: DStream[Int]) => r.repartition(5)
+ val ssc = setupStreams(input, operation, 2)
+ val output = runStreamsWithPartitions(ssc, 3, 3)
+ assert(output.size === 3)
+ val first = output(0)
+ val second = output(1)
+ val third = output(2)
+
+ assert(first.size === 5)
+ assert(second.size === 5)
+ assert(third.size === 5)
+
+ assert(first.flatten.toSet === (1 to 100).toSet)
+ assert(second.flatten.toSet === (101 to 200).toSet)
+ assert(third.flatten.toSet === (201 to 300).toSet)
+ }
+
+ test("repartition (fewer partitions)") {
+ val input = Seq(1 to 100, 101 to 200, 201 to 300)
+ val operation = (r: DStream[Int]) => r.repartition(2)
+ val ssc = setupStreams(input, operation, 5)
+ val output = runStreamsWithPartitions(ssc, 3, 3)
+ assert(output.size === 3)
+ val first = output(0)
+ val second = output(1)
+ val third = output(2)
+
+ assert(first.size === 2)
+ assert(second.size === 2)
+ assert(third.size === 2)
+
+ assert(first.flatten.toSet === (1 to 100).toSet)
+ assert(second.flatten.toSet === (101 to 200).toSet)
+ assert(third.flatten.toSet === (201 to 300).toSet)
+ }
+
test("groupByKey") {
testOperation(
Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
@@ -143,6 +172,72 @@ class BasicOperationsSuite extends TestSuiteBase {
)
}
+ test("union") {
+ val input = Seq(1 to 4, 101 to 104, 201 to 204)
+ val output = Seq(1 to 8, 101 to 108, 201 to 208)
+ testOperation(
+ input,
+ (s: DStream[Int]) => s.union(s.map(_ + 4)) ,
+ output
+ )
+ }
+
+ test("StreamingContext.union") {
+ val input = Seq(1 to 4, 101 to 104, 201 to 204)
+ val output = Seq(1 to 12, 101 to 112, 201 to 212)
+ // union over 3 DStreams
+ testOperation(
+ input,
+ (s: DStream[Int]) => s.context.union(Seq(s, s.map(_ + 4), s.map(_ + 8))),
+ output
+ )
+ }
+
+ test("transform") {
+ val input = Seq(1 to 4, 5 to 8, 9 to 12)
+ testOperation(
+ input,
+ (r: DStream[Int]) => r.transform(rdd => rdd.map(_.toString)), // RDD.map in transform
+ input.map(_.map(_.toString))
+ )
+ }
+
+ test("transformWith") {
+ val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") )
+ val outputData = Seq(
+ Seq( ("a", (1, "x")), ("b", (1, "x")) ),
+ Seq( ("", (1, "x")) ),
+ Seq( ),
+ Seq( )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ val t1 = s1.map(x => (x, 1))
+ val t2 = s2.map(x => (x, "x"))
+ t1.transformWith( // RDD.join in transform
+ t2,
+ (rdd1: RDD[(String, Int)], rdd2: RDD[(String, String)]) => rdd1.join(rdd2)
+ )
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
+ test("StreamingContext.transform") {
+ val input = Seq(1 to 4, 101 to 104, 201 to 204)
+ val output = Seq(1 to 12, 101 to 112, 201 to 212)
+
+ // transform over 3 DStreams by doing union of the 3 RDDs
+ val operation = (s: DStream[Int]) => {
+ s.context.transform(
+ Seq(s, s.map(_ + 4), s.map(_ + 8)), // 3 DStreams
+ (rdds: Seq[RDD[_]], time: Time) =>
+ rdds.head.context.union(rdds.map(_.asInstanceOf[RDD[Int]])) // union of RDDs
+ )
+ }
+
+ testOperation(input, operation, output)
+ }
+
test("cogroup") {
val inputData1 = Seq( Seq("a", "a", "b"), Seq("a", ""), Seq(""), Seq() )
val inputData2 = Seq( Seq("a", "a", "b"), Seq("b", ""), Seq(), Seq() )
@@ -168,7 +263,37 @@ class BasicOperationsSuite extends TestSuiteBase {
Seq( )
)
val operation = (s1: DStream[String], s2: DStream[String]) => {
- s1.map(x => (x,1)).join(s2.map(x => (x,"x")))
+ s1.map(x => (x, 1)).join(s2.map(x => (x, "x")))
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
+ test("leftOuterJoin") {
+ val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") )
+ val outputData = Seq(
+ Seq( ("a", (1, Some("x"))), ("b", (1, Some("x"))) ),
+ Seq( ("", (1, Some("x"))), ("a", (1, None)) ),
+ Seq( ("", (1, None)) ),
+ Seq( )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ s1.map(x => (x, 1)).leftOuterJoin(s2.map(x => (x, "x")))
+ }
+ testOperation(inputData1, inputData2, operation, outputData, true)
+ }
+
+ test("rightOuterJoin") {
+ val inputData1 = Seq( Seq("a", "b"), Seq("a", ""), Seq(""), Seq() )
+ val inputData2 = Seq( Seq("a", "b"), Seq("b", ""), Seq(), Seq("") )
+ val outputData = Seq(
+ Seq( ("a", (Some(1), "x")), ("b", (Some(1), "x")) ),
+ Seq( ("", (Some(1), "x")), ("b", (None, "x")) ),
+ Seq( ),
+ Seq( ("", (None, "x")) )
+ )
+ val operation = (s1: DStream[String], s2: DStream[String]) => {
+ s1.map(x => (x, 1)).rightOuterJoin(s2.map(x => (x, "x")))
}
testOperation(inputData1, inputData2, operation, outputData, true)
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index a327de80b3..67a0841535 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -20,45 +20,45 @@ package org.apache.spark.streaming
import dstream.FileInputDStream
import org.apache.spark.streaming.StreamingContext._
import java.io.File
-import runtime.RichInt
-import org.scalatest.BeforeAndAfter
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
import org.apache.commons.io.FileUtils
-import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
-import util.{Clock, ManualClock}
-import scala.util.Random
+import org.scalatest.BeforeAndAfter
+
import com.google.common.io.Files
+import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions
+import org.apache.spark.streaming.dstream.FileInputDStream
+import org.apache.spark.streaming.util.ManualClock
+
+
/**
* This test suites tests the checkpointing functionality of DStreams -
* the checkpointing of a DStream's RDDs as well as the checkpointing of
* the whole DStream graph.
*/
-class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
+class CheckpointSuite extends TestSuiteBase {
+
+ var ssc: StreamingContext = null
+
+ override def batchDuration = Milliseconds(500)
- System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
+ override def actuallyWait = true // to allow checkpoints to be written
- before {
+ override def beforeFunction() {
+ super.beforeFunction()
FileUtils.deleteDirectory(new File(checkpointDir))
}
- after {
+ override def afterFunction() {
+ super.afterFunction()
if (ssc != null) ssc.stop()
FileUtils.deleteDirectory(new File(checkpointDir))
-
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
}
- var ssc: StreamingContext = null
-
- override def framework = "CheckpointSuite"
-
- override def batchDuration = Milliseconds(500)
-
- override def actuallyWait = true
-
test("basic rdd checkpoints + dstream graph checkpoint recovery") {
assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second")
@@ -74,13 +74,13 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
// Setup the streams
val input = (1 to 10).map(_ => Seq("a")).toSeq
val operation = (st: DStream[String]) => {
- val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
- Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ Some((values.foldLeft(0)(_ + _) + state.getOrElse(0)))
}
st.map(x => (x, 1))
- .updateStateByKey[RichInt](updateFunc)
+ .updateStateByKey(updateFunc)
.checkpoint(stateStreamCheckpointInterval)
- .map(t => (t._1, t._2.self))
+ .map(t => (t._1, t._2))
}
var ssc = setupStreams(input, operation)
var stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head
@@ -180,13 +180,13 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
val input = (1 to 10).map(_ => Seq("a")).toSeq
val output = (1 to 10).map(x => Seq(("a", x))).toSeq
val operation = (st: DStream[String]) => {
- val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
- Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ Some((values.foldLeft(0)(_ + _) + state.getOrElse(0)))
}
st.map(x => (x, 1))
- .updateStateByKey[RichInt](updateFunc)
+ .updateStateByKey(updateFunc)
.checkpoint(batchDuration * 2)
- .map(t => (t._1, t._2.self))
+ .map(t => (t._1, t._2))
}
testCheckpointedOperation(input, operation, output, 7)
}
@@ -312,7 +312,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
* NOTE: This takes into consideration that the last batch processed before
* master failure will be re-processed after restart/recovery.
*/
- def testCheckpointedOperation[U: ClassManifest, V: ClassManifest](
+ def testCheckpointedOperation[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
@@ -356,7 +356,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
* Advances the manual clock on the streaming scheduler by given number of batches.
* It also waits for the expected amount of time for each batch.
*/
- def advanceTimeWithRealDelay[V: ClassManifest](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = {
+ def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = {
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
logInfo("Manual clock before advancing = " + clock.time)
for (i <- 1 to numBatches.toInt) {
@@ -366,7 +366,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
logInfo("Manual clock after advancing = " + clock.time)
Thread.sleep(batchDuration.milliseconds)
- val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
- outputStream.output
+ val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
+ outputStream.output.map(_.flatten)
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
index 6337c5359c..da9b04de1a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
@@ -32,17 +32,22 @@ import collection.mutable.ArrayBuffer
* This testsuite tests master failures at random times while the stream is running using
* the real clock.
*/
-class FailureSuite extends FunSuite with BeforeAndAfter with Logging {
+class FailureSuite extends TestSuiteBase with Logging {
var directory = "FailureSuite"
val numBatches = 30
- val batchDuration = Milliseconds(1000)
- before {
+ override def batchDuration = Milliseconds(1000)
+
+ override def useManualClock = false
+
+ override def beforeFunction() {
+ super.beforeFunction()
FileUtils.deleteDirectory(new File(directory))
}
- after {
+ override def afterFunction() {
+ super.afterFunction()
FileUtils.deleteDirectory(new File(directory))
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index 42e3e51e3f..62a9f120b4 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -23,15 +23,15 @@ import akka.actor.IOManager
import akka.actor.Props
import akka.util.ByteString
-import dstream.SparkFlumeEvent
+import org.apache.spark.streaming.dstream.{NetworkReceiver, SparkFlumeEvent}
import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket}
import java.io.{File, BufferedWriter, OutputStreamWriter}
-import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
+import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue}
import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import util.ManualClock
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.receivers.Receiver
-import org.apache.spark.Logging
+import org.apache.spark.{SparkContext, Logging}
import scala.util.Random
import org.apache.commons.io.FileUtils
import org.scalatest.BeforeAndAfter
@@ -44,24 +44,12 @@ import java.nio.ByteBuffer
import collection.JavaConversions._
import java.nio.charset.Charset
import com.google.common.io.Files
+import java.util.concurrent.atomic.AtomicInteger
class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val testPort = 9999
- override def checkpointDir = "checkpoint"
-
- before {
- System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
- }
-
- after {
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
- }
-
-
test("socket input stream") {
// Start the server
val testServer = new TestServer()
@@ -124,9 +112,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val input = Seq(1, 2, 3, 4, 5)
Thread.sleep(1000)
- val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort));
+ val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort))
val client = SpecificRequestor.getClient(
- classOf[AvroSourceProtocol], transceiver);
+ classOf[AvroSourceProtocol], transceiver)
for (i <- 0 until input.size) {
val event = new AvroFlumeEvent
@@ -268,13 +256,56 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK)
// Test specifying decoder
- val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group")
- val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
+ val kafkaParams = Map("zookeeper.connect"->"localhost:12345","group.id"->"consumer-group")
+ val test3 = ssc.kafkaStream[
+ String,
+ String,
+ kafka.serializer.StringDecoder,
+ kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
+ }
+
+ test("multi-thread receiver") {
+ // set up the test receiver
+ val numThreads = 10
+ val numRecordsPerThread = 1000
+ val numTotalRecords = numThreads * numRecordsPerThread
+ val testReceiver = new MultiThreadTestReceiver(numThreads, numRecordsPerThread)
+ MultiThreadTestReceiver.haveAllThreadsFinished = false
+
+ // set up the network stream using the test receiver
+ val ssc = new StreamingContext(master, framework, batchDuration)
+ val networkStream = ssc.networkStream[Int](testReceiver)
+ val countStream = networkStream.count
+ val outputBuffer = new ArrayBuffer[Seq[Long]] with SynchronizedBuffer[Seq[Long]]
+ val outputStream = new TestOutputStream(countStream, outputBuffer)
+ def output = outputBuffer.flatMap(x => x)
+ ssc.registerOutputStream(outputStream)
+ ssc.start()
+
+ // Let the data from the receiver be received
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val startTime = System.currentTimeMillis()
+ while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) &&
+ System.currentTimeMillis() - startTime < 5000) {
+ Thread.sleep(100)
+ clock.addToTime(batchDuration.milliseconds)
+ }
+ Thread.sleep(1000)
+ logInfo("Stopping context")
+ ssc.stop()
+
+ // Verify whether data received was as expected
+ logInfo("--------------------------------")
+ logInfo("output.size = " + outputBuffer.size)
+ logInfo("output")
+ outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+ assert(output.sum === numTotalRecords)
}
}
-/** This is server to test the network input stream */
+/** This is a server to test the network input stream */
class TestServer() extends Logging {
val queue = new ArrayBlockingQueue[String](100)
@@ -336,6 +367,7 @@ object TestServer {
}
}
+/** This is an actor for testing actor input stream */
class TestActor(port: Int) extends Actor with Receiver {
def bytesToString(byteString: ByteString) = byteString.utf8String
@@ -347,3 +379,36 @@ class TestActor(port: Int) extends Actor with Receiver {
pushBlock(bytesToString(bytes))
}
}
+
+/** This is a receiver to test multiple threads inserting data using block generator */
+class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int)
+ extends NetworkReceiver[Int] {
+ lazy val executorPool = Executors.newFixedThreadPool(numThreads)
+ lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY)
+ lazy val finishCount = new AtomicInteger(0)
+
+ protected def onStart() {
+ blockGenerator.start()
+ (1 to numThreads).map(threadId => {
+ val runnable = new Runnable {
+ def run() {
+ (1 to numRecordsPerThread).foreach(i =>
+ blockGenerator += (threadId * numRecordsPerThread + i) )
+ if (finishCount.incrementAndGet == numThreads) {
+ MultiThreadTestReceiver.haveAllThreadsFinished = true
+ }
+ logInfo("Finished thread " + threadId)
+ }
+ }
+ executorPool.submit(runnable)
+ })
+ }
+
+ protected def onStop() {
+ executorPool.shutdown()
+ }
+}
+
+object MultiThreadTestReceiver {
+ var haveAllThreadsFinished = false
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
new file mode 100644
index 0000000000..fa64142096
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import org.apache.spark.streaming.scheduler._
+import scala.collection.mutable.ArrayBuffer
+import org.scalatest.matchers.ShouldMatchers
+
+class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers {
+
+ val input = (1 to 4).map(Seq(_)).toSeq
+ val operation = (d: DStream[Int]) => d.map(x => x)
+
+ // To make sure that the processing start and end times in collected
+ // information are different for successive batches
+ override def batchDuration = Milliseconds(100)
+ override def actuallyWait = true
+
+ test("basic BatchInfo generation") {
+ val ssc = setupStreams(input, operation)
+ val collector = new BatchInfoCollector
+ ssc.addStreamingListener(collector)
+ runStreams(ssc, input.size, input.size)
+ val batchInfos = collector.batchInfos
+ batchInfos should have size 4
+
+ batchInfos.foreach(info => {
+ info.schedulingDelay should not be None
+ info.processingDelay should not be None
+ info.totalDelay should not be None
+ info.schedulingDelay.get should be >= 0L
+ info.processingDelay.get should be >= 0L
+ info.totalDelay.get should be >= 0L
+ })
+
+ isInIncreasingOrder(batchInfos.map(_.submissionTime)) should be (true)
+ isInIncreasingOrder(batchInfos.map(_.processingStartTime.get)) should be (true)
+ isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true)
+ }
+
+ /** Check if a sequence of numbers is in increasing order */
+ def isInIncreasingOrder(seq: Seq[Long]): Boolean = {
+ for(i <- 1 until seq.size) {
+ if (seq(i - 1) > seq(i)) return false
+ }
+ true
+ }
+
+ /** Listener that collects information on processed batches */
+ class BatchInfoCollector extends StreamingListener {
+ val batchInfos = new ArrayBuffer[BatchInfo]
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
+ batchInfos += batchCompleted.batchInfo
+ }
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 37dd9c4cc6..e969e91d13 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -20,8 +20,9 @@ package org.apache.spark.streaming
import org.apache.spark.streaming.dstream.{InputDStream, ForEachDStream}
import org.apache.spark.streaming.util.ManualClock
-import collection.mutable.ArrayBuffer
-import collection.mutable.SynchronizedBuffer
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.SynchronizedBuffer
+import scala.reflect.ClassTag
import java.io.{ObjectInputStream, IOException}
@@ -35,7 +36,7 @@ import org.apache.spark.rdd.RDD
* replayable, reliable message queue like Kafka. It requires a sequence as input, and
* returns the i_th element at the i_th batch unde manual clock.
*/
-class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int)
+class TestInputStream[T: ClassTag](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int)
extends InputDStream[T](ssc_) {
def start() {}
@@ -60,8 +61,11 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[
/**
* This is a output stream just for the testsuites. All the output is collected into a
* ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ *
+ * The buffer contains a sequence of RDD's, each containing a sequence of items
*/
-class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]])
+class TestOutputStream[T: ClassTag](parent: DStream[T],
+ val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]())
extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
val collected = rdd.collect()
output += collected
@@ -76,13 +80,37 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu
}
/**
+ * This is a output stream just for the testsuites. All the output is collected into a
+ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ *
+ * The buffer contains a sequence of RDD's, each containing a sequence of partitions, each
+ * containing a sequence of items.
+ */
+class TestOutputStreamWithPartitions[T: ClassTag](parent: DStream[T],
+ val output: ArrayBuffer[Seq[Seq[T]]] = ArrayBuffer[Seq[Seq[T]]]())
+ extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
+ val collected = rdd.glom().collect().map(_.toSeq)
+ output += collected
+ }) {
+
+ // This is to clear the output buffer every it is read from a checkpoint
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ ois.defaultReadObject()
+ output.clear()
+ }
+
+ def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten))
+}
+
+/**
* This is the base trait for Spark Streaming testsuites. This provides basic functionality
* to run user-defined set of input on user-defined stream operations, and verify the output.
*/
trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Name of the framework for Spark context
- def framework = "TestSuiteBase"
+ def framework = this.getClass.getSimpleName
// Master for Spark context
def master = "local[2]"
@@ -99,16 +127,47 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Maximum time to wait before the test times out
def maxWaitTimeMillis = 10000
+ // Whether to use manual clock or not
+ def useManualClock = true
+
// Whether to actually wait in real time before changing manual clock
def actuallyWait = false
+ // Default before function for any streaming test suite. Override this
+ // if you want to add your stuff to "before" (i.e., don't call before { } )
+ def beforeFunction() {
+ if (useManualClock) {
+ System.setProperty(
+ "spark.streaming.clock",
+ "org.apache.spark.streaming.util.ManualClock"
+ )
+ } else {
+ System.clearProperty("spark.streaming.clock")
+ }
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+ }
+
+ // Default after function for any streaming test suite. Override this
+ // if you want to add your stuff to "after" (i.e., don't call after { } )
+ def afterFunction() {
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+ }
+
+ before(beforeFunction)
+ after(afterFunction)
+
/**
* Set up required DStreams to test the DStream operation using the two sequences
* of input collections.
*/
- def setupStreams[U: ClassManifest, V: ClassManifest](
+ def setupStreams[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
- operation: DStream[U] => DStream[V]
+ operation: DStream[U] => DStream[V],
+ numPartitions: Int = numInputPartitions
): StreamingContext = {
// Create StreamingContext
@@ -118,9 +177,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
}
// Setup the stream computation
- val inputStream = new TestInputStream(ssc, input, numInputPartitions)
+ val inputStream = new TestInputStream(ssc, input, numPartitions)
val operatedStream = operation(inputStream)
- val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]])
+ val outputStream = new TestOutputStreamWithPartitions(operatedStream,
+ new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
ssc.registerInputStream(inputStream)
ssc.registerOutputStream(outputStream)
ssc
@@ -130,7 +190,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
* Set up required DStreams to test the binary operation using the sequence
* of input collections.
*/
- def setupStreams[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ def setupStreams[U: ClassTag, V: ClassTag, W: ClassTag](
input1: Seq[Seq[U]],
input2: Seq[Seq[V]],
operation: (DStream[U], DStream[V]) => DStream[W]
@@ -146,7 +206,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions)
val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions)
val operatedStream = operation(inputStream1, inputStream2)
- val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]])
+ val outputStream = new TestOutputStreamWithPartitions(operatedStream,
+ new ArrayBuffer[Seq[Seq[W]]] with SynchronizedBuffer[Seq[Seq[W]]])
ssc.registerInputStream(inputStream1)
ssc.registerInputStream(inputStream2)
ssc.registerOutputStream(outputStream)
@@ -157,18 +218,37 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
* Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
* returns the collected output. It will wait until `numExpectedOutput` number of
* output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
+ *
+ * Returns a sequence of items for each RDD.
*/
- def runStreams[V: ClassManifest](
+ def runStreams[V: ClassTag](
ssc: StreamingContext,
numBatches: Int,
numExpectedOutput: Int
): Seq[Seq[V]] = {
+ // Flatten each RDD into a single Seq
+ runStreamsWithPartitions(ssc, numBatches, numExpectedOutput).map(_.flatten.toSeq)
+ }
+
+ /**
+ * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
+ * returns the collected output. It will wait until `numExpectedOutput` number of
+ * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
+ *
+ * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each
+ * representing one partition.
+ */
+ def runStreamsWithPartitions[V: ClassTag](
+ ssc: StreamingContext,
+ numBatches: Int,
+ numExpectedOutput: Int
+ ): Seq[Seq[Seq[V]]] = {
assert(numBatches > 0, "Number of batches to run stream computation is zero")
assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero")
logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
// Get the output buffer
- val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
+ val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
val output = outputStream.output
try {
@@ -202,7 +282,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
Thread.sleep(500) // Give some time for the forgetting old RDDs to complete
} catch {
- case e: Exception => e.printStackTrace(); throw e;
+ case e: Exception => {e.printStackTrace(); throw e}
} finally {
ssc.stop()
}
@@ -214,7 +294,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
* is same as the expected output values, by comparing the output
* collections either as lists (order matters) or sets (order does not matter)
*/
- def verifyOutput[V: ClassManifest](
+ def verifyOutput[V: ClassTag](
output: Seq[Seq[V]],
expectedOutput: Seq[Seq[V]],
useSet: Boolean
@@ -244,7 +324,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
* Test unary DStream operation with a list of inputs, with number of
* batches to run same as the number of expected output values
*/
- def testOperation[U: ClassManifest, V: ClassManifest](
+ def testOperation[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
@@ -262,7 +342,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
* @param useSet Compare the output values with the expected output values
* as sets (order matters) or as lists (order does not matter)
*/
- def testOperation[U: ClassManifest, V: ClassManifest](
+ def testOperation[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
@@ -279,7 +359,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
* Test binary DStream operation with two lists of inputs, with number of
* batches to run same as the number of expected output values
*/
- def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ def testOperation[U: ClassTag, V: ClassTag, W: ClassTag](
input1: Seq[Seq[U]],
input2: Seq[Seq[V]],
operation: (DStream[U], DStream[V]) => DStream[W],
@@ -299,7 +379,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
* @param useSet Compare the output values with the expected output values
* as sets (order matters) or as lists (order does not matter)
*/
- def testOperation[U: ClassManifest, V: ClassManifest, W: ClassManifest](
+ def testOperation[U: ClassTag, V: ClassTag, W: ClassTag](
input1: Seq[Seq[U]],
input2: Seq[Seq[V]],
operation: (DStream[U], DStream[V]) => DStream[W],
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
index f50e05c0d8..6b4aaefcdf 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
@@ -22,19 +22,9 @@ import collection.mutable.ArrayBuffer
class WindowOperationsSuite extends TestSuiteBase {
- System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
+ override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer
- override def framework = "WindowOperationsSuite"
-
- override def maxWaitTimeMillis = 20000
-
- override def batchDuration = Seconds(1)
-
- after {
- // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
- System.clearProperty("spark.driver.port")
- System.clearProperty("spark.hostPort")
- }
+ override def batchDuration = Seconds(1) // making sure its visible in this class
val largerSlideInput = Seq(
Seq(("a", 1)),
diff --git a/tools/pom.xml b/tools/pom.xml
index f1c489beea..28f5ef14b1 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -25,7 +25,7 @@
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-tools_2.9.3</artifactId>
+ <artifactId>spark-tools_2.10</artifactId>
<packaging>jar</packaging>
<name>Spark Project Tools</name>
<url>http://spark.incubator.apache.org/</url>
@@ -33,24 +33,24 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core_2.9.3</artifactId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-streaming_2.9.3</artifactId>
+ <artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_2.9.3</artifactId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
- <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
- <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala
index f824c472ae..f670f65bf5 100644
--- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala
@@ -199,7 +199,7 @@ object JavaAPICompletenessChecker {
private def toJavaMethod(method: SparkMethod): SparkMethod = {
val params = method.parameters
- .filterNot(_.name == "scala.reflect.ClassManifest")
+ .filterNot(_.name == "scala.reflect.ClassTag")
.map(toJavaType(_, isReturnType = false))
SparkMethod(method.name, toJavaType(method.returnType, isReturnType = true), params)
}
@@ -212,7 +212,7 @@ object JavaAPICompletenessChecker {
// internal Spark components.
val excludedNames = Seq(
"org.apache.spark.rdd.RDD.origin",
- "org.apache.spark.rdd.RDD.elementClassManifest",
+ "org.apache.spark.rdd.RDD.elementClassTag",
"org.apache.spark.rdd.RDD.checkpointData",
"org.apache.spark.rdd.RDD.partitioner",
"org.apache.spark.rdd.RDD.partitions",
diff --git a/yarn/pom.xml b/yarn/pom.xml
index 3bc619df07..bc64a190fd 100644
--- a/yarn/pom.xml
+++ b/yarn/pom.xml
@@ -25,7 +25,7 @@
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-yarn_2.9.3</artifactId>
+ <artifactId>spark-yarn_2.10</artifactId>
<packaging>jar</packaging>
<name>Spark Project YARN Support</name>
<url>http://spark.incubator.apache.org/</url>
@@ -33,7 +33,7 @@
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core_2.9.3</artifactId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
@@ -61,11 +61,21 @@
<groupId>org.apache.avro</groupId>
<artifactId>avro-ipc</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
- <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
- <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
@@ -106,6 +116,46 @@
</execution>
</executions>
</plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-antrun-plugin</artifactId>
+ <executions>
+ <execution>
+ <phase>test</phase>
+ <goals>
+ <goal>run</goal>
+ </goals>
+ <configuration>
+ <exportAntProperties>true</exportAntProperties>
+ <tasks>
+ <property name="spark.classpath" refid="maven.test.classpath" />
+ <property environment="env" />
+ <fail message="Please set the SCALA_HOME (or SCALA_LIBRARY_PATH if scala is on the path) environment variables and retry.">
+ <condition>
+ <not>
+ <or>
+ <isset property="env.SCALA_HOME" />
+ <isset property="env.SCALA_LIBRARY_PATH" />
+ </or>
+ </not>
+ </condition>
+ </fail>
+ </tasks>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ <configuration>
+ <environmentVariables>
+ <SPARK_HOME>${basedir}/..</SPARK_HOME>
+ <SPARK_TESTING>1</SPARK_TESTING>
+ <SPARK_CLASSPATH>${spark.classpath}</SPARK_CLASSPATH>
+ </environmentVariables>
+ </configuration>
+ </plugin>
</plugins>
</build>
</project>
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index c1a87d3373..240ed8b32a 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -17,14 +17,17 @@
package org.apache.spark.deploy.yarn
-import java.io.IOException;
+import java.io.IOException
import java.net.Socket
-import java.security.PrivilegedExceptionAction
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
+
+import scala.collection.JavaConversions._
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.util.ShutdownHookManager
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
@@ -32,47 +35,49 @@ import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+
import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.util.Utils
-import org.apache.hadoop.security.UserGroupInformation
-import scala.collection.JavaConversions._
+
class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
def this(args: ApplicationMasterArguments) = this(args, new Configuration())
private var rpc: YarnRPC = YarnRPC.create(conf)
- private var resourceManager: AMRMProtocol = null
- private var appAttemptId: ApplicationAttemptId = null
- private var userThread: Thread = null
+ private var resourceManager: AMRMProtocol = _
+ private var appAttemptId: ApplicationAttemptId = _
+ private var userThread: Thread = _
private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
private val fs = FileSystem.get(yarnConf)
- private var yarnAllocator: YarnAllocationHandler = null
- private var isFinished:Boolean = false
- private var uiAddress: String = ""
+ private var yarnAllocator: YarnAllocationHandler = _
+ private var isFinished: Boolean = false
+ private var uiAddress: String = _
private val maxAppAttempts: Int = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES,
YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES)
private var isLastAMRetry: Boolean = true
-
+ // default to numWorkers * 2, with minimum of 3
+ private val maxNumWorkerFailures = System.getProperty("spark.yarn.max.worker.failures",
+ math.max(args.numWorkers * 2, 3).toString()).toInt
def run() {
- // setup the directories so things go to yarn approved directories rather
- // then user specified and /tmp
+ // Setup the directories so things go to yarn approved directories rather
+ // then user specified and /tmp.
System.setProperty("spark.local.dir", getLocalDirs())
- // use priority 30 as its higher then HDFS. Its same priority as MapReduce is using
+ // Use priority 30 as its higher then HDFS. Its same priority as MapReduce is using.
ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30)
appAttemptId = getApplicationAttemptId()
- isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts;
+ isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts
resourceManager = registerWithResourceManager()
// Workaround until hadoop moves to something which has
// https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line)
- // ignore result
+ // ignore result.
// This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
- // Hence args.workerCores = numCore disabled above. Any better option ?
+ // Hence args.workerCores = numCore disabled above. Any better option?
// Compute number of threads for akka
//val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
@@ -98,7 +103,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
waitForSparkContextInitialized()
- // do this after spark master is up and SparkContext is created so that we can register UI Url
+ // Do this after spark master is up and SparkContext is created so that we can register UI Url
val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
// Allocate all containers
@@ -117,12 +122,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
// LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
.getOrElse(Option(System.getenv("LOCAL_DIRS"))
- .getOrElse(""))
+ .getOrElse(""))
if (localDirs.isEmpty()) {
throw new Exception("Yarn Local dirs can't be empty")
}
- return localDirs
+ localDirs
}
private def getApplicationAttemptId(): ApplicationAttemptId = {
@@ -131,7 +136,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
val containerId = ConverterUtils.toContainerId(containerIdString)
val appAttemptId = containerId.getApplicationAttemptId()
logInfo("ApplicationAttemptId: " + appAttemptId)
- return appAttemptId
+ appAttemptId
}
private def registerWithResourceManager(): AMRMProtocol = {
@@ -139,7 +144,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
YarnConfiguration.RM_SCHEDULER_ADDRESS,
YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
logInfo("Connecting to ResourceManager at " + rmAddress)
- return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
+ rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
}
private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
@@ -147,12 +152,13 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
.asInstanceOf[RegisterApplicationMasterRequest]
appMasterRequest.setApplicationAttemptId(appAttemptId)
- // Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
+ // Setting this to master host,port - so that the ApplicationReport at client has some
+ // sensible info.
// Users can then monitor stderr/stdout on that node if required.
appMasterRequest.setHost(Utils.localHostName())
appMasterRequest.setRpcPort(0)
appMasterRequest.setTrackingUrl(uiAddress)
- return resourceManager.registerApplicationMaster(appMasterRequest)
+ resourceManager.registerApplicationMaster(appMasterRequest)
}
private def waitForSparkMaster() {
@@ -166,28 +172,32 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
try {
val socket = new Socket(driverHost, driverPort.toInt)
socket.close()
- logInfo("Driver now available: " + driverHost + ":" + driverPort)
+ logInfo("Driver now available: %s:%s".format(driverHost, driverPort))
driverUp = true
} catch {
- case e: Exception =>
- logWarning("Failed to connect to driver at " + driverHost + ":" + driverPort + ", retrying")
- Thread.sleep(100)
- tries = tries + 1
+ case e: Exception => {
+ logWarning("Failed to connect to driver at %s:%s, retrying ...".
+ format(driverHost, driverPort))
+ Thread.sleep(100)
+ tries = tries + 1
+ }
}
}
}
private def startUserClass(): Thread = {
logInfo("Starting the user JAR in a separate Thread")
- val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader)
- .getMethod("main", classOf[Array[String]])
+ val mainMethod = Class.forName(
+ args.userClass,
+ false /* initialize */,
+ Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]])
val t = new Thread {
override def run() {
var successed = false
try {
// Copy
- var mainArgs: Array[String] = new Array[String](args.userArgs.size())
- args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size())
+ var mainArgs: Array[String] = new Array[String](args.userArgs.size)
+ args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size)
mainMethod.invoke(null, mainArgs)
// some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR
// userThread will stop here unless it has uncaught exception thrown out
@@ -195,7 +205,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
successed = true
} finally {
logDebug("finishing main")
- isLastAMRetry = true;
+ isLastAMRetry = true
if (successed) {
ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
} else {
@@ -205,7 +215,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
}
}
t.start()
- return t
+ t
}
// this need to happen before allocateWorkers
@@ -227,12 +237,20 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
if (null != sparkContext) {
uiAddress = sparkContext.ui.appUIAddress
- this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args,
- sparkContext.preferredNodeLocationData)
+ this.yarnAllocator = YarnAllocationHandler.newAllocator(
+ yarnConf,
+ resourceManager,
+ appAttemptId,
+ args,
+ sparkContext.preferredNodeLocationData)
} else {
- logWarning("Unable to retrieve sparkContext inspite of waiting for " + count * waitTime +
- ", numTries = " + numTries)
- this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args)
+ logWarning("Unable to retrieve sparkContext inspite of waiting for %d, numTries = %d".
+ format(count * waitTime, numTries))
+ this.yarnAllocator = YarnAllocationHandler.newAllocator(
+ yarnConf,
+ resourceManager,
+ appAttemptId,
+ args)
}
}
} finally {
@@ -248,44 +266,57 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
// Wait until all containers have finished
// TODO: This is a bit ugly. Can we make it nicer?
// TODO: Handle container failure
- while(yarnAllocator.getNumWorkersRunning < args.numWorkers &&
- // If user thread exists, then quit !
- userThread.isAlive) {
- this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
- ApplicationMaster.incrementAllocatorLoop(1)
- Thread.sleep(100)
+ // Exists the loop if the user thread exits.
+ while (yarnAllocator.getNumWorkersRunning < args.numWorkers && userThread.isAlive) {
+ if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) {
+ finishApplicationMaster(FinalApplicationStatus.FAILED,
+ "max number of worker failures reached")
+ }
+ yarnAllocator.allocateContainers(
+ math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
+ ApplicationMaster.incrementAllocatorLoop(1)
+ Thread.sleep(100)
}
} finally {
- // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT :
- // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks
+ // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT,
+ // so that the loop in ApplicationMaster#sparkContextInitialized() breaks.
ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
}
logInfo("All workers have launched.")
- // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
+ // Launch a progress reporter thread, else the app will get killed after expiration
+ // (def: 10mins) timeout.
+ // TODO(harvey): Verify the timeout
if (userThread.isAlive) {
- // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
-
+ // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses.
val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
- // must be <= timeoutInterval/ 2.
- // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
- // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
- val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+
+ // we want to be reasonably responsive without causing too many requests to RM.
+ val schedulerInterval =
+ System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong
+
+ // must be <= timeoutInterval / 2.
+ val interval = math.min(timeoutInterval / 2, schedulerInterval)
+
launchReporterThread(interval)
}
}
- // TODO: We might want to extend this to allocate more containers in case they die !
private def launchReporterThread(_sleepTime: Long): Thread = {
val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
val t = new Thread {
override def run() {
while (userThread.isAlive) {
+ if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) {
+ finishApplicationMaster(FinalApplicationStatus.FAILED,
+ "max number of worker failures reached")
+ }
val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
if (missingWorkerCount > 0) {
- logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
+ logInfo("Allocating %d containers to make up for (potentially) lost containers".
+ format(missingWorkerCount))
yarnAllocator.allocateContainers(missingWorkerCount)
}
else sendProgress()
@@ -293,16 +324,16 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
}
}
}
- // setting to daemon status, though this is usually not a good idea.
+ // Setting to daemon status, though this is usually not a good idea.
t.setDaemon(true)
t.start()
logInfo("Started progress reporter thread - sleep time : " + sleepTime)
- return t
+ t
}
private def sendProgress() {
logDebug("Sending progress")
- // simulated with an allocate request with no nodes requested ...
+ // Simulated with an allocate request with no nodes requested ...
yarnAllocator.allocateContainers(0)
}
@@ -321,8 +352,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
}
*/
- def finishApplicationMaster(status: FinalApplicationStatus) {
-
+ def finishApplicationMaster(status: FinalApplicationStatus, diagnostics: String = "") {
synchronized {
if (isFinished) {
return
@@ -335,21 +365,21 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
.asInstanceOf[FinishApplicationMasterRequest]
finishReq.setAppAttemptId(appAttemptId)
finishReq.setFinishApplicationStatus(status)
- // set tracking url to empty since we don't have a history server
+ finishReq.setDiagnostics(diagnostics)
+ // Set tracking url to empty since we don't have a history server.
finishReq.setTrackingUrl("")
resourceManager.finishApplicationMaster(finishReq)
-
}
/**
- * clean up the staging directory.
+ * Clean up the staging directory.
*/
private def cleanupStagingDir() {
var stagingDirPath: Path = null
try {
val preserveFiles = System.getProperty("spark.yarn.preserve.staging.files", "false").toBoolean
if (!preserveFiles) {
- stagingDirPath = new Path(System.getenv("SPARK_YARN_JAR_PATH")).getParent()
+ stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR"))
if (stagingDirPath == null) {
logError("Staging directory is null")
return
@@ -358,13 +388,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
fs.delete(stagingDirPath, true)
}
} catch {
- case e: IOException =>
- logError("Failed to cleanup staging dir " + stagingDirPath, e)
+ case ioe: IOException =>
+ logError("Failed to cleanup staging dir " + stagingDirPath, ioe)
}
}
- // The shutdown hook that runs when a signal is received AND during normal
- // close of the JVM.
+ // The shutdown hook that runs when a signal is received AND during normal close of the JVM.
class AppMasterShutdownHook(appMaster: ApplicationMaster) extends Runnable {
def run() {
@@ -374,15 +403,14 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e
if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir()
}
}
-
}
object ApplicationMaster {
- // number of times to wait for the allocator loop to complete.
- // each loop iteration waits for 100ms, so maximum of 3 seconds.
+ // Number of times to wait for the allocator loop to complete.
+ // Each loop iteration waits for 100ms, so maximum of 3 seconds.
// This is to ensure that we have reasonable number of containers before we start
- // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more
- // containers are available. Might need to handle this better.
+ // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
+ // optimal as more containers are available. Might need to handle this better.
private val ALLOCATOR_LOOP_WAIT_COUNT = 30
def incrementAllocatorLoop(by: Int) {
val count = yarnAllocatorLoop.getAndAdd(by)
@@ -400,7 +428,8 @@ object ApplicationMaster {
applicationMasters.add(master)
}
- val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null)
+ val sparkContextRef: AtomicReference[SparkContext] =
+ new AtomicReference[SparkContext](null /* initialValue */)
val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
def sparkContextInitialized(sc: SparkContext): Boolean = {
@@ -410,19 +439,21 @@ object ApplicationMaster {
sparkContextRef.notifyAll()
}
- // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
- // Should not really have to do this, but it helps yarn to evict resources earlier.
- // not to mention, prevent Client declaring failure even though we exit'ed properly.
- // Note that this will unfortunately not properly clean up the staging files because it gets called to
- // late and the filesystem is already shutdown.
+ // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do
+ // System.exit.
+ // Should not really have to do this, but it helps YARN to evict resources earlier.
+ // Not to mention, prevent the Client from declaring failure even though we exited properly.
+ // Note that this will unfortunately not properly clean up the staging files because it gets
+ // called too late, after the filesystem is already shutdown.
if (modified) {
Runtime.getRuntime().addShutdownHook(new Thread with Logging {
- // This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
+ // This is not only logs, but also ensures that log system is initialized for this instance
+ // when we are actually 'run'-ing.
logInfo("Adding shutdown hook for context " + sc)
override def run() {
logInfo("Invoking sc stop from shutdown hook")
sc.stop()
- // best case ...
+ // Best case ...
for (master <- applicationMasters) {
master.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
}
@@ -430,7 +461,7 @@ object ApplicationMaster {
} )
}
- // Wait for initialization to complete and atleast 'some' nodes can get allocated
+ // Wait for initialization to complete and atleast 'some' nodes can get allocated.
yarnAllocatorLoop.synchronized {
while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) {
yarnAllocatorLoop.wait(1000L)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 8afb3e39cb..79dd038065 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -17,43 +17,54 @@
package org.apache.spark.deploy.yarn
-import java.net.{InetSocketAddress, URI}
+import java.net.{InetAddress, UnknownHostException, URI}
import java.nio.ByteBuffer
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
+
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.fs.{FileContext, FileStatus, FileSystem, Path, FileUtil}
+import org.apache.hadoop.fs.permission.FsPermission;
+import org.apache.hadoop.io.DataOutputBuffer
import org.apache.hadoop.mapred.Master
import org.apache.hadoop.net.NetUtils
-import org.apache.hadoop.io.DataOutputBuffer
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.api._
-import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.YarnClientImpl
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
-import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions._
+import org.apache.hadoop.yarn.util.{Apps, Records}
+
import org.apache.spark.Logging
import org.apache.spark.util.Utils
-import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils}
-import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.spark.deploy.SparkHadoopUtil
+
class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging {
-
+
def this(args: ClientArguments) = this(new Configuration(), args)
-
+
var rpc: YarnRPC = YarnRPC.create(conf)
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
val credentials = UserGroupInformation.getCurrentUser().getCredentials()
- private var distFiles = None: Option[String]
- private var distFilesTimeStamps = None: Option[String]
- private var distFilesFileSizes = None: Option[String]
- private var distArchives = None: Option[String]
- private var distArchivesTimeStamps = None: Option[String]
- private var distArchivesFileSizes = None: Option[String]
-
- def run() {
+ private val SPARK_STAGING: String = ".sparkStaging"
+ private val distCacheMgr = new ClientDistributedCacheManager()
+
+ // Staging directory is private! -> rwx--------
+ val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short)
+
+ // App files are world-wide readable and owner writable -> rw-r--r--
+ val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short)
+
+ // for client user who want to monitor app status by itself.
+ def runApp() = {
+ validateArgs()
+
init(yarnConf)
start()
logClusterResourceDetails()
@@ -63,8 +74,9 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
verifyClusterResources(newApp)
val appContext = createApplicationSubmissionContext(appId)
- val localResources = prepareLocalResources(appId, ".sparkStaging")
- val env = setupLaunchEnv(localResources)
+ val appStagingDir = getAppStagingDir(appId)
+ val localResources = prepareLocalResources(appStagingDir)
+ val env = setupLaunchEnv(localResources, appStagingDir)
val amContainer = createContainerLaunchContext(newApp, localResources, env)
appContext.setQueue(args.amQueue)
@@ -72,30 +84,59 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
appContext.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
submitApp(appContext)
-
+ appId
+ }
+
+ def run() {
+ val appId = runApp()
monitorApplication(appId)
System.exit(0)
}
-
+
+ def validateArgs() = {
+ Map(
+ (System.getenv("SPARK_JAR") == null) -> "Error: You must set SPARK_JAR environment variable!",
+ (args.userJar == null) -> "Error: You must specify a user jar!",
+ (args.userClass == null) -> "Error: You must specify a user class!",
+ (args.numWorkers <= 0) -> "Error: You must specify atleast 1 worker!",
+ (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: AM memory size must be " +
+ "greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD),
+ (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Worker memory size " +
+ "must be greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD)
+ ).foreach { case(cond, errStr) =>
+ if (cond) {
+ logError(errStr)
+ args.printUsageAndExit(1)
+ }
+ }
+ }
+
+ def getAppStagingDir(appId: ApplicationId): String = {
+ SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR
+ }
def logClusterResourceDetails() {
val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
- logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers)
+ logInfo("Got Cluster metric info from ASM, numNodeManagers = " +
+ clusterMetrics.getNumNodeManagers)
val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue)
- logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity +
- ", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size +
- ", queueChildQueueCount=" + queueInfo.getChildQueues.size)
+ logInfo("""Queue info ... queueName = %s, queueCurrentCapacity = %s, queueMaxCapacity = %s,
+ queueApplicationCount = %s, queueChildQueueCount = %s""".format(
+ queueInfo.getQueueName,
+ queueInfo.getCurrentCapacity,
+ queueInfo.getMaximumCapacity,
+ queueInfo.getApplications.size,
+ queueInfo.getChildQueues.size))
}
-
def verifyClusterResources(app: GetNewApplicationResponse) = {
val maxMem = app.getMaximumResourceCapability().getMemory()
logInfo("Max mem capabililty of a single resource in this cluster " + maxMem)
-
- // if we have requested more then the clusters max for a single resource then exit.
+
+ // If we have requested more then the clusters max for a single resource then exit.
if (args.workerMemory > maxMem) {
- logError("the worker size is to large to run on this cluster " + args.workerMemory);
+ logError("the worker size is to large to run on this cluster " + args.workerMemory)
System.exit(1)
}
val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD
@@ -104,10 +145,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
System.exit(1)
}
- // We could add checks to make sure the entire cluster has enough resources but that involves getting
- // all the node reports and computing ourselves
+ // We could add checks to make sure the entire cluster has enough resources but that involves
+ // getting all the node reports and computing ourselves
}
-
+
def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = {
logInfo("Setting up application submission context for ASM")
val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
@@ -116,198 +157,169 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
return appContext
}
- /**
- * Copy the local file into HDFS and configure to be distributed with the
- * job via the distributed cache.
- * If a fragment is specified the file will be referenced as that fragment.
- */
- private def copyLocalFile(
- dstDir: Path,
- resourceType: LocalResourceType,
- originalPath: Path,
- replication: Short,
- localResources: HashMap[String,LocalResource],
- fragment: String,
- appMasterOnly: Boolean = false): Unit = {
- val fs = FileSystem.get(conf)
- val newPath = new Path(dstDir, originalPath.getName())
- logInfo("Uploading " + originalPath + " to " + newPath)
- fs.copyFromLocalFile(false, true, originalPath, newPath)
- fs.setReplication(newPath, replication);
- val destStatus = fs.getFileStatus(newPath)
-
- val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
- amJarRsrc.setType(resourceType)
- amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
- amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath))
- amJarRsrc.setTimestamp(destStatus.getModificationTime())
- amJarRsrc.setSize(destStatus.getLen())
- var pathURI: URI = new URI(newPath.toString() + "#" + originalPath.getName());
- if ((fragment == null) || (fragment.isEmpty())){
- localResources(originalPath.getName()) = amJarRsrc
- } else {
- localResources(fragment) = amJarRsrc
- pathURI = new URI(newPath.toString() + "#" + fragment);
+ /** See if two file systems are the same or not. */
+ private def compareFs(srcFs: FileSystem, destFs: FileSystem): Boolean = {
+ val srcUri = srcFs.getUri()
+ val dstUri = destFs.getUri()
+ if (srcUri.getScheme() == null) {
+ return false
+ }
+ if (!srcUri.getScheme().equals(dstUri.getScheme())) {
+ return false
}
- val distPath = pathURI.toString()
- if (appMasterOnly == true) return
- if (resourceType == LocalResourceType.FILE) {
- distFiles match {
- case Some(path) =>
- distFilesFileSizes = Some(distFilesFileSizes.get + "," +
- destStatus.getLen().toString())
- distFilesTimeStamps = Some(distFilesTimeStamps.get + "," +
- destStatus.getModificationTime().toString())
- distFiles = Some(path + "," + distPath)
- case _ =>
- distFilesFileSizes = Some(destStatus.getLen().toString())
- distFilesTimeStamps = Some(destStatus.getModificationTime().toString())
- distFiles = Some(distPath)
+ var srcHost = srcUri.getHost()
+ var dstHost = dstUri.getHost()
+ if ((srcHost != null) && (dstHost != null)) {
+ try {
+ srcHost = InetAddress.getByName(srcHost).getCanonicalHostName()
+ dstHost = InetAddress.getByName(dstHost).getCanonicalHostName()
+ } catch {
+ case e: UnknownHostException =>
+ return false
}
- } else {
- distArchives match {
- case Some(path) =>
- distArchivesTimeStamps = Some(distArchivesTimeStamps.get + "," +
- destStatus.getModificationTime().toString())
- distArchivesFileSizes = Some(distArchivesFileSizes.get + "," +
- destStatus.getLen().toString())
- distArchives = Some(path + "," + distPath)
- case _ =>
- distArchivesTimeStamps = Some(destStatus.getModificationTime().toString())
- distArchivesFileSizes = Some(destStatus.getLen().toString())
- distArchives = Some(distPath)
+ if (!srcHost.equals(dstHost)) {
+ return false
}
+ } else if (srcHost == null && dstHost != null) {
+ return false
+ } else if (srcHost != null && dstHost == null) {
+ return false
+ }
+ //check for ports
+ if (srcUri.getPort() != dstUri.getPort()) {
+ return false
}
+ return true
}
- def prepareLocalResources(appId: ApplicationId, sparkStagingDir: String): HashMap[String, LocalResource] = {
+ /** Copy the file into HDFS if needed. */
+ private def copyRemoteFile(
+ dstDir: Path,
+ originalPath: Path,
+ replication: Short,
+ setPerms: Boolean = false): Path = {
+ val fs = FileSystem.get(conf)
+ val remoteFs = originalPath.getFileSystem(conf)
+ var newPath = originalPath
+ if (! compareFs(remoteFs, fs)) {
+ newPath = new Path(dstDir, originalPath.getName())
+ logInfo("Uploading " + originalPath + " to " + newPath)
+ FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf)
+ fs.setReplication(newPath, replication)
+ if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION))
+ }
+ // Resolve any symlinks in the URI path so using a "current" symlink to point to a specific
+ // version shows the specific version in the distributed cache configuration
+ val qualPath = fs.makeQualified(newPath)
+ val fc = FileContext.getFileContext(qualPath.toUri(), conf)
+ val destPath = fc.resolvePath(qualPath)
+ destPath
+ }
+
+ def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = {
logInfo("Preparing Local resources")
- // Upload Spark and the application JAR to the remote file system
- // Add them as local resources to the AM
+ // Upload Spark and the application JAR to the remote file system if necessary. Add them as
+ // local resources to the AM.
val fs = FileSystem.get(conf)
- val delegTokenRenewer = Master.getMasterPrincipal(conf);
+ val delegTokenRenewer = Master.getMasterPrincipal(conf)
if (UserGroupInformation.isSecurityEnabled()) {
if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) {
logError("Can't get Master Kerberos principal for use as renewer")
System.exit(1)
}
}
-
- val pathSuffix = sparkStagingDir + "/" + appId.toString() + "/"
- val dst = new Path(fs.getHomeDirectory(), pathSuffix)
+ val dst = new Path(fs.getHomeDirectory(), appStagingDir)
val replication = System.getProperty("spark.yarn.submit.file.replication", "3").toShort
if (UserGroupInformation.isSecurityEnabled()) {
val dstFs = dst.getFileSystem(conf)
- dstFs.addDelegationTokens(delegTokenRenewer, credentials);
+ dstFs.addDelegationTokens(delegTokenRenewer, credentials)
}
val localResources = HashMap[String, LocalResource]()
+ FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION))
- Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+
+ Map(Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar,
+ Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF"))
.foreach { case(destName, _localPath) =>
val localPath: String = if (_localPath != null) _localPath.trim() else ""
if (! localPath.isEmpty()) {
- val src = new Path(localPath)
- val newPath = new Path(dst, destName)
- logInfo("Uploading " + src + " to " + newPath)
- fs.copyFromLocalFile(false, true, src, newPath)
- fs.setReplication(newPath, replication);
- val destStatus = fs.getFileStatus(newPath)
-
- val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
- amJarRsrc.setType(LocalResourceType.FILE)
- amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
- amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(newPath))
- amJarRsrc.setTimestamp(destStatus.getModificationTime())
- amJarRsrc.setSize(destStatus.getLen())
- localResources(destName) = amJarRsrc
+ var localURI = new URI(localPath)
+ // if not specified assume these are in the local filesystem to keep behavior like Hadoop
+ if (localURI.getScheme() == null) {
+ localURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(localPath)).toString)
+ }
+ val setPermissions = if (destName.equals(Client.APP_JAR)) true else false
+ val destPath = copyRemoteFile(dst, new Path(localURI), replication, setPermissions)
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+ destName, statCache)
}
}
// handle any add jars
if ((args.addJars != null) && (!args.addJars.isEmpty())){
args.addJars.split(',').foreach { case file: String =>
- val tmpURI = new URI(file)
- val tmp = new Path(tmpURI)
- copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources,
- tmpURI.getFragment(), true)
+ val localURI = new URI(file.trim())
+ val localPath = new Path(localURI)
+ val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+ val destPath = copyRemoteFile(dst, localPath, replication)
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+ linkname, statCache, true)
}
}
// handle any distributed cache files
if ((args.files != null) && (!args.files.isEmpty())){
args.files.split(',').foreach { case file: String =>
- val tmpURI = new URI(file)
- val tmp = new Path(tmpURI)
- copyLocalFile(dst, LocalResourceType.FILE, tmp, replication, localResources,
- tmpURI.getFragment())
+ val localURI = new URI(file.trim())
+ val localPath = new Path(localURI)
+ val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+ val destPath = copyRemoteFile(dst, localPath, replication)
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE,
+ linkname, statCache)
}
}
// handle any distributed cache archives
if ((args.archives != null) && (!args.archives.isEmpty())) {
args.archives.split(',').foreach { case file:String =>
- val tmpURI = new URI(file)
- val tmp = new Path(tmpURI)
- copyLocalFile(dst, LocalResourceType.ARCHIVE, tmp, replication,
- localResources, tmpURI.getFragment())
+ val localURI = new URI(file.trim())
+ val localPath = new Path(localURI)
+ val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+ val destPath = copyRemoteFile(dst, localPath, replication)
+ distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE,
+ linkname, statCache)
}
}
- UserGroupInformation.getCurrentUser().addCredentials(credentials);
+ UserGroupInformation.getCurrentUser().addCredentials(credentials)
return localResources
}
-
- def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
+
+ def setupLaunchEnv(
+ localResources: HashMap[String, LocalResource],
+ stagingDir: String): HashMap[String, String] = {
logInfo("Setting up the launch environment")
- val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null)
+ val log4jConfLocalRes = localResources.getOrElse(Client.LOG4J_PROP, null)
val env = new HashMap[String, String]()
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
- Apps.addToEnvironment(env, Environment.CLASSPATH.name,
- Environment.PWD.$() + Path.SEPARATOR + "*")
-
- Client.populateHadoopClasspath(yarnConf, env)
+ Client.populateClasspath(yarnConf, log4jConfLocalRes != null, env)
env("SPARK_YARN_MODE") = "true"
- env("SPARK_YARN_JAR_PATH") =
- localResources("spark.jar").getResource().getScheme.toString() + "://" +
- localResources("spark.jar").getResource().getFile().toString()
- env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString()
- env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString()
-
- env("SPARK_YARN_USERJAR_PATH") =
- localResources("app.jar").getResource().getScheme.toString() + "://" +
- localResources("app.jar").getResource().getFile().toString()
- env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString()
- env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString()
-
- if (log4jConfLocalRes != null) {
- env("SPARK_YARN_LOG4J_PATH") =
- log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString()
- env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString()
- env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
- }
+ env("SPARK_YARN_STAGING_DIR") = stagingDir
- // set the environment variables to be passed on to the Workers
- if (distFiles != None) {
- env("SPARK_YARN_CACHE_FILES") = distFiles.get
- env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") = distFilesTimeStamps.get
- env("SPARK_YARN_CACHE_FILES_FILE_SIZES") = distFilesFileSizes.get
- }
- if (distArchives != None) {
- env("SPARK_YARN_CACHE_ARCHIVES") = distArchives.get
- env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") = distArchivesTimeStamps.get
- env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") = distArchivesFileSizes.get
- }
+ // Set the environment variables to be passed on to the Workers.
+ distCacheMgr.setDistFilesEnv(env)
+ distCacheMgr.setDistArchivesEnv(env)
- // allow users to specify some environment variables
+ // Allow users to specify some environment variables.
Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
- // Add each SPARK-* key to the environment
+ // Add each SPARK-* key to the environment.
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
- return env
+ env
}
def userArgsToString(clientArgs: ClientArguments): String = {
@@ -317,13 +329,13 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
for (arg <- args){
retval.append(prefix).append(" '").append(arg).append("' ")
}
-
retval.toString
}
- def createContainerLaunchContext(newApp: GetNewApplicationResponse,
- localResources: HashMap[String, LocalResource],
- env: HashMap[String, String]): ContainerLaunchContext = {
+ def createContainerLaunchContext(
+ newApp: GetNewApplicationResponse,
+ localResources: HashMap[String, LocalResource],
+ env: HashMap[String, String]): ContainerLaunchContext = {
logInfo("Setting up container launch context")
val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
amContainer.setLocalResources(localResources)
@@ -331,8 +343,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
+ // TODO(harvey): This can probably be a val.
var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
- (if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD
+ ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) -
+ YarnAllocationHandler.MEMORY_OVERHEAD)
// Extra options for the JVM
var JAVA_OPTS = ""
@@ -343,14 +357,18 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
JAVA_OPTS += " -Djava.io.tmpdir=" +
new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " "
-
- // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
- // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
- // node, spark gc effects all other containers performance (which can also be other spark containers)
- // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
- // limited to subset of cores on a node.
- if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) {
- // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines
+ // Commenting it out for now - so that people can refer to the properties if required. Remove
+ // it once cpuset version is pushed out. The context is, default gc for server class machines
+ // end up using all cores to do gc - hence if there are multiple containers in same node,
+ // spark gc effects all other containers performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in
+ // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset
+ // of cores on a node.
+ val useConcurrentAndIncrementalGC = env.isDefinedAt("SPARK_USE_CONC_INCR_GC") &&
+ java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))
+ if (useConcurrentAndIncrementalGC) {
+ // In our expts, using (default) throughput collector has severe perf ramnifications in
+ // multi-tenant machines
JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
JAVA_OPTS += " -XX:+CMSIncrementalMode "
JAVA_OPTS += " -XX:+CMSIncrementalPacing "
@@ -363,7 +381,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
}
// Command for the ApplicationMaster
- var javaCommand = "java";
+ var javaCommand = "java"
val javaHome = System.getenv("JAVA_HOME")
if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) {
javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
@@ -372,7 +390,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
val commands = List[String](javaCommand +
" -server " +
JAVA_OPTS +
- " org.apache.spark.deploy.yarn.ApplicationMaster" +
+ " " + args.amClass +
" --class " + args.userClass +
" --jar " + args.userJar +
userArgsToString(args) +
@@ -383,28 +401,28 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
" 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
logInfo("Command for the ApplicationMaster: " + commands(0))
amContainer.setCommands(commands)
-
+
val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
- // Memory for the ApplicationMaster
+ // Memory for the ApplicationMaster.
capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
amContainer.setResource(capability)
- // Setup security tokens
+ // Setup security tokens.
val dob = new DataOutputBuffer()
credentials.writeTokenStorageToStream(dob)
amContainer.setContainerTokens(ByteBuffer.wrap(dob.getData()))
- return amContainer
+ amContainer
}
-
+
def submitApp(appContext: ApplicationSubmissionContext) = {
- // Submit the application to the applications manager
+ // Submit the application to the applications manager.
logInfo("Submitting application to ASM")
super.submitApplication(appContext)
}
-
+
def monitorApplication(appId: ApplicationId): Boolean = {
- while(true) {
+ while (true) {
Thread.sleep(1000)
val report = super.getApplicationReport(appId)
@@ -422,26 +440,31 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
"\t appTrackingUrl: " + report.getTrackingUrl() + "\n" +
"\t appUser: " + report.getUser()
)
-
+
val state = report.getYarnApplicationState()
val dsStatus = report.getFinalApplicationStatus()
if (state == YarnApplicationState.FINISHED ||
state == YarnApplicationState.FAILED ||
state == YarnApplicationState.KILLED) {
- return true
+ return true
}
}
- return true
+ true
}
}
object Client {
+ val SPARK_JAR: String = "spark.jar"
+ val APP_JAR: String = "app.jar"
+ val LOG4J_PROP: String = "log4j.properties"
+
def main(argStrings: Array[String]) {
// Set an env variable indicating we are running in YARN mode.
// Note that anything with SPARK prefix gets propagated to all (remote) processes
System.setProperty("SPARK_YARN_MODE", "true")
val args = new ClientArguments(argStrings)
+
new Client(args).run
}
@@ -451,4 +474,30 @@ object Client {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
}
}
+
+ def populateClasspath(conf: Configuration, addLog4j: Boolean, env: HashMap[String, String]) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
+ // If log4j present, ensure ours overrides all others
+ if (addLog4j) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + LOG4J_PROP)
+ }
+ // Normally the users app.jar is last in case conflicts with spark jars
+ val userClasspathFirst = System.getProperty("spark.yarn.user.classpath.first", "false")
+ .toBoolean
+ if (userClasspathFirst) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + APP_JAR)
+ }
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + SPARK_JAR)
+ Client.populateHadoopClasspath(conf, env)
+
+ if (!userClasspathFirst) {
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + APP_JAR)
+ }
+ Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
+ Path.SEPARATOR + "*")
+ }
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index 852dbd7dab..b3a7886d93 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -35,6 +35,7 @@ class ClientArguments(val args: Array[String]) {
var numWorkers = 2
var amQueue = System.getProperty("QUEUE", "default")
var amMemory: Int = 512
+ var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster"
var appName: String = "Spark"
// TODO
var inputFormatInfo: List[InputFormatInfo] = null
@@ -62,18 +63,22 @@ class ClientArguments(val args: Array[String]) {
userArgsBuffer += value
args = tail
- case ("--master-memory") :: MemoryParam(value) :: tail =>
- amMemory = value
+ case ("--master-class") :: value :: tail =>
+ amClass = value
args = tail
- case ("--num-workers") :: IntParam(value) :: tail =>
- numWorkers = value
+ case ("--master-memory") :: MemoryParam(value) :: tail =>
+ amMemory = value
args = tail
case ("--worker-memory") :: MemoryParam(value) :: tail =>
workerMemory = value
args = tail
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
case ("--worker-cores") :: IntParam(value) :: tail =>
workerCores = value
args = tail
@@ -84,6 +89,7 @@ class ClientArguments(val args: Array[String]) {
case ("--name") :: value :: tail =>
appName = value
+ args = tail
case ("--addJars") :: value :: tail =>
addJars = value
@@ -119,19 +125,20 @@ class ClientArguments(val args: Array[String]) {
System.err.println(
"Usage: org.apache.spark.deploy.yarn.Client [options] \n" +
"Options:\n" +
- " --jar JAR_PATH Path to your application's JAR file (required)\n" +
- " --class CLASS_NAME Name of your application's main class (required)\n" +
- " --args ARGS Arguments to be passed to your application's main class.\n" +
- " Mutliple invocations are possible, each will be passed in order.\n" +
- " --num-workers NUM Number of workers to start (Default: 2)\n" +
- " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
- " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
- " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
- " --name NAME The name of your application (Default: Spark)\n" +
- " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
- " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
- " --files files Comma separated list of files to be distributed with the job.\n" +
- " --archives archives Comma separated list of archives to be distributed with the job."
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
+ " --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" +
+ " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
+ " --name NAME The name of your application (Default: Spark)\n" +
+ " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
+ " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
+ " --files files Comma separated list of files to be distributed with the job.\n" +
+ " --archives archives Comma separated list of archives to be distributed with the job."
)
System.exit(exitCode)
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
new file mode 100644
index 0000000000..5f159b073f
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsAction
+import org.apache.hadoop.yarn.api.records.LocalResource
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
+import org.apache.hadoop.yarn.api.records.LocalResourceType
+import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
+
+import org.apache.spark.Logging
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.LinkedHashMap
+import scala.collection.mutable.Map
+
+
+/** Client side methods to setup the Hadoop distributed cache */
+class ClientDistributedCacheManager() extends Logging {
+ private val distCacheFiles: Map[String, Tuple3[String, String, String]] =
+ LinkedHashMap[String, Tuple3[String, String, String]]()
+ private val distCacheArchives: Map[String, Tuple3[String, String, String]] =
+ LinkedHashMap[String, Tuple3[String, String, String]]()
+
+
+ /**
+ * Add a resource to the list of distributed cache resources. This list can
+ * be sent to the ApplicationMaster and possibly the workers so that it can
+ * be downloaded into the Hadoop distributed cache for use by this application.
+ * Adds the LocalResource to the localResources HashMap passed in and saves
+ * the stats of the resources to they can be sent to the workers and verified.
+ *
+ * @param fs FileSystem
+ * @param conf Configuration
+ * @param destPath path to the resource
+ * @param localResources localResource hashMap to insert the resource into
+ * @param resourceType LocalResourceType
+ * @param link link presented in the distributed cache to the destination
+ * @param statCache cache to store the file/directory stats
+ * @param appMasterOnly Whether to only add the resource to the app master
+ */
+ def addResource(
+ fs: FileSystem,
+ conf: Configuration,
+ destPath: Path,
+ localResources: HashMap[String, LocalResource],
+ resourceType: LocalResourceType,
+ link: String,
+ statCache: Map[URI, FileStatus],
+ appMasterOnly: Boolean = false) = {
+ val destStatus = fs.getFileStatus(destPath)
+ val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
+ amJarRsrc.setType(resourceType)
+ val visibility = getVisibility(conf, destPath.toUri(), statCache)
+ amJarRsrc.setVisibility(visibility)
+ amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(destPath))
+ amJarRsrc.setTimestamp(destStatus.getModificationTime())
+ amJarRsrc.setSize(destStatus.getLen())
+ if (link == null || link.isEmpty()) throw new Exception("You must specify a valid link name")
+ localResources(link) = amJarRsrc
+
+ if (appMasterOnly == false) {
+ val uri = destPath.toUri()
+ val pathURI = new URI(uri.getScheme(), uri.getAuthority(), uri.getPath(), null, link)
+ if (resourceType == LocalResourceType.FILE) {
+ distCacheFiles(pathURI.toString()) = (destStatus.getLen().toString(),
+ destStatus.getModificationTime().toString(), visibility.name())
+ } else {
+ distCacheArchives(pathURI.toString()) = (destStatus.getLen().toString(),
+ destStatus.getModificationTime().toString(), visibility.name())
+ }
+ }
+ }
+
+ /**
+ * Adds the necessary cache file env variables to the env passed in
+ * @param env
+ */
+ def setDistFilesEnv(env: Map[String, String]) = {
+ val (keys, tupleValues) = distCacheFiles.unzip
+ val (sizes, timeStamps, visibilities) = tupleValues.unzip3
+
+ if (keys.size > 0) {
+ env("SPARK_YARN_CACHE_FILES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") =
+ timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_FILE_SIZES") =
+ sizes.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_FILES_VISIBILITIES") =
+ visibilities.reduceLeft[String] { (acc,n) => acc + "," + n }
+ }
+ }
+
+ /**
+ * Adds the necessary cache archive env variables to the env passed in
+ * @param env
+ */
+ def setDistArchivesEnv(env: Map[String, String]) = {
+ val (keys, tupleValues) = distCacheArchives.unzip
+ val (sizes, timeStamps, visibilities) = tupleValues.unzip3
+
+ if (keys.size > 0) {
+ env("SPARK_YARN_CACHE_ARCHIVES") = keys.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") =
+ timeStamps.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") =
+ sizes.reduceLeft[String] { (acc,n) => acc + "," + n }
+ env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") =
+ visibilities.reduceLeft[String] { (acc,n) => acc + "," + n }
+ }
+ }
+
+ /**
+ * Returns the local resource visibility depending on the cache file permissions
+ * @param conf
+ * @param uri
+ * @param statCache
+ * @return LocalResourceVisibility
+ */
+ def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ LocalResourceVisibility = {
+ if (isPublic(conf, uri, statCache)) {
+ return LocalResourceVisibility.PUBLIC
+ }
+ return LocalResourceVisibility.PRIVATE
+ }
+
+ /**
+ * Returns a boolean to denote whether a cache file is visible to all(public)
+ * or not
+ * @param conf
+ * @param uri
+ * @param statCache
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ def isPublic(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]): Boolean = {
+ val fs = FileSystem.get(uri, conf)
+ val current = new Path(uri.getPath())
+ //the leaf level file should be readable by others
+ if (!checkPermissionOfOther(fs, current, FsAction.READ, statCache)) {
+ return false
+ }
+ return ancestorsHaveExecutePermissions(fs, current.getParent(), statCache)
+ }
+
+ /**
+ * Returns true if all ancestors of the specified path have the 'execute'
+ * permission set for all users (i.e. that other users can traverse
+ * the directory heirarchy to the given path)
+ * @param fs
+ * @param path
+ * @param statCache
+ * @return true if all ancestors have the 'execute' permission set for all users
+ */
+ def ancestorsHaveExecutePermissions(fs: FileSystem, path: Path,
+ statCache: Map[URI, FileStatus]): Boolean = {
+ var current = path
+ while (current != null) {
+ //the subdirs in the path should have execute permissions for others
+ if (!checkPermissionOfOther(fs, current, FsAction.EXECUTE, statCache)) {
+ return false
+ }
+ current = current.getParent()
+ }
+ return true
+ }
+
+ /**
+ * Checks for a given path whether the Other permissions on it
+ * imply the permission in the passed FsAction
+ * @param fs
+ * @param path
+ * @param action
+ * @param statCache
+ * @return true if the path in the uri is visible to all, false otherwise
+ */
+ def checkPermissionOfOther(fs: FileSystem, path: Path,
+ action: FsAction, statCache: Map[URI, FileStatus]): Boolean = {
+ val status = getFileStatus(fs, path.toUri(), statCache)
+ val perms = status.getPermission()
+ val otherAction = perms.getOtherAction()
+ if (otherAction.implies(action)) {
+ return true
+ }
+ return false
+ }
+
+ /**
+ * Checks to see if the given uri exists in the cache, if it does it
+ * returns the existing FileStatus, otherwise it stats the uri, stores
+ * it in the cache, and returns the FileStatus.
+ * @param fs
+ * @param uri
+ * @param statCache
+ * @return FileStatus
+ */
+ def getFileStatus(fs: FileSystem, uri: URI, statCache: Map[URI, FileStatus]): FileStatus = {
+ val stat = statCache.get(uri) match {
+ case Some(existstat) => existstat
+ case None =>
+ val newStat = fs.getFileStatus(new Path(uri))
+ statCache.put(uri, newStat)
+ newStat
+ }
+ return stat
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
new file mode 100644
index 0000000000..69038844bb
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.Socket
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import akka.actor._
+import akka.remote._
+import akka.actor.Terminated
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.scheduler.SplitInfo
+
+class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
+
+ def this(args: ApplicationMasterArguments) = this(args, new Configuration())
+
+ private val rpc: YarnRPC = YarnRPC.create(conf)
+ private var resourceManager: AMRMProtocol = null
+ private var appAttemptId: ApplicationAttemptId = null
+ private var reporterThread: Thread = null
+ private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ private var yarnAllocator: YarnAllocationHandler = null
+ private var driverClosed:Boolean = false
+
+ val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0)._1
+ var actor: ActorRef = null
+
+ // This actor just working as a monitor to watch on Driver Actor.
+ class MonitorActor(driverUrl: String) extends Actor {
+
+ var driver: ActorSelection = null
+
+ override def preStart() {
+ logInfo("Listen to driver: " + driverUrl)
+ driver = context.actorSelection(driverUrl)
+ driver ! "hello"
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ }
+
+ override def receive = {
+ case x: DisassociatedEvent =>
+ logInfo(s"Driver terminated or disconnected! Shutting down. $x")
+ driverClosed = true
+ }
+ }
+
+ def run() {
+
+ appAttemptId = getApplicationAttemptId()
+ resourceManager = registerWithResourceManager()
+ val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
+
+ // Compute number of threads for akka
+ val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
+
+ if (minimumMemory > 0) {
+ val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
+ val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
+
+ if (numCore > 0) {
+ // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
+ // TODO: Uncomment when hadoop is on a version which has this fixed.
+ // args.workerCores = numCore
+ }
+ }
+
+ waitForSparkMaster()
+
+ // Allocate all containers
+ allocateWorkers()
+
+ // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
+ // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
+
+ val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+ // must be <= timeoutInterval/ 2.
+ // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
+ // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
+ val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+ reporterThread = launchReporterThread(interval)
+
+ // Wait for the reporter thread to Finish.
+ reporterThread.join()
+
+ finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
+ actorSystem.shutdown()
+
+ logInfo("Exited")
+ System.exit(0)
+ }
+
+ private def getApplicationAttemptId(): ApplicationAttemptId = {
+ val envs = System.getenv()
+ val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ logInfo("ApplicationAttemptId: " + appAttemptId)
+ return appAttemptId
+ }
+
+ private def registerWithResourceManager(): AMRMProtocol = {
+ val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
+ YarnConfiguration.RM_SCHEDULER_ADDRESS,
+ YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
+ logInfo("Connecting to ResourceManager at " + rmAddress)
+ return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
+ }
+
+ private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
+ logInfo("Registering the ApplicationMaster")
+ val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
+ .asInstanceOf[RegisterApplicationMasterRequest]
+ appMasterRequest.setApplicationAttemptId(appAttemptId)
+ // Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
+ // Users can then monitor stderr/stdout on that node if required.
+ appMasterRequest.setHost(Utils.localHostName())
+ appMasterRequest.setRpcPort(0)
+ // What do we provide here ? Might make sense to expose something sensible later ?
+ appMasterRequest.setTrackingUrl("")
+ return resourceManager.registerApplicationMaster(appMasterRequest)
+ }
+
+ private def waitForSparkMaster() {
+ logInfo("Waiting for spark driver to be reachable.")
+ var driverUp = false
+ val hostport = args.userArgs(0)
+ val (driverHost, driverPort) = Utils.parseHostPort(hostport)
+ while(!driverUp) {
+ try {
+ val socket = new Socket(driverHost, driverPort)
+ socket.close()
+ logInfo("Master now available: " + driverHost + ":" + driverPort)
+ driverUp = true
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to driver at " + driverHost + ":" + driverPort)
+ Thread.sleep(100)
+ }
+ }
+ System.setProperty("spark.driver.host", driverHost)
+ System.setProperty("spark.driver.port", driverPort.toString)
+
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
+ driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+ actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM")
+ }
+
+
+ private def allocateWorkers() {
+
+ // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now.
+ val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map()
+
+ yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, preferredNodeLocationData)
+
+ logInfo("Allocating " + args.numWorkers + " workers.")
+ // Wait until all containers have finished
+ // TODO: This is a bit ugly. Can we make it nicer?
+ // TODO: Handle container failure
+ while(yarnAllocator.getNumWorkersRunning < args.numWorkers) {
+ yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
+ Thread.sleep(100)
+ }
+
+ logInfo("All workers have launched.")
+
+ }
+
+ // TODO: We might want to extend this to allocate more containers in case they die !
+ private def launchReporterThread(_sleepTime: Long): Thread = {
+ val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
+
+ val t = new Thread {
+ override def run() {
+ while (!driverClosed) {
+ val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
+ if (missingWorkerCount > 0) {
+ logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
+ yarnAllocator.allocateContainers(missingWorkerCount)
+ }
+ else sendProgress()
+ Thread.sleep(sleepTime)
+ }
+ }
+ }
+ // setting to daemon status, though this is usually not a good idea.
+ t.setDaemon(true)
+ t.start()
+ logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+ return t
+ }
+
+ private def sendProgress() {
+ logDebug("Sending progress")
+ // simulated with an allocate request with no nodes requested ...
+ yarnAllocator.allocateContainers(0)
+ }
+
+ def finishApplicationMaster(status: FinalApplicationStatus) {
+
+ logInfo("finish ApplicationMaster with " + status)
+ val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
+ .asInstanceOf[FinishApplicationMasterRequest]
+ finishReq.setAppAttemptId(appAttemptId)
+ finishReq.setFinishApplicationStatus(status)
+ resourceManager.finishApplicationMaster(finishReq)
+ }
+
+}
+
+
+object WorkerLauncher {
+ def main(argStrings: Array[String]) {
+ val args = new ApplicationMasterArguments(argStrings)
+ new WorkerLauncher(args).run()
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
index 8dac9e02ac..6a90cc51cf 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala
@@ -21,53 +21,59 @@ import java.net.URI
import java.nio.ByteBuffer
import java.security.PrivilegedExceptionAction
+import scala.collection.JavaConversions._
+import scala.collection.mutable.HashMap
+
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.DataOutputBuffer
import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils}
-import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
-
-import scala.collection.JavaConversions._
-import scala.collection.mutable.HashMap
import org.apache.spark.Logging
-import org.apache.spark.util.Utils
-class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String,
- slaveId: String, hostname: String, workerMemory: Int, workerCores: Int)
- extends Runnable with Logging {
-
+
+class WorkerRunnable(
+ container: Container,
+ conf: Configuration,
+ masterAddress: String,
+ slaveId: String,
+ hostname: String,
+ workerMemory: Int,
+ workerCores: Int)
+ extends Runnable with Logging {
+
var rpc: YarnRPC = YarnRPC.create(conf)
var cm: ContainerManager = null
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
-
+
def run = {
logInfo("Starting Worker Container")
cm = connectToCM
startContainer
}
-
+
def startContainer = {
logInfo("Setting up ContainerLaunchContext")
-
+
val ctx = Records.newRecord(classOf[ContainerLaunchContext])
.asInstanceOf[ContainerLaunchContext]
-
+
ctx.setContainerId(container.getId())
ctx.setResource(container.getResource())
val localResources = prepareLocalResources
ctx.setLocalResources(localResources)
-
+
val env = prepareEnvironment
ctx.setEnvironment(env)
-
+
// Extra options for the JVM
var JAVA_OPTS = ""
// Set the JVM memory
@@ -80,17 +86,21 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
JAVA_OPTS += " -Djava.io.tmpdir=" +
new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " "
-
- // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
- // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
- // node, spark gc effects all other containers performance (which can also be other spark containers)
- // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
- // limited to subset of cores on a node.
+ // Commenting it out for now - so that people can refer to the properties if required. Remove
+ // it once cpuset version is pushed out.
+ // The context is, default gc for server class machines end up using all cores to do gc - hence
+ // if there are multiple containers in same node, spark gc effects all other containers
+ // performance (which can also be other spark containers)
+ // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in
+ // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset
+ // of cores on a node.
/*
else {
// If no java_opts specified, default to using -XX:+CMSIncrementalMode
- // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it.
- // In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines
+ // It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont
+ // want to mess with it.
+ // In our expts, using (default) throughput collector has severe perf ramnifications in
+ // multi-tennent machines
// The options are based on
// http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
@@ -108,7 +118,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
credentials.writeTokenStorageToStream(dob)
ctx.setContainerTokens(ByteBuffer.wrap(dob.getData()))
- var javaCommand = "java";
+ var javaCommand = "java"
val javaHome = System.getenv("JAVA_HOME")
if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) {
javaCommand = Environment.JAVA_HOME.$() + "/bin/java"
@@ -117,11 +127,13 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
val commands = List[String](javaCommand +
" -server " +
// Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
- // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state.
- // TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
+ // Not killing the task leaves various aspects of the worker and (to some extent) the jvm in
+ // an inconsistent state.
+ // TODO: If the OOM is not recoverable by rescheduling it on different node, then do
+ // 'something' to fail job ... akin to blacklisting trackers in mapred ?
" -XX:OnOutOfMemoryError='kill %p' " +
JAVA_OPTS +
- " org.apache.spark.executor.StandaloneExecutorBackend " +
+ " org.apache.spark.executor.CoarseGrainedExecutorBackend " +
masterAddress + " " +
slaveId + " " +
hostname + " " +
@@ -130,7 +142,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
" 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
logInfo("Setting up worker with commands: " + commands)
ctx.setCommands(commands)
-
+
// Send the start request to the ContainerManager
val startReq = Records.newRecord(classOf[StartContainerRequest])
.asInstanceOf[StartContainerRequest]
@@ -138,64 +150,35 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
cm.startContainer(startReq)
}
- private def setupDistributedCache(file: String,
+ private def setupDistributedCache(
+ file: String,
rtype: LocalResourceType,
localResources: HashMap[String, LocalResource],
timestamp: String,
- size: String) = {
+ size: String,
+ vis: String) = {
val uri = new URI(file)
val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
amJarRsrc.setType(rtype)
- amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
+ amJarRsrc.setVisibility(LocalResourceVisibility.valueOf(vis))
amJarRsrc.setResource(ConverterUtils.getYarnUrlFromURI(uri))
amJarRsrc.setTimestamp(timestamp.toLong)
amJarRsrc.setSize(size.toLong)
localResources(uri.getFragment()) = amJarRsrc
}
-
-
+
def prepareLocalResources: HashMap[String, LocalResource] = {
logInfo("Preparing Local resources")
val localResources = HashMap[String, LocalResource]()
-
- // Spark JAR
- val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
- sparkJarResource.setType(LocalResourceType.FILE)
- sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
- sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
- new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
- sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
- sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
- localResources("spark.jar") = sparkJarResource
- // User JAR
- val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
- userJarResource.setType(LocalResourceType.FILE)
- userJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
- userJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
- new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
- userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
- userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
- localResources("app.jar") = userJarResource
-
- // Log4j conf - if available
- if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
- val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
- log4jConfResource.setType(LocalResourceType.FILE)
- log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION)
- log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI(
- new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
- log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
- log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
- localResources("log4j.properties") = log4jConfResource
- }
if (System.getenv("SPARK_YARN_CACHE_FILES") != null) {
val timeStamps = System.getenv("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
val fileSizes = System.getenv("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
val distFiles = System.getenv("SPARK_YARN_CACHE_FILES").split(',')
+ val visibilities = System.getenv("SPARK_YARN_CACHE_FILES_VISIBILITIES").split(',')
for( i <- 0 to distFiles.length - 1) {
setupDistributedCache(distFiles(i), LocalResourceType.FILE, localResources, timeStamps(i),
- fileSizes(i))
+ fileSizes(i), visibilities(i))
}
}
@@ -203,40 +186,38 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
val timeStamps = System.getenv("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS").split(',')
val fileSizes = System.getenv("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES").split(',')
val distArchives = System.getenv("SPARK_YARN_CACHE_ARCHIVES").split(',')
+ val visibilities = System.getenv("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES").split(',')
for( i <- 0 to distArchives.length - 1) {
setupDistributedCache(distArchives(i), LocalResourceType.ARCHIVE, localResources,
- timeStamps(i), fileSizes(i))
+ timeStamps(i), fileSizes(i), visibilities(i))
}
}
-
+
logInfo("Prepared Local resources " + localResources)
return localResources
}
-
+
def prepareEnvironment: HashMap[String, String] = {
val env = new HashMap[String, String]()
- Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
- Apps.addToEnvironment(env, Environment.CLASSPATH.name,
- Environment.PWD.$() + Path.SEPARATOR + "*")
- Client.populateHadoopClasspath(yarnConf, env)
+ Client.populateClasspath(yarnConf, System.getenv("SPARK_YARN_LOG4J_PATH") != null, env)
- // allow users to specify some environment variables
+ // Allow users to specify some environment variables
Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
return env
}
-
+
def connectToCM: ContainerManager = {
val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort()
val cmAddress = NetUtils.createSocketAddr(cmHostPortStr)
logInfo("Connecting to ContainerManager at " + cmHostPortStr)
- // use doAs and remoteUser here so we can add the container token and not
- // pollute the current users credentials with all of the individual container tokens
- val user = UserGroupInformation.createRemoteUser(container.getId().toString());
- val containerToken = container.getContainerToken();
+ // Use doAs and remoteUser here so we can add the container token and not pollute the current
+ // users credentials with all of the individual container tokens
+ val user = UserGroupInformation.createRemoteUser(container.getId().toString())
+ val containerToken = container.getContainerToken()
if (containerToken != null) {
user.addToken(ProtoUtils.convertFromProtoFormat(containerToken, cmAddress))
}
@@ -247,8 +228,8 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
return rpc.getProxy(classOf[ContainerManager],
cmAddress, conf).asInstanceOf[ContainerManager]
}
- });
- return proxy;
+ })
+ proxy
}
-
+
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
index 25da9aa917..9ab2073529 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -17,87 +17,112 @@
package org.apache.spark.deploy.yarn
+import java.lang.{Boolean => JBoolean}
+import java.util.{Collections, Set => JSet}
+import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
import org.apache.spark.Logging
-import org.apache.spark.util.Utils
import org.apache.spark.scheduler.SplitInfo
-import scala.collection
-import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
+import org.apache.spark.util.Utils
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.yarn.api.AMRMProtocol
+import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId}
+import org.apache.hadoop.yarn.api.records.{Container, ContainerId, ContainerStatus}
+import org.apache.hadoop.yarn.api.records.{Priority, Resource, ResourceRequest}
import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
import org.apache.hadoop.yarn.util.{RackResolver, Records}
-import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
-import java.util.concurrent.atomic.AtomicInteger
-import org.apache.hadoop.yarn.api.AMRMProtocol
-import collection.JavaConversions._
-import collection.mutable.{ArrayBuffer, HashMap, HashSet}
-import org.apache.hadoop.conf.Configuration
-import java.util.{Collections, Set => JSet}
-import java.lang.{Boolean => JBoolean}
-object AllocationType extends Enumeration ("HOST", "RACK", "ANY") {
+
+object AllocationType extends Enumeration {
type AllocationType = Value
val HOST, RACK, ANY = Value
}
-// too many params ? refactor it 'somehow' ?
-// needs to be mt-safe
-// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it
-// more proactive and decoupled.
+// TODO:
+// Too many params.
+// Needs to be mt-safe
+// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should
+// make it more proactive and decoupled.
+
// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
-// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info
-// on how we are requesting for containers.
-private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol,
- val appAttemptId: ApplicationAttemptId,
- val maxWorkers: Int, val workerMemory: Int, val workerCores: Int,
- val preferredHostToCount: Map[String, Int],
- val preferredRackToCount: Map[String, Int])
+// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for
+// more info on how we are requesting for containers.
+private[yarn] class YarnAllocationHandler(
+ val conf: Configuration,
+ val resourceManager: AMRMProtocol,
+ val appAttemptId: ApplicationAttemptId,
+ val maxWorkers: Int,
+ val workerMemory: Int,
+ val workerCores: Int,
+ val preferredHostToCount: Map[String, Int],
+ val preferredRackToCount: Map[String, Int])
extends Logging {
-
-
// These three are locked on allocatedHostToContainersMap. Complementary data structures
// allocatedHostToContainersMap : containers which are running : host, Set<containerid>
- // allocatedContainerToHostMap: container to host mapping
- private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]()
+ // allocatedContainerToHostMap: container to host mapping.
+ private val allocatedHostToContainersMap =
+ new HashMap[String, collection.mutable.Set[ContainerId]]()
+
private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
- // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node)
- // As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap
+
+ // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an
+ // allocated node)
+ // As with the two data structures above, tightly coupled with them, and to be locked on
+ // allocatedHostToContainersMap
private val allocatedRackCount = new HashMap[String, Int]()
- // containers which have been released.
+ // Containers which have been released.
private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]()
- // containers to be released in next request to RM
+ // Containers to be released in next request to RM
private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
private val numWorkersRunning = new AtomicInteger()
// Used to generate a unique id per worker
private val workerIdCounter = new AtomicInteger()
private val lastResponseId = new AtomicInteger()
+ private val numWorkersFailed = new AtomicInteger()
def getNumWorkersRunning: Int = numWorkersRunning.intValue
+ def getNumWorkersFailed: Int = numWorkersFailed.intValue
def isResourceConstraintSatisfied(container: Container): Boolean = {
container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
}
def allocateContainers(workersToRequest: Int) {
- // We need to send the request only once from what I understand ... but for now, not modifying this much.
+ // We need to send the request only once from what I understand ... but for now, not modifying
+ // this much.
// Keep polling the Resource Manager for containers
val amResp = allocateWorkerResources(workersToRequest).getAMResponse
val _allocatedContainers = amResp.getAllocatedContainers()
- if (_allocatedContainers.size > 0) {
-
- logDebug("Allocated " + _allocatedContainers.size + " containers, current count " +
- numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
- ", pendingReleaseContainers : " + pendingReleaseContainers)
- logDebug("Cluster Resources: " + amResp.getAvailableResources)
+ if (_allocatedContainers.size > 0) {
+ logDebug("""
+ Allocated containers: %d
+ Current worker count: %d
+ Containers released: %s
+ Containers to be released: %s
+ Cluster resources: %s
+ """.format(
+ _allocatedContainers.size,
+ numWorkersRunning.get(),
+ releasedContainerList,
+ pendingReleaseContainers,
+ amResp.getAvailableResources))
val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
- // ignore if not satisfying constraints {
+ // Ignore if not satisfying constraints {
for (container <- _allocatedContainers) {
if (isResourceConstraintSatisfied(container)) {
// allocatedContainers += container
@@ -111,8 +136,7 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
else releasedContainerList.add(container.getId())
}
- // Find the appropriate containers to use
- // Slightly non trivial groupBy I guess ...
+ // Find the appropriate containers to use. Slightly non trivial groupBy ...
val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
@@ -132,21 +156,22 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
remainingContainers = null
}
else if (requiredHostCount > 0) {
- // container list has more containers than we need for data locality.
- // Split into two : data local container count of (remainingContainers.size - requiredHostCount)
- // and rest as remainingContainer
- val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount)
+ // Container list has more containers than we need for data locality.
+ // Split into two : data local container count of (remainingContainers.size -
+ // requiredHostCount) and rest as remainingContainer
+ val (dataLocal, remaining) = remainingContainers.splitAt(
+ remainingContainers.size - requiredHostCount)
dataLocalContainers.put(candidateHost, dataLocal)
// remainingContainers = remaining
// yarn has nasty habit of allocating a tonne of containers on a host - discourage this :
- // add remaining to release list. If we have insufficient containers, next allocation cycle
- // will reallocate (but wont treat it as data local)
+ // add remaining to release list. If we have insufficient containers, next allocation
+ // cycle will reallocate (but wont treat it as data local)
for (container <- remaining) releasedContainerList.add(container.getId())
remainingContainers = null
}
- // now rack local
+ // Now rack local
if (remainingContainers != null){
val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
@@ -159,15 +184,17 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
if (requiredRackCount >= remainingContainers.size){
// Add all to dataLocalContainers
dataLocalContainers.put(rack, remainingContainers)
- // all consumed
+ // All consumed
remainingContainers = null
}
else if (requiredRackCount > 0) {
// container list has more containers than we need for data locality.
- // Split into two : data local container count of (remainingContainers.size - requiredRackCount)
- // and rest as remainingContainer
- val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount)
- val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]())
+ // Split into two : data local container count of (remainingContainers.size -
+ // requiredRackCount) and rest as remainingContainer
+ val (rackLocal, remaining) = remainingContainers.splitAt(
+ remainingContainers.size - requiredRackCount)
+ val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack,
+ new ArrayBuffer[Container]())
existingRackLocal ++= rackLocal
remainingContainers = remaining
@@ -183,8 +210,8 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
// Now that we have split the containers into various groups, go through them in order :
// first host local, then rack local and then off rack (everything else).
- // Note that the list we create below tries to ensure that not all containers end up within a host
- // if there are sufficiently large number of hosts/containers.
+ // Note that the list we create below tries to ensure that not all containers end up within a
+ // host if there are sufficiently large number of hosts/containers.
val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
@@ -197,33 +224,39 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
val workerHostname = container.getNodeId.getHost
val containerId = container.getId
- assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
+ assert(
+ container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
if (numWorkersRunningNow > maxWorkers) {
- logInfo("Ignoring container " + containerId + " at host " + workerHostname +
- " .. we already have required number of containers")
+ logInfo("""Ignoring container %s at host %s, since we already have the required number of
+ containers for it.""".format(containerId, workerHostname))
releasedContainerList.add(containerId)
// reset counter back to old value.
numWorkersRunning.decrementAndGet()
}
else {
- // deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter)
+ // Deallocate + allocate can result in reusing id's wrongly - so use a different counter
+ // (workerIdCounter)
val workerId = workerIdCounter.incrementAndGet().toString
- val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
logInfo("launching container on " + containerId + " host " + workerHostname)
- // just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..
+ // Just to be safe, simply remove it from pendingReleaseContainers.
+ // Should not be there, but ..
pendingReleaseContainers.remove(containerId)
val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
allocatedHostToContainersMap.synchronized {
- val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]())
+ val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname,
+ new HashSet[ContainerId]())
containerSet += containerId
allocatedContainerToHostMap.put(containerId, workerHostname)
- if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
+ if (rack != null) {
+ allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
+ }
}
new Thread(
@@ -232,17 +265,23 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
).start()
}
}
- logDebug("After allocated " + allocatedContainers.size + " containers (orig : " +
- _allocatedContainers.size + "), current count " + numWorkersRunning.get() +
- ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
+ logDebug("""
+ Finished processing %d containers.
+ Current number of workers running: %d,
+ releasedContainerList: %s,
+ pendingReleaseContainers: %s
+ """.format(
+ allocatedContainers.size,
+ numWorkersRunning.get(),
+ releasedContainerList,
+ pendingReleaseContainers))
}
val completedContainers = amResp.getCompletedContainersStatuses()
if (completedContainers.size > 0){
- logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() +
- ", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
-
+ logDebug("Completed %d containers, to-be-released: %s".format(
+ completedContainers.size, releasedContainerList))
for (completedContainer <- completedContainers){
val containerId = completedContainer.getContainerId
@@ -251,10 +290,19 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
pendingReleaseContainers.remove(containerId)
}
else {
- // simply decrement count - next iteration of ReporterThread will take care of allocating !
+ // Simply decrement count - next iteration of ReporterThread will take care of allocating.
numWorkersRunning.decrementAndGet()
- logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState +
- " httpaddress: " + completedContainer.getDiagnostics)
+ logInfo("Completed container %s (state: %s, exit status: %s)".format(
+ containerId,
+ completedContainer.getState,
+ completedContainer.getExitStatus()))
+ // Hadoop 2.2.X added a ContainerExitStatus we should switch to use
+ // there are some exit status' we shouldn't necessarily count against us, but for
+ // now I think its ok as none of the containers are expected to exit
+ if (completedContainer.getExitStatus() != 0) {
+ logInfo("Container marked as failed: " + containerId)
+ numWorkersFailed.incrementAndGet()
+ }
}
allocatedHostToContainersMap.synchronized {
@@ -271,7 +319,7 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
allocatedContainerToHostMap -= containerId
- // doing this within locked context, sigh ... move to outside ?
+ // Doing this within locked context, sigh ... move to outside ?
val rack = YarnAllocationHandler.lookupRack(conf, host)
if (rack != null) {
val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
@@ -281,9 +329,16 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
}
}
}
- logDebug("After completed " + completedContainers.size + " containers, current count " +
- numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
- ", pendingReleaseContainers : " + pendingReleaseContainers)
+ logDebug("""
+ Finished processing %d completed containers.
+ Current number of workers running: %d,
+ releasedContainerList: %s,
+ pendingReleaseContainers: %s
+ """.format(
+ completedContainers.size,
+ numWorkersRunning.get(),
+ releasedContainerList,
+ pendingReleaseContainers))
}
}
@@ -337,7 +392,7 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
// default.
if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
- logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty)
+ logDebug("numWorkers: " + numWorkers + ", host preferences: " + preferredHostToCount.isEmpty)
resourceRequests = List(
createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY))
}
@@ -350,17 +405,24 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
if (requiredCount > 0) {
- hostContainerRequests +=
- createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY)
+ hostContainerRequests += createResourceRequest(
+ AllocationType.HOST,
+ candidateHost,
+ requiredCount,
+ YarnAllocationHandler.PRIORITY)
}
}
- val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList)
+ val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(
+ hostContainerRequests.toList)
- val anyContainerRequests: ResourceRequest =
- createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)
+ val anyContainerRequests: ResourceRequest = createResourceRequest(
+ AllocationType.ANY,
+ resource = null,
+ numWorkers,
+ YarnAllocationHandler.PRIORITY)
- val containerRequests: ArrayBuffer[ResourceRequest] =
- new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1)
+ val containerRequests: ArrayBuffer[ResourceRequest] = new ArrayBuffer[ResourceRequest](
+ hostContainerRequests.size + rackContainerRequests.size + 1)
containerRequests ++= hostContainerRequests
containerRequests ++= rackContainerRequests
@@ -378,55 +440,60 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
val releasedContainerList = createReleasedContainerList()
req.addAllReleases(releasedContainerList)
-
-
if (numWorkers > 0) {
- logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.")
+ logInfo("Allocating %d worker containers with %d of memory each.".format(numWorkers,
+ workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
}
else {
logDebug("Empty allocation req .. release : " + releasedContainerList)
}
- for (req <- resourceRequests) {
- logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers +
- ", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability)
+ for (request <- resourceRequests) {
+ logInfo("ResourceRequest (host : %s, num containers: %d, priority = %s , capability : %s)".
+ format(
+ request.getHostName,
+ request.getNumContainers,
+ request.getPriority,
+ request.getCapability))
}
resourceManager.allocate(req)
}
- private def createResourceRequest(requestType: AllocationType.AllocationType,
- resource:String, numWorkers: Int, priority: Int): ResourceRequest = {
+ private def createResourceRequest(
+ requestType: AllocationType.AllocationType,
+ resource:String,
+ numWorkers: Int,
+ priority: Int): ResourceRequest = {
// If hostname specified, we need atleast two requests - node local and rack local.
// There must be a third request - which is ANY : that will be specially handled.
requestType match {
case AllocationType.HOST => {
- assert (YarnAllocationHandler.ANY_HOST != resource)
-
+ assert(YarnAllocationHandler.ANY_HOST != resource)
val hostname = resource
val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority)
- // add to host->rack mapping
+ // Add to host->rack mapping
YarnAllocationHandler.populateRackInfo(conf, hostname)
nodeLocal
}
-
case AllocationType.RACK => {
val rack = resource
createResourceRequestImpl(rack, numWorkers, priority)
}
-
- case AllocationType.ANY => {
- createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority)
- }
-
- case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType)
+ case AllocationType.ANY => createResourceRequestImpl(
+ YarnAllocationHandler.ANY_HOST, numWorkers, priority)
+ case _ => throw new IllegalArgumentException(
+ "Unexpected/unsupported request type: " + requestType)
}
}
- private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = {
+ private def createResourceRequestImpl(
+ hostname:String,
+ numWorkers: Int,
+ priority: Int): ResourceRequest = {
val rsrcRequest = Records.newRecord(classOf[ResourceRequest])
val memCapability = Records.newRecord(classOf[Resource])
@@ -447,11 +514,11 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
def createReleasedContainerList(): ArrayBuffer[ContainerId] = {
val retval = new ArrayBuffer[ContainerId](1)
- // iterator on COW list ...
+ // Iterator on COW list ...
for (container <- releasedContainerList.iterator()){
retval += container
}
- // remove from the original list.
+ // Remove from the original list.
if (! retval.isEmpty) {
releasedContainerList.removeAll(retval)
for (v <- retval) pendingReleaseContainers.put(v, true)
@@ -466,14 +533,14 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
object YarnAllocationHandler {
val ANY_HOST = "*"
- // all requests are issued with same priority : we do not (yet) have any distinction between
+ // All requests are issued with same priority : we do not (yet) have any distinction between
// request types (like map/reduce in hadoop for example)
val PRIORITY = 1
// Additional memory overhead - in mb
val MEMORY_OVERHEAD = 384
- // host to rack map - saved from allocation requests
+ // Host to rack map - saved from allocation requests
// We are expecting this not to change.
// Note that it is possible for this to change : and RM will indicate that to us via update
// response to allocate. But we are punting on handling that for now.
@@ -481,38 +548,69 @@ object YarnAllocationHandler {
private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
- def newAllocator(conf: Configuration,
- resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
- args: ApplicationMasterArguments): YarnAllocationHandler = {
-
- new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers,
- args.workerMemory, args.workerCores, Map[String, Int](), Map[String, Int]())
+ def newAllocator(
+ conf: Configuration,
+ resourceManager: AMRMProtocol,
+ appAttemptId: ApplicationAttemptId,
+ args: ApplicationMasterArguments): YarnAllocationHandler = {
+
+ new YarnAllocationHandler(
+ conf,
+ resourceManager,
+ appAttemptId,
+ args.numWorkers,
+ args.workerMemory,
+ args.workerCores,
+ Map[String, Int](),
+ Map[String, Int]())
}
- def newAllocator(conf: Configuration,
- resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
- args: ApplicationMasterArguments,
- map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
+ def newAllocator(
+ conf: Configuration,
+ resourceManager: AMRMProtocol,
+ appAttemptId: ApplicationAttemptId,
+ args: ApplicationMasterArguments,
+ map: collection.Map[String,
+ collection.Set[SplitInfo]]): YarnAllocationHandler = {
val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
-
- new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers,
- args.workerMemory, args.workerCores, hostToCount, rackToCount)
+ new YarnAllocationHandler(
+ conf,
+ resourceManager,
+ appAttemptId,
+ args.numWorkers,
+ args.workerMemory,
+ args.workerCores,
+ hostToCount,
+ rackToCount)
}
- def newAllocator(conf: Configuration,
- resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
- maxWorkers: Int, workerMemory: Int, workerCores: Int,
- map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
+ def newAllocator(
+ conf: Configuration,
+ resourceManager: AMRMProtocol,
+ appAttemptId: ApplicationAttemptId,
+ maxWorkers: Int,
+ workerMemory: Int,
+ workerCores: Int,
+ map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
- new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers,
- workerMemory, workerCores, hostToCount, rackToCount)
+ new YarnAllocationHandler(
+ conf,
+ resourceManager,
+ appAttemptId,
+ maxWorkers,
+ workerMemory,
+ workerCores,
+ hostToCount,
+ rackToCount)
}
// A simple method to copy the split info map.
- private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) :
+ private def generateNodeToWeight(
+ conf: Configuration,
+ input: collection.Map[String, collection.Set[SplitInfo]]) :
// host to count, rack to count
(Map[String, Int], Map[String, Int]) = {
@@ -536,7 +634,7 @@ object YarnAllocationHandler {
}
def lookupRack(conf: Configuration, host: String): String = {
- if (! hostToRack.contains(host)) populateRackInfo(conf, host)
+ if (!hostToRack.contains(host)) populateRackInfo(conf, host)
hostToRack.get(host)
}
@@ -559,10 +657,12 @@ object YarnAllocationHandler {
val rack = rackInfo.getNetworkLocation
hostToRack.put(hostname, rack)
if (! rackToHostSet.containsKey(rack)) {
- rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
+ rackToHostSet.putIfAbsent(rack,
+ Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
}
rackToHostSet.get(rack).add(hostname)
+ // TODO(harvey): Figure out this comment...
// Since RackResolver caches, we are disabling this for now ...
} /* else {
// right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index ca2f1e2565..2ba2366ead 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -18,13 +18,10 @@
package org.apache.spark.deploy.yarn
import org.apache.spark.deploy.SparkHadoopUtil
-import collection.mutable.HashMap
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
-import java.security.PrivilegedExceptionAction
/**
* Contains util methods to interact with Hadoop from spark.
@@ -40,7 +37,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
// add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster
override def addCredentials(conf: JobConf) {
- val jobCreds = conf.getCredentials();
+ val jobCreds = conf.getCredentials()
jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials())
}
}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
new file mode 100644
index 0000000000..63a0449e5a
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark._
+import org.apache.hadoop.conf.Configuration
+import org.apache.spark.deploy.yarn.YarnAllocationHandler
+import org.apache.spark.util.Utils
+
+/**
+ *
+ * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM.
+ */
+private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+
+ def this(sc: SparkContext) = this(sc, new Configuration())
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ val retval = YarnAllocationHandler.lookupRack(conf, host)
+ if (retval != null) Some(retval) else None
+ }
+
+ override def postStartHook() {
+
+ // The yarn application is running, but the worker might not yet ready
+ // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+ Thread.sleep(2000L)
+ logInfo("YarnClientClusterScheduler.postStartHook done")
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
new file mode 100644
index 0000000000..b206780c78
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
+import org.apache.spark.{SparkException, Logging, SparkContext}
+import org.apache.spark.deploy.yarn.{Client, ClientArguments}
+
+private[spark] class YarnClientSchedulerBackend(
+ scheduler: ClusterScheduler,
+ sc: SparkContext)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ with Logging {
+
+ var client: Client = null
+ var appId: ApplicationId = null
+
+ override def start() {
+ super.start()
+
+ val defalutWorkerCores = "2"
+ val defalutWorkerMemory = "512m"
+ val defaultWorkerNumber = "1"
+
+ val userJar = System.getenv("SPARK_YARN_APP_JAR")
+ var workerCores = System.getenv("SPARK_WORKER_CORES")
+ var workerMemory = System.getenv("SPARK_WORKER_MEMORY")
+ var workerNumber = System.getenv("SPARK_WORKER_INSTANCES")
+
+ if (userJar == null)
+ throw new SparkException("env SPARK_YARN_APP_JAR is not set")
+
+ if (workerCores == null)
+ workerCores = defalutWorkerCores
+ if (workerMemory == null)
+ workerMemory = defalutWorkerMemory
+ if (workerNumber == null)
+ workerNumber = defaultWorkerNumber
+
+ val driverHost = System.getProperty("spark.driver.host")
+ val driverPort = System.getProperty("spark.driver.port")
+ val hostport = driverHost + ":" + driverPort
+
+ val argsArray = Array[String](
+ "--class", "notused",
+ "--jar", userJar,
+ "--args", hostport,
+ "--worker-memory", workerMemory,
+ "--worker-cores", workerCores,
+ "--num-workers", workerNumber,
+ "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher"
+ )
+
+ val args = new ClientArguments(argsArray)
+ client = new Client(args)
+ appId = client.runApp()
+ waitForApp()
+ }
+
+ def waitForApp() {
+
+ // TODO : need a better way to find out whether the workers are ready or not
+ // maybe by resource usage report?
+ while(true) {
+ val report = client.getApplicationReport(appId)
+
+ logInfo("Application report from ASM: \n" +
+ "\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
+ "\t appStartTime: " + report.getStartTime() + "\n" +
+ "\t yarnAppState: " + report.getYarnApplicationState() + "\n"
+ )
+
+ // Ready to go, or already gone.
+ val state = report.getYarnApplicationState()
+ if (state == YarnApplicationState.RUNNING) {
+ return
+ } else if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ throw new SparkException("Yarn application already ended," +
+ "might be killed or not able to launch application master.")
+ }
+
+ Thread.sleep(1000)
+ }
+ }
+
+ override def stop() {
+ super.stop()
+ client.stop()
+ logInfo("Stoped")
+ }
+
+}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
new file mode 100644
index 0000000000..2941356bc5
--- /dev/null
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import org.scalatest.FunSuite
+import org.scalatest.mock.MockitoSugar
+import org.mockito.Mockito.when
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.permission.FsAction
+import org.apache.hadoop.yarn.api.records.LocalResource
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
+import org.apache.hadoop.yarn.api.records.LocalResourceType
+import org.apache.hadoop.yarn.util.{Records, ConverterUtils}
+
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.Map
+
+
+class ClientDistributedCacheManagerSuite extends FunSuite with MockitoSugar {
+
+ class MockClientDistributedCacheManager extends ClientDistributedCacheManager {
+ override def getVisibility(conf: Configuration, uri: URI, statCache: Map[URI, FileStatus]):
+ LocalResourceVisibility = {
+ return LocalResourceVisibility.PRIVATE
+ }
+ }
+
+ test("test getFileStatus empty") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath() === null)
+ }
+
+ test("test getFileStatus cached") {
+ val distMgr = new ClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val uri = new URI("/tmp/testing")
+ val realFileStatus = new FileStatus(10, false, 1, 1024, 10, 10, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(new Path(uri))).thenReturn(new FileStatus())
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus](uri -> realFileStatus)
+ val stat = distMgr.getFileStatus(fs, uri, statCache)
+ assert(stat.getPath().toString() === "/tmp/testing")
+ }
+
+ test("test addResource") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 0)
+ assert(resource.getSize() === 0)
+ assert(resource.getType() === LocalResourceType.FILE)
+
+ val env = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env)
+ assert(env("SPARK_YARN_CACHE_FILES") === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(env("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === "0")
+ assert(env("SPARK_YARN_CACHE_FILES_FILE_SIZES") === "0")
+ assert(env("SPARK_YARN_CACHE_FILES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name())
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
+
+ //add another one and verify both there and order correct
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing2"))
+ val destPath2 = new Path("file:///foo.invalid.com:8080/tmp/testing2")
+ when(fs.getFileStatus(destPath2)).thenReturn(realFileStatus)
+ distMgr.addResource(fs, conf, destPath2, localResources, LocalResourceType.FILE, "link2",
+ statCache, false)
+ val resource2 = localResources("link2")
+ assert(resource2.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource2.getResource()) === destPath2)
+ assert(resource2.getTimestamp() === 10)
+ assert(resource2.getSize() === 20)
+ assert(resource2.getType() === LocalResourceType.FILE)
+
+ val env2 = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env2)
+ val timestamps = env2("SPARK_YARN_CACHE_FILES_TIME_STAMPS").split(',')
+ val files = env2("SPARK_YARN_CACHE_FILES").split(',')
+ val sizes = env2("SPARK_YARN_CACHE_FILES_FILE_SIZES").split(',')
+ val visibilities = env2("SPARK_YARN_CACHE_FILES_VISIBILITIES") .split(',')
+ assert(files(0) === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(timestamps(0) === "0")
+ assert(sizes(0) === "0")
+ assert(visibilities(0) === LocalResourceVisibility.PRIVATE.name())
+
+ assert(files(1) === "file:/foo.invalid.com:8080/tmp/testing2#link2")
+ assert(timestamps(1) === "10")
+ assert(sizes(1) === "20")
+ assert(visibilities(1) === LocalResourceVisibility.PRIVATE.name())
+ }
+
+ test("test addResource link null") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ when(fs.getFileStatus(destPath)).thenReturn(new FileStatus())
+
+ intercept[Exception] {
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, null,
+ statCache, false)
+ }
+ assert(localResources.get("link") === None)
+ assert(localResources.size === 0)
+ }
+
+ test("test addResource appmaster only") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, true)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val env = new HashMap[String, String]()
+ distMgr.setDistFilesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_FILES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None)
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === None)
+ }
+
+ test("test addResource archive") {
+ val distMgr = new MockClientDistributedCacheManager()
+ val fs = mock[FileSystem]
+ val conf = new Configuration()
+ val destPath = new Path("file:///foo.invalid.com:8080/tmp/testing")
+ val localResources = HashMap[String, LocalResource]()
+ val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]()
+ val realFileStatus = new FileStatus(20, false, 1, 1024, 10, 30, null, "testOwner",
+ null, new Path("/tmp/testing"))
+ when(fs.getFileStatus(destPath)).thenReturn(realFileStatus)
+
+ distMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, "link",
+ statCache, false)
+ val resource = localResources("link")
+ assert(resource.getVisibility() === LocalResourceVisibility.PRIVATE)
+ assert(ConverterUtils.getPathFromYarnURL(resource.getResource()) === destPath)
+ assert(resource.getTimestamp() === 10)
+ assert(resource.getSize() === 20)
+ assert(resource.getType() === LocalResourceType.ARCHIVE)
+
+ val env = new HashMap[String, String]()
+
+ distMgr.setDistArchivesEnv(env)
+ assert(env("SPARK_YARN_CACHE_ARCHIVES") === "file:/foo.invalid.com:8080/tmp/testing#link")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_TIME_STAMPS") === "10")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_FILE_SIZES") === "20")
+ assert(env("SPARK_YARN_CACHE_ARCHIVES_VISIBILITIES") === LocalResourceVisibility.PRIVATE.name())
+
+ distMgr.setDistFilesEnv(env)
+ assert(env.get("SPARK_YARN_CACHE_FILES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_TIME_STAMPS") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_FILE_SIZES") === None)
+ assert(env.get("SPARK_YARN_CACHE_FILES_VISIBILITIES") === None)
+ }
+
+
+}